In [None]:
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 1

# Investigate Results

In [None]:
import logging
import random

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as plt_colors
from matplotlib.axes._axes import _log as matplotlib_axes_logger
matplotlib_axes_logger.setLevel('ERROR')
import scipy.stats as st
import holoviews as hv
hv.extension('bokeh')
from holoviews import dim
from IPython.display import Markdown, display
from IPython.core.display import HTML

import matplotlib
matplotlib.rc('xtick', labelsize=14)     
matplotlib.rc('ytick', labelsize=14)
matplotlib.rc('axes', labelsize=14, titlesize=14)

pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

from counts_analysis.c_utils import COUNTS_CSV, CLASSES, set_settings, set_counts

#== Load Datasets ==#
df = pd.read_csv(COUNTS_CSV['counts'])
# Dataset without problematic classes (Gyrodinium, Pseudo-nitzchia chain)
df_ = df[df['class'].isin(CLASSES)].reset_index(drop=True)
data = df.copy()

def printmd(string):
    display(Markdown(string))

#=== Set count forms & settings ===#
# COUNT
# SETTING
# Original raw counts
rc_counts = set_counts('gtruth', 'raw count', micro_default=True)
rc_counts_pred = set_counts('predicted', 'raw count', micro_default=True)
rc_settings = set_settings(rc_counts)
print('Example of setting\n{}'.format(rc_settings))
# Relative abundance
rel_counts = set_counts('gtruth', 'relative abundance', micro_default=False)
rel_counts = ['micro cells/mL relative abundance'] + list(rel_counts[1:])
# Classifier predicted counts
rel_counts_pred = set_counts('predicted', 'relative abundance', micro_default=False)
rel_counts_pred = ['micro cells/mL relative abundance'] + list(rel_counts_pred[1:])

#=== Set classifier gtruth vs predictions
lab_gtruth_pred = ['lab {} raw count'.format(lbl) for lbl in ['gtruth', 'predicted']]
pier_gtruth_pred = ['pier {} raw count'.format(lbl) for lbl in ['gtruth', 'predicted']]

In [None]:
def display_side_by_side(dfs:list, captions:list):
    """Display tables side by side to save vertical space
    Input:
        dfs: list of pandas.DataFrame
        captions: list of table captions
    """
    output = ""
    combined = dict(zip(captions, dfs))
    for caption, df in combined.items():
        output += df.style.set_table_attributes("style='display:inline'").set_caption(caption)._repr_html_()
        output += "\xa0\xa0\xa0"
    display(HTML(output))
    
def compute_relative_abundance(raw_count, data):
    if 'micro' in raw_count:
        relative_column = 'micro cells/mL relative abundance'
    else:
        relative_column = f'{raw_count.split()[0]} {raw_count.split()[1]} relative abundance'
    data[relative_column] = data.groupby('class')[raw_count].apply(lambda x: x / x.sum() * 100.0 if sum(x) != 0 else x)
    return data

def filter_classes(df, classes):
    return df[~df['class'].isin(classes)].reset_index(drop=True)

def load_baseline_dataset(data):
    df = data.copy()
    
    # Filter classes
    df = filter_classes(df, ['Gyrodinium', 'Ceratium falcatiforme or fusus', 'Chattonella', 'Pseudo-nitzschia chain'])
    return df

def load_rel_class_sum_dataset(data):
    df = data.copy()
    # Filter classes
    df = filter_classes(df, ['Gyrodinium', 'Ceratium falcatiforme or fusus', 'Chattonella', 'Pseudo-nitzschia chain'])

    # Compute relative abundance
    for rc in list(rc_counts + rc_counts_pred):
        df = compute_relative_abundance(rc, df)
    return df

def load_seasonal_dataset(data):
    df = data.copy()

    # Filter classes
    df = filter_classes(df, ['Gyrodinium', 'Pseudo-nitzschia chain'])

    # Compute relative abundance
    for rc in list(rc_counts + rc_counts_pred):
        df = compute_relative_abundance(rc, df)

    # Separate into seasonal/nonseasonal dates
    dates = ['2019-05-23', '2019-05-28', '2019-06-03']
    seasonal = df[df['datetime'].isin(dates)]
    nonseasonal = df[~df['datetime'].isin(dates)]
    return seasonal, nonseasonal

