In [None]:
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 1

In [None]:
import logging
import random

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as plt_colors
from matplotlib.axes._axes import _log as matplotlib_axes_logger
matplotlib_axes_logger.setLevel('ERROR')
import scipy.stats as st
import holoviews as hv
hv.extension('bokeh')
from holoviews import dim
from IPython.display import Markdown, display
from IPython.core.display import HTML

import matplotlib
matplotlib.rc('xtick', labelsize=14)     
matplotlib.rc('ytick', labelsize=14)
matplotlib.rc('axes', labelsize=14, titlesize=14)


pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

from counts_analysis.c_utils import COUNTS_CSV, CLASSES, set_settings, set_counts

#== Load Datasets ==#
df = pd.read_csv(COUNTS_CSV['counts'])
pred_df = pd.read_csv(COUNTS_CSV['counts-v10'])
# Dataset without problematic classes (Gyrodinium, Pseudo-nitzchia chain)
df_ = df[df['class'].isin(CLASSES)].reset_index(drop=True)
data = df.copy()
pred_data = pred_df.copy()

def printmd(string):
    display(Markdown(string))

#=== Set count forms & settings ===#
# COUNT
# SETTING
# Original raw counts
rc_counts = set_counts('gtruth', 'raw count', micro_default=True)            
rc_settings = set_settings(rc_counts)
print('Example of setting\n{}'.format(rc_settings))
# Predicted raw counts
rc_counts_pred = set_counts('predicted', 'raw count', micro_default=True)            
rc_settings_pred = set_settings(rc_counts_pred)
# Relative abundance
rel_counts = set_counts('gtruth', 'relative abundance', micro_default=False)
rel_counts = ['micro cells/mL relative abundance'] + list(rel_counts[1:])
# Predicted Relative Abundance
rel_counts_pred = set_counts('predicted', 'relative abundance', micro_default=False)
rel_counts_pred = ['micro cells/mL relative abundance'] + list(rel_counts_pred[1:])

#=== Set classifier gtruth vs predictions
lab_gtruth_pred = ['lab {} raw count'.format(lbl) for lbl in ['gtruth', 'predicted']]
pier_gtruth_pred = ['pier {} raw count'.format(lbl) for lbl in ['gtruth', 'predicted']]

In [None]:
high = pd.read_csv('/data6/phytoplankton-db/counts/master_counts_v11-high.csv')
low = pd.read_csv('/data6/phytoplankton-db/counts/master_counts_v11-low.csv')

In [None]:
""" Experimenting different relative abundance combinations

"Other" distribution computed to understand how HAB species would fare
- 6/11/2020 3:03 PM microscopy "other" distribution MUCH lower than this. Seems that "other" dominates much of the camera samples

"""
dataset = pd.read_csv(COUNTS_CSV['counts-v10'])
dataset = dataset[dataset['class'] == 'Other']
print(dataset['lab gtruth relative abundance'].describe())
dataset[['class', 'datetime', 'lab gtruth relative abundance', 'pier gtruth relative abundance']]

In [None]:
dataset = df.copy()

matplotlib.rc('xtick', labelsize=18)     
matplotlib.rc('ytick', labelsize=18)
matplotlib.rc('axes', labelsize=18)

printmd("# Total Counts Distribution")
current_palette_7 = sns.color_palette("muted", 9)
sns.set_palette(current_palette_7)