baseline = load_baseline_dataset(df.copy())
rel_class_sum = load_rel_class_sum_dataset(df.copy())
seasonal, nonseasonal = load_seasonal_dataset(df.copy())

## Lab vs Pier Performance

In [None]:
#=== plot distributions ===#
y = seasonal.copy()
dataset_type = 'seasonal'.upper()
from counts_analysis.plot_class_summary import plot_summary_sampling_class_dist
# printmd('Original Relative Abundance')
# plot_summary_sampling_class_dist(df, rel_counts, False)
printmd(f'### {dataset_type} Camera Distribution')
plot_summary_sampling_class_dist(y, rel_counts, False, relative=True)

printmd(f'### {dataset_type} Automated Classifier Counts Distribution')
plot_summary_sampling_class_dist(y, rel_counts_pred, False, relative=True)

In [None]:
y = seasonal.copy()

from validate_exp.stat_fns import mase, investigate_mase, pearson, concordance_correlation_coefficient

# Set evaluation metric
stat = mase

# Set settings
settings_ = [set_settings(count) for count in [rel_counts, rel_counts_pred]]
count_forms = dict(zip(['relative', 'relative predicted'], settings_))

from eval_counts import compare_count_forms

# Evaluate count forms
printmd(f'# {dataset_type} MASE')
settings_score = compare_count_forms(count_forms, stat, y)

## Investigate Lab vs Pier Performance

In [None]:
display(settings_score[settings_score['count form'] == 'relative'])
settings_score[settings_score['count form'] == 'relative'][['class', 'lab - micro','pier - micro','pier - lab']].hvplot.bar(x='class')

Classes that increase in error from `lab - micro`: 

Akashiwo, Lingulodinium Polyedra, Prorocentrum micans

In [None]:
from counts_analysis.plot_class_summary import plot_summary_both_count_forms, plot_class_summary

def filter_class(cls, x_data, y_data):
    x_cls_df = x_data[x_data['class'] == cls].reset_index(drop=True)
    y_cls_df = y_data[y_data['class'] == cls].reset_index(drop=True)
    return x_cls_df, y_cls_df
                     
def plot_summary(x, y):
    x_count, x_data, x_relative = x
    y_count, y_data, y_relative = y
    datetime_col = ['datetime']
    display_side_by_side([x_data[datetime_col + list(x_count)], y_data[datetime_col + list(y_count)]], ['raw', 'relative'])
    printmd(f'### Sum total over N={x_data["datetime"].nunique()} days')
    display(x_data[list(x_count)].sum())
    display_side_by_side([x_data[list(x_count)].describe(), y_data[list(y_count)].describe()], ['raw descriptors', 'relative descriptors'])
    return hv.Layout(plot_class_summary(x_count, x_data, relative=x_relative) + plot_class_summary(y_count, y_data, relative=y_relative)).cols(3).opts(shared_axes=False)

def compute_relative_abundance(raw_count, data):
    if 'micro' in raw_count:
        relative_column = 'micro cells/mL relative abundance'
    else:
        relative_column = f'{raw_count.split()[0]} {raw_count.split()[1]} relative abundance'
    data[relative_column] = data.groupby('class')[raw_count].apply(lambda x: x / x.sum() * 100.0 if sum(x) != 0 else x)
    return data
    
def _plot_class_summary(cls, x, y, x_relative=False, y_relative=True, classifier=False):
    x_df, y_df = filter_class(cls, x, y)
    printmd(f'# {cls}')
    if classifier:
        x_counts, y_counts = rc_counts_pred, rel_counts_pred
    else:
        x_counts, y_counts = rc_counts, rel_counts
    return plot_summary((x_counts, x_df, x_relative), (y_counts, y_df, y_relative))

In [None]:
cls = 'Akashiwo'
_plot_class_summary(cls, seasonal, seasonal, x_relative=False, y_relative=True, classifier=False)

In [None]:
def _compute_mase(cls_df, x, y, error_name):
    ms = investigate_mase(cls_df, gtruth=x, pred=y)
    ms[error_name] = ms['error'] / ms['naive']
    return ms

LM_MASE = 'MASE (lab - micro)'
PM_MASE = 'MASE (pier - micro)'
PL_MASE = 'MASE (pier - lab)'

cls_df = y.groupby('class').get_group(cls)
lm = _compute_mase(cls_df, x=rel_counts[0], y=rel_counts[1], error_name=LM_MASE)
pm = _compute_mase(cls_df, x=rel_counts[0], y=rel_counts[2], error_name=PM_MASE)
pl = _compute_mase(cls_df, x=rel_counts[1], y=rel_counts[2], error_name=PL_MASE)

printmd(f'# {cls} Error Analysis')
printmd('### error lab - micro')
display(lm)
print(f'Final Score (lab - micro): {np.mean(lm[LM_MASE])}\n')
printmd('### error pier - lab')
display(pl)
print(f'Final Score (lab - micro): {np.mean(pl[PL_MASE])}\n')

In [None]:
cls = 'Lingulodinium polyedra'
_plot_class_summary(cls, seasonal, seasonal, x_relative=False, y_relative=True, classifier=False)

In [None]:
cls_df = y.groupby('class').get_group(cls)
lm = _compute_mase(cls_df, x=rel_counts[0], y=rel_counts[1], error_name=LM_MASE)
pm = _compute_mase(cls_df, x=rel_counts[0], y=rel_counts[2], error_name=PM_MASE)
pl = _compute_mase(cls_df, x=rel_counts[1], y=rel_counts[2], error_name=PL_MASE)

printmd(f'# {cls} Error Analysis')
printmd('### error lab - micro')
display(lm)
print(f'Final Score (lab - micro): {np.mean(lm[LM_MASE])}\n')
printmd('### error pier - lab')
display(pl)
print(f'Final Score (lab - micro): {np.mean(pl[PL_MASE])}\n')

In [None]:
cls = 'Prorocentrum micans'
_plot_class_summary(cls, seasonal, seasonal, x_relative=False, y_relative=True, classifier=False)

In [None]:
cls_df = y.groupby('class').get_group(cls)
lm = _compute_mase(cls_df, x=rel_counts[0], y=rel_counts[1], error_name=LM_MASE)
pm = _compute_mase(cls_df, x=rel_counts[0], y=rel_counts[2], error_name=PM_MASE)
pl = _compute_mase(cls_df, x=rel_counts[1], y=rel_counts[2], error_name=PL_MASE)

printmd(f'# {cls} Error Analysis')
printmd('### error lab - micro')
display(lm)
print(f'Final Score (lab - micro): {np.mean(lm[LM_MASE])}\n')
printmd('### error pier - lab')
display(pl)
print(f'Final Score (lab - micro): {np.mean(pl[PL_MASE])}\n')

In [None]:
def display_side_by_side(dfs:list, captions:list):
    """Display tables side by side to save vertical space
    Input:
        dfs: list of pandas.DataFrame
        captions: list of table captions
    """
    output = ""
    combined = dict(zip(captions, dfs))
    for caption, df in combined.items():
        output += df.style.set_table_attributes("style='display:inline'").set_caption(caption)._repr_html_()
        output += "\xa0\xa0\xa0"
    display(HTML(output))
    
def _compute_mase(cls_df, x, y, error_name):
    ms = investigate_mase(cls_df, gtruth=x, pred=y)
    ms[error_name] = ms['error'] / ms['naive']
    return ms

LM_MASE = 'MASE (lab - micro)'
PM_MASE = 'MASE (pier - micro)'
PL_MASE = 'MASE (pier - lab)'

scaled_error = pd.DataFrame()
for cls, cls_df in y.groupby('class'):
#     if cls in ['Ceratium falcatiforme or fusus', 'Ceratium furca', 'Cochlodinium']:
#         continue
        
    lm = _compute_mase(cls_df, x=rel_counts[0], y=rel_counts[1], error_name=LM_MASE)
    pm = _compute_mase(cls_df, x=rel_counts[0], y=rel_counts[2], error_name=PM_MASE)
    pl = _compute_mase(cls_df, x=rel_counts[1], y=rel_counts[2], error_name=PL_MASE)
    printmd(f'# {cls}')
    printmd('### error lab - micro')
    display(lm)
    print(f'Final Score (lab - micro): {np.mean(lm[LM_MASE])}\n')
    printmd('### error pier - lab')
    display(pl)
    print(f'Final Score (lab - micro): {np.mean(pl[PL_MASE])}\n')

    cls_error = lm.merge(pm, on=['class', 'datetime'])
    cls_error = cls_error.merge(pl, on=['class', 'datetime'])

    scaled_error = scaled_error.append(cls_error)
    