# Time Series
dataset.groupby('datetime')[rc_counts].sum().plot(kind='line', figsize=(18,5))
plt.xlabel('Date Sampled')
plt.ylabel('Total Counts')
plt.legend(labels=['micro cells/mL', 'lab cells/1000s', 'pier cells/2000s'], bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

# Distribution
dataset.groupby('datetime')[rc_counts].sum().plot(kind='hist', figsize=(8,5), alpha=0.4)
plt.xlabel('Total Counts')
plt.legend(labels=['micro cells/mL', 'lab cells/1000s', 'pier cells/2000s'], bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

# Box & Whiskers plot
dataset.groupby('datetime')[rc_counts].sum().plot(kind='box', figsize=(8,5))
plt.xlabel('Sampling Techniques')
plt.ylabel('Total Counts')

# Box & Whiskers plot (logged)
dataset.groupby('datetime')[rc_counts].sum().plot(kind='box', figsize=(7,8), logy=True)
plt.xlabel('Sampling Techniques')
plt.ylabel('Total Counts (logged)')

total_counts = df.groupby('datetime')[rc_counts].sum()
print(total_counts.sum())
total_counts.describe()

def plot_correlation(counts, data,  ax_idx):
    from validate_exp.v_utils import best_fit

    sns.scatterplot(x=data[counts[0]], y=data[counts[1]], ax=ax_idx,
                    label='lab (Y) - micro (X)')
    sns.scatterplot(x=data[counts[0]], y=data[counts[2]], ax=ax_idx,
                    label='pier (Y) - micro (X)')
    sns.scatterplot(x=data[counts[1]], y=data[counts[2]], ax=ax_idx,
                    label='pier (Y) - lab (X)')
    Xfit, Yfit = best_fit(data[counts[0]], data[counts[1]], False, verbose=False)
    ax_idx.plot(Xfit, Yfit)

    Xfit, Yfit = best_fit(data[counts[0]], data[counts[2]], False, verbose=False)
    ax_idx.plot(Xfit, Yfit)

    Xfit, Yfit = best_fit(data[counts[1]], data[counts[2]], False, verbose=False)
    ax_idx.plot(Xfit, Yfit)

    ax_idx.set_xlabel('Count (X)')
    ax_idx.set_ylabel('Count (Y)')

    ymin, ymax = ax_idx.get_ylim()
    xmin, xmax = ax_idx.get_xlim()

    max_val = xmax if xmax >= ymax else ymax
    ax_idx.set_ylim(0, max_val)
    ax_idx.set_xlim(0, max_val)

def plot_class_summary(counts, data):
    matplotlib.rc('xtick', labelsize=18)     
    matplotlib.rc('ytick', labelsize=18)
    matplotlib.rc('axes', labelsize=18)
    
    fig, ax = plt.subplots(1,3, figsize=(30,8))
    
    
    #=== plot box whiskers plot ===#
    dataset.groupby('datetime')[rc_counts].sum().plot(kind='box', ax=ax[0])
    ax[0].set_xlabel('Sampling Techniques')
    ax[0].set_ylabel('Total Counts')
    
    #=== plot time series ===#
    dataset.groupby('datetime')[rc_counts].sum().plot(kind='line', ax=ax[1])
    ax[1].set_xlabel('Date Sampled')
    ax[1].set_ylabel('Total Counts')
    
    #=== plot corrrelationn plot ===#
    plot_correlation(counts, data, ax_idx=ax[2])
    plt.show()


    

In [None]:
matplotlib.rc('xtick', labelsize=18)     
matplotlib.rc('ytick', labelsize=18)
matplotlib.rc('axes', labelsize=18)
matplotlib.rc('font', serif='Calibri') 
matplotlib.rc('font', family='sans-serif') 

printmd("# Total Counts Distribution")
current_palette_7 = sns.color_palette("muted", 9)
sns.set_palette(current_palette_7)

def plot_class_distribution(dataset, counts, logged=False, relative=False):
    #=== Box&Whisker ===#
    sm = dataset[['class', 'datetime'] + list(counts)]
    sm = sm.melt(id_vars=['class', 'datetime'], var_name=['setting'], value_name='count')
    sm = sm.sort_values('class')
#     sm['setting'] = sm['setting'].map({'micro cells/mL': 'micro', 'lab gtruth raw count': 'lab', 'pier gtruth raw count':'pier'})
    plt.figure(figsize=(30, 7))
    sns.boxplot(x='class', y='count', hue='setting', data=sm)
    sns.stripplot(x='class', y='count', hue='setting', data=sm, color=".25", dodge=True)
    
    ylabel = 'Total Counts' if not relative else 'Relative Abundance'
    if logged:
        ylabel += ' (logged)'
        plt.yscale('symlog')
        plt.ylim(0)

    plt.ylabel(ylabel)
    plt.xlabel('Classes')
    plt.tight_layout()
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.show()
    
    #=== Histogram ===#
    num_cols = 5
    fig, ax = plt.subplots(2, num_cols, figsize=(20,9))
    for idx,cls in enumerate(sorted(sm['class'].unique())):
        ak_sm = sm[sm['class'] == cls]
        ak_sm_gp = ak_sm.groupby('setting')
        sns.kdeplot(ak_sm_gp.get_group(counts[0])['count'], bw=1, ax=ax[idx//num_cols, idx%num_cols], label='micro')
        sns.kdeplot(ak_sm_gp.get_group(counts[1])['count'], bw=1, ax=ax[idx//num_cols, idx%num_cols], label='lab')
        sns.kdeplot(ak_sm_gp.get_group(counts[2])['count'], bw=1, ax=ax[idx//num_cols, idx%num_cols], label='pier')

        ax[idx//num_cols, idx%num_cols].set_title(cls + "\nDistribution")
        ax[idx//num_cols, idx%num_cols].set_xlabel(ylabel)
        ax[idx//num_cols, idx%num_cols].set_ylabel('Frequency')
        
        
    plt.tight_layout()
    plt.show()
    

classes = ['Lingulodinium polyedra', 'Prorocentrum micans', 'Pseudo-nitzschia chain']
def filter_classes(df, classes, high=True): return df[df['class'].isin(classes)].reset_index(drop=True) if high else df[~df['class'].isin(classes)].reset_index(drop=True)

"""Overall"""
plot_class_distribution(df, rc_counts, False)
plot_class_distribution(df, rel_counts, False, relative=True)


"""High & Low"""
# plot_class_distribution(high, rel_counts, False)
# plot_class_distribution(low, rel_counts, False)

"""
High & Low | Rare & Dominant
"""
# plot_class_distribution(df, rc_counts, False)
# high_cls = filter_classes(high, classes)
# plot_class_distribution(high_cls, rc_counts, False)
# lowhigh_cls = filter_classes(high, classes, False)
# plot_class_distribution(lowhigh_cls, rc_counts, False)

# low_cls = filter_classes(low, classes)
# plot_class_distribution(low_cls, rc_counts, False)
# highlow_cls = filter_classes(low, classes, False)
# plot_class_distribution(highlow_cls, rc_counts, False)

"""Rare & Dominant"""
# dominant = filter_classes(df, classes)
# plot_class_distribution(dominant, rc_counts, True)
# rare = filter_classes(df, classes, high=False)
# plot_class_distribution(rare, rc_counts, True)

"""Rare & Dominant | Relative Counts"""
# dominant = filter_classes(df, classes)
# plot_class_distribution(dominant, rel_counts, False, True)
# rare = filter_classes(df, classes, high=False)
# plot_class_distribution(rare, rel_counts, False, True)

In [None]:
"""Compute relative abundance of each date without certain classes"""
def compute_relative_abundance(raw_count, data):
    if 'micro' in raw_count:
        relative_column = 'micro cells/mL relative abundance'
    else:
        relative_column = f'{raw_count.split()[0]} {raw_count.split()[1]} relative abundance'
    data[relative_column] = data.groupby('datetime')[raw_count].apply(lambda x: x / x.sum() * 100.0 if sum(x) != 0 else x)
    return data

data = df[~df['class'].isin(['Gyrodinium', 'Pseudo-nitzschia chainn'])].reset_index(drop=True)
# data = df[df['class'].isin(['Lingulodinium polyedra', 'Prorocentrum micans', 'Pseudo-nitzschia chain'])].reset_index(drop=True)

for rc in rc_counts:
    print(rc)
    data = compute_relative_abundance(rc, data)
filtered_data = data.copy()

In [None]:
filtered_data[['class', 'datetime', 'lab gtruth relative abundance', 'micro cells/mL relative abundance', 'pier gtruth relative abundance']].head(25)

In [None]:
plot_class_distribution(df, rel_counts, False, relative=True)
display(df.groupby('class')[rel_counts].describe())
plot_class_distribution(filtered_data, rel_counts, False, relative=True)
display(filtered_data.groupby('class')[rel_counts].describe())

In [None]:
fontsize = 20
matplotlib.rc('xtick', labelsize=fontsize)     
matplotlib.rc('ytick', labelsize=fontsize)
matplotlib.rc('axes', labelsize=fontsize, titlesize=fontsize)

def plot_all_classes(dataset, counts, logged=False, relative=False):
    sm = dataset[['class', 'datetime'] + list(counts)]
    sm = sm.melt(id_vars=['class', 'datetime'], var_name=['setting'], value_name='count')
    sm = sm.sort_values('class')
#     sm['setting'] = sm['setting'].map({'micro cells/mL relative abundance': 'micro', 'lab gtruth relative abundance': 'lab', 'pier gtruth relative abundance':'pier'})
    
    plt.figure(figsize=(10, 6))
    sns.boxplot(x='class', y='count', hue='setting', data=sm)
#     sns.stripplot(x='class', y='count', hue='setting', data=sm, color=".25", dodge=True)

    ylabel = 'Total Counts' if not relative else 'Relative Abundance'
    if logged:
        ylabel += ' (logged)'
        plt.yscale('symlog')
        plt.ylim(0)

#     plt.ylim(0, 100)
    plt.ylabel(ylabel)
    plt.xlabel('Classes')
    plt.tight_layout()
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    
    ax = plt.gca()

    ax.axes.xaxis.set_ticklabels([])
#     ax.axes.yaxis.set_ticklabels([])

    plt.show()
    
def compute_relative_abundance(raw_count, data):
    if 'micro' in raw_count:
        relative_column = 'micro cells/mL relative abundance'
    else:
        relative_column = f'{raw_count.split()[0]} {raw_count.split()[1]} relative abundance'
    data[relative_column] = data.groupby('class')[raw_count].apply(lambda x: x / x.sum() * 100.0 if sum(x) != 0 else x)
    return data

# MODIFYING DATASET FOR ONLY SEASONAL DATES ONLY
y = df.copy()
def filter_classes(df, classes):
    return df[~df['class'].isin(classes)].reset_index(drop=True)
y = y[y['datetime'].isin(['2019-05-23', '2019-05-28', '2019-06-03'])]
y = filter_classes(y, ['Gyrodinium', 'Chattonella', 'Pseudo-nitzschia chain'])
y = y.sort_values('class')
for rc in rc_counts:
    print(rc)
    y = compute_relative_abundance(rc, y)
high = y.copy()

y = df.copy()
y = y[~y['datetime'].isin(['2019-05-23', '2019-05-28', '2019-06-03'])]
y = filter_classes(y, ['Gyrodinium', 'Chattonella', 'Pseudo-nitzschia chain'])
for rc in rc_counts:
    print(rc)
    y = compute_relative_abundance(rc, y)
low = y.copy()

# display(y[['class', 'datetime'] + list(rel_counts)])
plot_all_classes(df, rc_counts, False, relative=False)
# plot_all_classes(high, rel_counts, True, relative=True)
# plot_all_classes(low, rel_counts, True, relative=True)

# Correlation Plot

In [None]:
# %%opts Scatter [tools=['hover'], width=600, height=600, legend_position='right', xlim=(-1, None), ylim=(-1, None), logx=True, logy=True]
# %%opts Slope [xlim=(-1, None), ylim=(-1, None), logx=True, logy=True]

def plot_correlation_hv(counts, data, cls, relative=False):
    title_pre = '[Absolute Count]' if not relative else '[Relative Abundance]'
    xy = 'Count' if not relative else 'Relative Abundance'
    max_val = max(data[list(counts)].max()) + 10
    
    # correlation plot
    dot_size, alpha = 12, 0.6
    fs = 18

    sc1 = hv.Scatter(data, counts[0], [counts[1], 'datetime', 'class'],
                     label='lab - micro').opts(size=dot_size, alpha=alpha,
                                               tools=['hover'])
    reg = hv.Slope.from_scatter(sc1).opts(alpha=alpha, tools=['hover'], )

    sc2 = hv.Scatter(data, counts[0], [counts[2], 'datetime', 'class'],
                     label='pier - micro').opts(size=dot_size, alpha=alpha,
                                                tools=['hover'], )
    reg2 = hv.Slope.from_scatter(sc2).opts(alpha=alpha, tools=['hover'], )

    sc3 = hv.Scatter(data, counts[1], [counts[2], 'datetime', 'class'],
                     label='pier - lab').opts(size=dot_size, alpha=alpha,
                                              tools=['hover'], )
    reg3 = hv.Slope.from_scatter(sc3).opts(alpha=alpha, tools=['hover'], )
    
    corr = (sc1 * sc2 * sc3 * reg * reg2 * reg3).opts(xlabel=xy, ylabel=xy,
                                                      title=f'{cls}',
                                                      xlim=(0, max_val),
                                                      ylim=(0, max_val), tools=['hover'],
                                                      width=550, height=550,
#                                                       legend_position='right',
                                                      show_legend=False,
                                                      fontsize={'title': fs, 'labels': fs, 'xticks': fs, 'yticks': fs}).redim.range(y=(0.01, 1))
    return corr.redim.range(y=(0.01, None), x=(0.01, None))

def compute_relative_abundance(raw_count, data):
    if 'micro' in raw_count:
        relative_column = 'micro cells/mL relative abundance'
    else:
        relative_column = f'{raw_count.split()[0]} {raw_count.split()[1]} relative abundance'
    data[relative_column] = data.groupby('class')[raw_count].apply(lambda x: x / x.sum() * 100.0 if sum(x) != 0 else x)
    return data

y = df.copy()
for rc in rc_counts:
    print(rc)
    y = compute_relative_abundance(rc, y)
for rc in rc_counts_pred:
    print(rc)
    y = compute_relative_abundance(rc, y)
y = y[~y['datetime'].isin(['2019-05-23', '2019-05-28', '2019-06-03'])].reset_index(drop=True)
cls = 'Akashiwo'

COUNT = rel_counts_pred
def filter_classes(df, classes):
    return df[df['class'].isin(classes)].reset_index(drop=True)
cls_df = filter_classes(y, [cls])
corr = plot_correlation_hv(COUNT, cls_df, cls, True)

clsses = ['Ceratium falcatiforme or fusus', 'Ceratium furca', 'Cochlodinium', 'Lingulodinium polyedra', 'Prorocentrum micans']
# clsses = ['Ceratium furca', 'Cochlodinium', 'Lingulodinium polyedra', 'Prorocentrum micans']
for cls in clsses:
    print(cls)
    cls_df = filter_classes(y, [cls])
    corr += plot_correlation_hv(COUNT, cls_df, cls, True)
hv.Layout(corr.redim.range(y=(0.01, None), x=(0.01, None))).cols(3)

In [None]:
matplotlib.rc('xtick', labelsize=10)     
matplotlib.rc('ytick', labelsize=10)
matplotlib.rc('axes', labelsize=10)

current_palette_7 = sns.color_palette("Set2", 3)
sns.set_palette(current_palette_7[::-1])

def rsquared(x, y):
    """ Return R^2 where x and y are array-like."""
    import statsmodels.api as sm
    X = np.hstack((np.array([1] * len(x)).reshape(-1, 1), np.array(x).reshape(-1, 1)))
    mod = sm.OLS(np.array(y).reshape(-1, 1), X)
    res = mod.fit()
    return res

def best_fit(X, Y, log_scale=False, verbose=False):
    if log_scale:
        # slope, intercept, \
        # r_value, p_value, std_err = linregress(np.log10(np.array(X) + 1), np.log10(np.array(Y)+1))
        # Xfit = np.logspace(-1, 4, base=10)
        # Yfit = Xfit * slope + intercept

        x1 = [x for (x, y) in sorted(zip(X, Y))]
        y1 = [y for (x, y) in sorted(zip(X, Y))]
        x = np.array([np.log(x) if x>=1 else 1 for x in x1])
        y = np.array([np.log(x) if x>=1 else 1 for x in y1])
        k,m = np.polyfit(x, y, deg=1)
        Xfit = x1
        Yfit = np.exp(m) * x1**(k)
#         Yfit = fit[0] * x + fit[1]


    else:
        xbar = sum(X) / len(X)
        ybar = sum(Y) / len(Y)
        n = len(X)  # or len(Y)

        numer = sum([xi * yi for xi, yi in zip(X, Y)]) - n * xbar * ybar
        denum = sum([xi ** 2 for xi in X]) - n * xbar ** 2

        b = numer / denum if denum !=0 else 0
        a = ybar - b * xbar

        Yfit = [a + b * xi for xi in X]
        Xfit = X

    # Compute R2 value and other statistics from statsmodel
    res = rsquared(X, Y)

    if verbose:
        print(res.summary())
    return Xfit, Yfit

def plot_correlation(data, counts, logged=False):
    NUM_COLS=5
    fig, ax = plt.subplots(2, NUM_COLS, figsize=(20, 8))
    sns.scatterplot(x=data[counts[0]], y=data[counts[1]], ax=ax[0,0], label='lab (Y) - micro (X)')
    sns.scatterplot(x=data[counts[0]], y=data[counts[2]], ax=ax[0,0], label='pier (Y) - micro (X)')
    sns.scatterplot(x=data[counts[1]], y=data[counts[2]], ax=ax[0,0], label='pier (Y) - lab (X)')

    ax[0,0].set_xlabel('Count (X)')
    ax[0,0].set_ylabel('Count (Y)')

    plt.tight_layout()
    classes = sorted(data['class'].unique())
    for i_ax,cls in enumerate(classes):
        cls_df = data[data['class']==cls]
        ax_idx = ax[int((i_ax+1) / NUM_COLS), (i_ax+1) % NUM_COLS]
        sns.scatterplot(x=cls_df[counts[0]], y=cls_df[counts[1]], ax=ax_idx, label='lab (Y) - micro (X)')
        sns.scatterplot(x=cls_df[counts[0]], y=cls_df[counts[2]], ax=ax_idx, label='pier (Y) - micro (X)')
        sns.scatterplot(x=cls_df[counts[1]], y=cls_df[counts[2]], ax=ax_idx, label='pier (Y) - lab (X)')
        Xfit, Yfit = best_fit(cls_df[counts[0]], cls_df[counts[1]], logged, verbose=False)
        ax_idx.plot(Xfit, Yfit)

        Xfit, Yfit = best_fit(cls_df[counts[0]], cls_df[counts[2]], logged, verbose=False)
        ax_idx.plot(Xfit, Yfit)

        Xfit, Yfit = best_fit(cls_df[counts[1]], cls_df[counts[2]], logged, verbose=False)
        ax_idx.plot(Xfit, Yfit)

        ax_idx.set_xlabel('Count (X)')
        ax_idx.set_ylabel('Count (Y)')
        
        ymin, ymax = ax_idx.get_ylim()
        xmin, xmax = ax_idx.get_xlim()
        
        max_val = xmax if xmax >= ymax else ymax
        ax_idx.set_ylim(0, max_val)
        ax_idx.set_xlim(0, max_val)
        ax_idx.set_yscale('symlog')
        ax_idx.set_xscale('symlog')

        ax_idx.set_title(cls)
#         set_plotting_opts(ax_idx, logged=LOGGED)
        plt.tight_layout()
    plt.show()

In [None]:
ABSL_COUNTS = True
PREDICTED_COUNTS = False
RELATIVE_COUNTS = False
HIGH_AND_LOW = False

printmd('# Overall Abundance Days')
data = filtered_data.copy()
if ABSL_COUNTS:
    printmd('### Absolute Counts')
    plot_correlation(data, rc_counts, logged=True)

if PREDICTED_COUNTS:
    printmd('### Predicted Absolute Counts')
    pred_data = pred_df.copy()
    pred_data = pred_data[pred_data['class'] != "Other"].reset_index(drop=True)
    plot_correlation(pred_data, rc_counts_pred)

if RELATIVE_COUNTS:
    printmd('### Relative Counts')
    plot_correlation(data, rel_counts)

if HIGH_AND_LOW:
    printmd('# Low Abundance Days')
    data = low.copy()
    if ABSL_COUNTS:
        printmd('### Absolute Counts')
        plot_correlation(data, rc_counts)

    if PREDICTED_COUNTS:
        printmd('### Predicted Absolute Counts')
        pred_data = pred_df.copy()
        pred_data = pred_data[pred_data['class'] != "Other"].reset_index(drop=True)
        plot_correlation(pred_data, rc_counts_pred)

    if RELATIVE_COUNTS:
        printmd('### Relative Counts')
        plot_correlation(data, rel_counts)

    printmd('# High Abundance Days')
    data = high.copy()
    if ABSL_COUNTS:
        printmd('### Absolute Counts')
        plot_correlation(data, rc_counts)

    if PREDICTED_COUNTS:
        printmd('### Predicted Absolute Counts')
        pred_data = pred_df.copy()
        pred_data = pred_data[pred_data['class'] != "Other"].reset_index(drop=True)
        plot_correlation(pred_data, rc_counts_pred)

    if RELATIVE_COUNTS:
        printmd('### Relative Counts')
        plot_correlation(data, rel_counts)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

# Abhishek Bhatia's data & scatter plot.
x = np.array([  29.,   36.,    8.,   32.,   11.,   60.,   16.,  242.,   36.,
               115.,    5.,  102.,    3.,   16.,   71.,    0.,    0.,   21.,
               347.,   19.,   12.,  162.,   11.,  224.,   20.,    1.,   14.,
                 6.,    3.,  346.,   73.,   51.,   42.,   37.,  251.,   21.,
               100.,   11.,   53.,  118.,   82.,  113.,   21.,    0.,   42.,
                42.,  105.,    9.,   96.,   93.,   39.,   66.,   66.,   33.,
               354.,   16.,  602.])
y = np.array([ 30,  47, 115,  50,  40, 200, 120, 168,  39, 100,   2, 100,  14,
               50, 200,  63,  15, 510, 755, 135,  13,  47,  36, 425,  50,   4,
               41,  34,  30, 289, 392, 200,  37,  15, 200,  50, 200, 247, 150,
              180, 147, 500,  48,  73,  50,  55, 108,  28,  55, 100, 500,  61,
              145, 400, 500,  40, 250])
fig = plt.figure()
ax=plt.gca() 
ax.scatter(x,y,c="blue",alpha=0.95,edgecolors='none', label='data')
ax.set_yscale('log')
ax.set_xscale('log')


newX = np.logspace(0, 3, base=10)  # Makes a nice domain for the fitted curves.
                                   # Goes from 10^0 to 10^3
                                   # This avoids the sorting and the swarm of lines.

# Let's fit an exponential function.  
# This looks like a line on a lof-log plot.
def myExpFunc(x, a, b):
    return a * np.power(x, b)
popt, pcov = curve_fit(myExpFunc, x, y)
plt.plot(newX, myExpFunc(newX, *popt), 'r-', 
         label="({0:.3f}*x**{1:.3f})".format(*popt))
print("Exponential Fit: y = (a*(x**b))")
print("\ta = popt[0] = {0}\n\tb = popt[1] = {1}".format(*popt))

# Let's fit a more complicated function.
# This won't look like a line.
def myComplexFunc(x, a, b, c):
    return a * np.power(x, b) + c
popt, pcov = curve_fit(myComplexFunc, x, y)
plt.plot(newX, myComplexFunc(newX, *popt), 'g-', 
         label="({0:.3f}*x**{1:.3f}) + {2:.3f}".format(*popt))
print("Modified Exponential Fit: y = (a*(x**b)) + c")
print("\ta = popt[0] = {0}\n\tb = popt[1] = {1}\n\tc = popt[2] = {2}".format(*popt))

ax.grid(b='on')
plt.legend(loc='lower right')
plt.show()

In [None]:
matplotlib.rc('xtick', labelsize=10)     
matplotlib.rc('ytick', labelsize=10)
matplotlib.rc('axes', labelsize=10)


def plot_predicted_correlation(data, lab, pier):
    from validate_exp.v_utils import best_fit
    from validate_exp.stat_fns import mase

    NUM_COLS=5
    fig, ax = plt.subplots(2, NUM_COLS, figsize=(20, 8))
    sns.scatterplot(x=data[lab[0]], y=data[lab[1]], ax=ax[0, 0], label='lab')
    sns.scatterplot(x=data[pier[0]], y=data[pier[1]], ax=ax[0, 0], label='pier')

    ax[0,0].set_xlabel('Gtruth Count')
    ax[0,0].set_ylabel('Predicted Count')

    scores_df = pd.DataFrame(columns=['class', 'camera', 'mase'])
    
    plt.tight_layout()
    classes = sorted(data['class'].unique())
    for i_ax,cls in enumerate(classes):
        cls_df = data[data['class']==cls]
        ax_idx = ax[int((i_ax+1) / NUM_COLS), (i_ax+1) % NUM_COLS]
        sns.scatterplot(x=cls_df[lab[0]], y=cls_df[lab[1]], ax=ax_idx, label='lab')
        sns.scatterplot(x=cls_df[pier[0]], y=cls_df[pier[1]], ax=ax_idx, label='pier')
        
        print(cls)
        lm = mase(cls_df[lab[0]], cls_df[lab[1]])
        pm = mase(cls_df[pier[0]], cls_df[pier[1]])
        print('MASE (Lab): {}'.format(lm))
        print('MASE (Pier): {}\n\n'.format(pm))
        scores_df = scores_df.append({'class':cls, 'camera': 'lab', 'mase': lm}, ignore_index=True)
        scores_df = scores_df.append({'class':cls, 'camera': 'pier', 'mase': pm}, ignore_index=True)
        
        Xfit, Yfit = best_fit(cls_df[lab[0]], cls_df[lab[1]], False, verbose=False)
        ax_idx.plot(Xfit, Yfit, color='blue', marker='_')

        Xfit, Yfit = best_fit(cls_df[pier[0]], cls_df[pier[1]], False, verbose=False)
        ax_idx.plot(Xfit, Yfit, color='orange', marker='_')

        ax_idx.set_xlabel('Gtruth Count')
        ax_idx.set_ylabel('Predicted Count')  

        ax_idx.set_title(cls)
    #     set_plotting_opts(ax_idx, logged=LOGGED)
        plt.tight_layout()
    plt.show()
    
    sns.barplot(x='class', y='mase', hue='camera', data=scores_df)
    plt.show()
    return scores_df

pred_data = pred_df.copy()
pred_data = pred_data[pred_data['class'] != "Other"].reset_index(drop=True)
scores_df = plot_predicted_correlation(pred_data, lab_gtruth_pred, pier_gtruth_pred)

In [None]:
scores_df1 = scores_df.copy()
scores_df1 = scores_df1[scores_df1['camera'] == 'pier']
display(scores_df1['mase'].describe())

scores_df1 = scores_df.copy()
scores_df1 = scores_df1[scores_df1['camera'] == 'lab']
display(scores_df1['mase'].describe())

In [None]:
"""
SCALED WORK
"""
t = data.groupby('class').get_group('Prorocentrum micans')
t.loc[:, rel_counts[1]] /= 100.
t.loc[:, rel_counts[2]] /= 100.

def plot_scatter(df, counts, logged=False):
    from validate_exp.v_utils import best_fit

    sns.scatterplot(x=df[counts[0]], y=df[counts[1]], label='lab (Y) - micro (X)')
    Xfit, Yfit = best_fit(df[counts[0]], df[counts[1]], logged, verbose=False)
    plt.plot(Xfit, Yfit, color='blue')

    sns.scatterplot(x=df[counts[0]], y=df[counts[2]], label='pier (Y) - micro (X)')
    Xfit, Yfit = best_fit(df[counts[0]], df[counts[2]], logged, verbose=False)
    plt.plot(Xfit, Yfit, color='orange')

    sns.scatterplot(x=df[counts[1]], y=df[counts[2]], label='pier (Y) - lab (X)')
    Xfit, Yfit = best_fit(df[counts[1]], df[counts[2]], logged, verbose=False)
    plt.plot(Xfit, Yfit, color='green')
    
    if logged:
        plt.xscale('symlog')
        plt.yscale('symlog')
    
    plt.show()

t['scaling1'] = t['micro cells/mL'] / t[rel_counts[1]]
t['test1'] = t[rel_counts[1]] * t['scaling1']
t['test1'] = t['test1'].fillna(0)

t['scaling2'] = t['micro cells/mL'] / t[rel_counts[2]]
t['test2'] = t[rel_counts[2]] * t['scaling2']

print(t[['micro cells/mL', rel_counts[1], 'lab gtruth total abundance', 'scaling1', 'test1', rel_counts[2], 'test2']])
plot_scatter(t, ['micro cells/mL', 'test1', 'test2'], False)

In [None]:
data = df.copy()

data[rel_counts[1]] /= 100.
data[rel_counts[2]] /= 100.

lab = []
for i in lab_scale:
    lab.extend([i]*9)
pier = []
for i in pier_scale:
    pier.extend([i]*9)
data['lab_scale'] = lab
data['pier_scale'] = pier

data['test1'] = data['lab_scale'] * data[rel_counts[1]]
data['test2'] = data['pier_scale'] * data[rel_counts[1]]
plot_correlation(data, ['micro cells/mL', 'test1', 'test2'])

In [None]:
dates = t['datetime'].tolist()

In [None]:
lab_scale = [209.56185858585857,
 67.63110000000002,
 430.7411294498382,
 111.36646808510639,
 120.40339999999998,
 84.57114285714286,
 27.5385,
 17.528000000000002,
 8,
 9.436,
 4.078,
 40.78158333333332,
 126.14485714285715,
 54.91462500000001,
 6.0064,
 34.4175,
 1351.8,
 3,
 115.64233333333333,
 19,
 16,
 10,
 41.6185,
 10,
 12,
 13]
pier_scale = [226.0693222060958,
 76.54969971671387,
 609.2260073529412,
 116.65848275862069,
 125.30670588235294,
 106.41286363636362,
 33.38,
 11.0176,
 13,
 25.948999999999998,
 3.398333333333334,
 54.187294117647056,
 136.74526530612243,
 61.46822222222221,
 9.385,
 37.859249999999996,
 69,
 115.144,
 131.593,
 9,
 20.39,
 1.252,
 59.45499999999999,
 26.701333333333334,
 568.632,
 10]

lab_dt_scale = dict(zip(dates, lab_scale))
pier_dt_scale = dict(zip(dates, pier_scale))