### Analysis of stats above
It seems most of the error inflation is due to the use of the lab as the new objective ground truth. We see this the most with the Akashiwo. 

The relative abundances appear to be quite low. This suggests that there's a discrepancy between how much the pier collects vs the lab & micro. 

We also know that this is a relatively rare species that occured during a seasonal event of Lingulodinium polyedra and Prorocentrum micans, meaning it's almost expected that it'd be difficult for the microscopy to detect it. 

Thus it could be a combination of switching the objective ground truth to a system that is less able to precisely estimate similar counts. The pier is a level better than it but less than the micro.

- Microscopy better reflected with the pier than the lab because more precise/correlated counts for these dominant species. micro and lab agree moreso because of the amount of rare species. lab and pier is the worse of the two, because the lab 

In [None]:
scaled_error.hvplot.heatmap(x='datetime', y='class', C=LM_MASE, cmap='coolwarm')

In [None]:
scaled_error.hvplot.heatmap(x='datetime', y='class', C=PL_MASE, cmap='coolwarm')

In [None]:
"""
# Seasonal (Class Relative Abundance)
"""
y = seasonal.copy()
dataset_type = 'seasonal'.upper()
printmd(f'# {dataset_type} ERROR ANALYSIS')
printmd('Camera Counts')
display(y[['class', 'datetime'] + list(rel_counts)])
# printmd('Automated Classifier Counts')
# display(y[['class', 'datetime'] + list(rel_counts_pred)].head(10))

In [None]:
## Investigate Class Errors over Time

In [None]:
%%opts Scatter [tools=['hover'], legend_position='left', color_index='class', width=700, height=500, logx=True]
import math
from collections import defaultdict

def plot_class_mase_scores_vs_gtruth_count_for_each_setting(sampled_data,
                                                            score_settings,
                                                            interactive=False):
    if not interactive:
        current_palette_7 = sns.color_palette("coolwarm", 9)
        sns.set_palette(current_palette_7)
        fig, ax = plt.subplots(1, 3, figsize=(20, 5))
        i = 0
        for label, (gtruth_setting, experimental_setting) in score_settings.items():
            s = sns.scatterplot(x=gtruth_setting, y=label, hue='class',
                                data=sampled_data, ax=ax[i], s=50)
            ax[i].set_xscale('symlog')
            ax[i].set_ylabel('MASE')
            ax[i].set_xlabel(
                "Logged Gtruth Counts ({})".format(get_units(gtruth_setting)))
            ax[i].set_title(label)
            if i <= 1:
                s.legend_.remove()
            i += 1
        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        plt.show()
    else:
        label = LM_MASE
        sc1 = hv.Scatter(sampled_data, score_settings[label][0],
                         [label, 'class', 'datetime'])
        sc1 = sc1.redim.range(gtruth_setting=(0, None))
        sc1 = sc1.opts(cmap='coolwarm', size=7, title=label)

        label = PM_MASE
        sc2 = hv.Scatter(sampled_data, score_settings[label][0],
                         [label, 'class', 'datetime'])
        sc2 = sc2.redim.range(gtruth_setting=(0, None))
        sc2 = sc2.opts(cmap='coolwarm', size=7, title=label)

        label = PL_MASE
        sc3 = hv.Scatter(sampled_data, score_settings[label][0],
                         [label, 'class', 'datetime'])
        sc3 = sc3.redim.range(gtruth_setting=(0, None))
        sc3 = sc3.opts(cmap='coolwarm', size=7, title=label)

        return sc1, sc2, sc3

scaled_error = scaled_error.rename({'micro cells/mL relative abundance_x': 'micro cells/mL relative abundance',
                               'lab gtruth relative abundance_x': 'lab gtruth relative abundance' }, axis=1)

data = scaled_error.copy()
scores_df = defaultdict(list)
counts = rel_counts

score_settings = {LM_MASE: (counts[0], counts[1]),
                  PM_MASE: (counts[0], counts[2]),
                  PL_MASE: (counts[1], counts[2])}


# plot_class_mase_scores_vs_gtruth_count_for_each_setting(sampled_data, score_settings)

print('\nInteractive')
sc1, sc2, sc3 = plot_class_mase_scores_vs_gtruth_count_for_each_setting(data,
                                                                        score_settings,
                                                                        interactive=True)
hv.Layout(sc1 + sc2 + sc3).cols(1)

In [None]:
%%opts HeatMap[colorbar=True, width=1000, height=300, xrotation=60, tools=['hover'], shared_axes=True]
# PLOT SMAPE Class over Time

data = data.sort_values(['datetime', 'class'])
sdata = hv.Dataset(data=data, kdims=['class', 'datetime'])

label = LM_MASE
t1 = sdata.to(hv.HeatMap, ['datetime', 'class'], label).opts(title=label)

label = PM_MASE
t2 = sdata.to(hv.HeatMap, ['datetime', 'class'], label).opts(title=label)

label = PL_MASE
t3 = sdata.to(hv.HeatMap, ['datetime', 'class'], label).opts(title=label)

hv.Layout(t1 + t2 + t3).cols(1)

From the data above^, it seems that our error for the `pier - lab` is due to the rare classes causing the error to increase. 

In [None]:
# Evaluate count forms
printmd(f'# {dataset_type} Pearson')
settings_score = compare_count_forms(count_forms, pearson, y)

# NonSeasonal Results

In [None]:
"""
# NonSeasonal (Class Relative Abundance)
"""
y = nonseasonal.copy()
dataset_type = 'nonseasonal'.upper()
printmd(f'# {dataset_type} ERROR ANALYSIS')
printmd('Camera Counts')
display(y[['class', 'datetime'] + list(rel_counts)].head(10))
printmd('Automated Classifier Counts')
display(y[['class', 'datetime'] + list(rel_counts_pred)].head(10))

# #=== plot distributions ===#
# from counts_analysis.plot_class_summary import plot_summary_sampling_class_dist
# # printmd('Original Relative Abundance')
# # plot_summary_sampling_class_dist(df, rel_counts, False)
# printmd(f'### {dataset_type} Camera Distribution')
# plot_summary_sampling_class_dist(y, rel_counts, False, relative=True)

# printmd(f'### {dataset_type} Automated Classifier Counts Distribution')
# plot_summary_sampling_class_dist(y, rel_counts_pred, False, relative=True)

from validate_exp.stat_fns import mase, investigate_mase, pearson, concordance_correlation_coefficient

ms = investigate_mase(y.groupby('class').get_group('Prorocentrum micans'), gtruth=rel_counts[0], pred=rel_counts[1])
ms['scaled_error'] = ms['error'] / ms['naive']
# display(ms)
# print(np.mean(ms['scaled_error']))

# Set evaluation metric
stat = mase

# Set settings
settings_ = [set_settings(count) for count in [rel_counts, rel_counts_pred]]
count_forms = dict(zip(['relative', 'relative predicted'], settings_))

from eval_counts import compare_count_forms

# Evaluate count forms
printmd(f'# {dataset_type} MASE')
settings_score = compare_count_forms(count_forms, stat, y)

# Evaluate count forms
printmd(f'# {dataset_type} Pearson')
settings_score = compare_count_forms(count_forms, pearson, y)

In [None]:
def display_side_by_side(dfs:list, captions:list):
    """Display tables side by side to save vertical space
    Input:
        dfs: list of pandas.DataFrame
        captions: list of table captions
    """
    output = ""
    combined = dict(zip(captions, dfs))
    for caption, df in combined.items():
        output += df.style.set_table_attributes("style='display:inline'").set_caption(caption)._repr_html_()
        output += "\xa0\xa0\xa0"
    display(HTML(output))
    
def investigate_mase(df, x, y):
    gtruth = rel_counts[x]
    pred = rel_counts[y]
    temp = df[['class', 'datetime', rc_counts[x], gtruth, rc_counts[y], pred]]
    temp['error'] = np.abs(temp[gtruth] - temp[pred])
    temp['naive'] = np.mean(np.abs(np.diff(temp[gtruth])))
    return temp
    
def _compute_mase(cls_df, x, y, error_name):
    ms = investigate_mase(cls_df, x=x, y=y)
    ms[error_name] = ms['error'] / ms['naive']
    return ms

LM_MASE = 'MASE (lab - micro)'
PM_MASE = 'MASE (pier - micro)'
PL_MASE = 'MASE (pier - lab)'

scaled_error1 = pd.DataFrame()
for cls, cls_df in y.groupby('class'):
        
    lm = _compute_mase(cls_df, x=0, y=1, error_name=LM_MASE)
    pm = _compute_mase(cls_df, x=0, y=2, error_name=PM_MASE)
    pl = _compute_mase(cls_df, x=1, y=2, error_name=PL_MASE)
    
    if cls in ['Akashiwo']:
        printmd(f'# {cls}')
        printmd('### error lab - micro')
        display(lm.sort_values(by='error'))
        print(f'Sum counts (lab - micro):\n{cls_df[[rc_counts[0], rc_counts[1]]].sum()}')
        print(f'Final Score (lab - micro): {np.mean(lm[LM_MASE])}\n')
        printmd('### error pier - micro')
        display(pm.sort_values(by='error'))
        print(f'Sum counts (pier - micro):\n{cls_df[[rc_counts[0], rc_counts[2]]].sum()}')
        print(f'Final Score (pier - micro): {np.mean(pm[PM_MASE])}\n')
        printmd('### error pier - lab')
        display(pl.sort_values(by='error'))
        print(f'Sum counts (pier - lab):\n{cls_df[[rc_counts[1], rc_counts[2]]].sum()}')
        print(f'Final Score (pier - lab): {np.mean(pl[PL_MASE])}\n')

    cls_error = lm.merge(pm, on=['class', 'datetime'])
    cls_error = cls_error.merge(pl, on=['class', 'datetime'])

    scaled_error1 = scaled_error1.append(cls_error)
    

In [None]:
%%opts HeatMap[colorbar=True, width=1000, height=300, xrotation=60, tools=['hover'], shared_axes=True]

import hvplot.pandas

scaled_error1 = scaled_error.copy()

for ms in [LM_MASE, PM_MASE, PL_MASE]:
    condition = scaled_error1[ms] > 1.0
    scaled_error1.loc[condition, ms] = 1.0

h1 = scaled_error1.hvplot.heatmap(x='datetime', y='class', C=LM_MASE, cmap='coolwarm').opts(title=LM_MASE)
h2 = scaled_error1.hvplot.heatmap(x='datetime', y='class', C=PM_MASE, cmap='coolwarm').opts(title=PM_MASE)
h3 = scaled_error1.hvplot.heatmap(x='datetime', y='class', C=PL_MASE, cmap='coolwarm').opts(title=PL_MASE)
hv.Layout(h1 + h2 + h3).cols(1)

In [None]:
counts_df = y[['datetime', 'class'] + list(rel_counts)]
counts_df = counts_df.melt(id_vars=['class', 'datetime'], var_name=['setting'], value_name='relative abundance')

In [None]:
%%opts HeatMap[colorbar=True, width=1000, height=300, xrotation=60, tools=['hover'], shared_axes=True]

import hvplot.pandas

counts_df = y.copy()
counts_df = counts_df.sort_values(by=['class', 'datetime'])
sdata = hv.Dataset(data=counts_df, kdims=['class', 'datetime'])

label = rel_counts[0]
t = sdata.to(hv.HeatMap, ['datetime', 'class'], label).opts(title=label)

label = rel_counts[1]
t1 = sdata.to(hv.HeatMap, ['datetime', 'class'], label).opts(title=label)

label = rel_counts[2]
t2 = sdata.to(hv.HeatMap, ['datetime', 'class'], label).opts(title=label)

hv.Layout(t+t1+t2).cols(1)

In [None]:
%%opts Scatter [tools=['hover'], legend_position='left', color_index='class', width=700, height=500, logx=True]

def plot_class_mase_scores_vs_gtruth_count_for_each_setting(sampled_data,
                                                            score_settings,
                                                            interactive=False):
    if not interactive:
        current_palette_7 = sns.color_palette("coolwarm", 7)
        sns.set_palette(current_palette_7)
        fig, ax = plt.subplots(1, 3, figsize=(20, 5))
        i = 0
        for label, (gtruth_setting, experimental_setting) in score_settings.items():
            s = sns.scatterplot(x=gtruth_setting, y=label, hue='class',
                                data=sampled_data, ax=ax[i], s=50)
            ax[i].set_yscale('symlog')
            ax[i].set_xscale('symlog')
            ax[i].set_ylabel('MASE')
            ax[i].set_xlabel(
                "Logged Gtruth Counts ({})".format(get_units(gtruth_setting)))
            ax[i].set_title(label)
            if i <= 1:
                s.legend_.remove()
            i += 1
        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        plt.show()
        return None, None, None
    else:
        label = LM_MASE
        sc1 = hv.Scatter(sampled_data, score_settings[label][0],
                         [label, 'class', 'datetime'])
        sc1 = sc1.redim.range(gtruth_setting=(0, None))
        sc1 = sc1.opts(cmap='coolwarm', size=7, title=label)

        label = PM_MASE
        sc2 = hv.Scatter(sampled_data, score_settings[label][0],
                         [label, 'class', 'datetime'])
        sc2 = sc2.redim.range(gtruth_setting=(0, None))
        sc2 = sc2.opts(cmap='coolwarm', size=7, title=label)

        label = PL_MASE
        sc3 = hv.Scatter(sampled_data, score_settings[label][0],
                         [label, 'class', 'datetime'])
        sc3 = sc3.redim.range(gtruth_setting=(0, None))
        sc3 = sc3.opts(cmap='coolwarm', size=7, title=label)

        return sc1, sc2, sc3

scaled_error = scaled_error.rename({'micro cells/mL relative abundance_x': 'micro cells/mL relative abundance',
                               'lab gtruth relative abundance_x': 'lab gtruth relative abundance' }, axis=1)

data = scaled_error.copy()
scores_df = defaultdict(list)
counts = rel_counts

score_settings = {LM_MASE: (counts[0], counts[1]),
                  PM_MASE: (counts[0], counts[2]),
                  PL_MASE: (counts[1], counts[2])}


# plot_class_mase_scores_vs_gtruth_count_for_each_setting(sampled_data, score_settings)

print('\nInteractive')
sc1, sc2, sc3 = plot_class_mase_scores_vs_gtruth_count_for_each_setting(data,
                                                                        score_settings,
                                                                        interactive=False)
fig, ax = plt.subplots(1, 3, figsize=(20, 5))
i = 0
for label, (gtruth_setting, experimental_setting) in score_settings.items():
    ax[i].hist(data[gtruth_setting], bins=100)
#     s = sns.scatterplot(x=gtruth_setting, y=label, hue='class',
#                         data=sampled_data, ax=ax[i], s=50)
    ax[i].set_yscale('symlog')
    ax[i].set_xscale('symlog')
#     ax[i].set_ylabel('MASE')
#     ax[i].set_xlabel(
#         "Logged Gtruth Counts ({})".format(get_units(gtruth_setting)))
#     ax[i].set_title(label)
#     if i <= 1:
#         s.legend_.remove()
    i += 1


hv.Layout(sc1 + sc2 + sc3).cols(3)

In [None]:
from counts_analysis.c_utils import get_units

def plot_class_smape_scores_vs_gtruth_count_combined_settings(data, score_settings):
    label = LM_MASE
    sm = data[[score_settings[label][0], 'class', 'datetime', LM_MASE, PM_MASE]]
    sm = sm.melt(id_vars=[score_settings[label][0], 'class', 'datetime'], var_name=['setting'], value_name='mase')

    current_palette_7 = sns.color_palette("coolwarm", 7)
    sns.set_palette(current_palette_7)
    label = LM_MASE
    markers = {LM_MASE: "X", PM_MASE: "^"}
    sns.scatterplot(x=score_settings[label][0], y='mase', hue='class', markers=markers, style='setting', data=sm, s=75)
    plt.xscale('symlog')
    plt.yscale('symlog')
    plt.xlabel('Logged Gtruth Micro ({})'.format(get_units(score_settings[label][0])))
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

    plt.figure()
    sns.scatterplot(x=score_settings[label][0], y='mase', hue='class', markers=markers, style='setting', data=sm, s=75)
    # plt.xscale('symlog')
    plt.xlabel('Gtruth Micro ({})'.format(get_units(score_settings[label][0])))
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    
plot_class_smape_scores_vs_gtruth_count_combined_settings(scaled_error, score_settings)
plt.show();

# Combined Seasonal/NonSeasonal Error Analysis

In [None]:
mase_errors = scaled_error.append(scaled_error1)
mase_errors.describe()

In [None]:
plot_class_smape_scores_vs_gtruth_count_combined_settings(mase_errors, score_settings); plt.show()