In [None]:
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 1

In [None]:
import logging
import random

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as plt_colors
from matplotlib.axes._axes import _log as matplotlib_axes_logger
matplotlib_axes_logger.setLevel('ERROR')
import scipy.stats as st
import holoviews as hv
hv.extension('bokeh')
from holoviews import dim
from IPython.display import Markdown, display
from IPython.core.display import HTML

import matplotlib
matplotlib.rc('xtick', labelsize=14)     
matplotlib.rc('ytick', labelsize=14)
matplotlib.rc('axes', labelsize=14, titlesize=14)

pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

from counts_analysis.c_utils import COUNTS_CSV, CLASSES, set_settings, set_counts

#== Load Datasets ==#
df = pd.read_csv(COUNTS_CSV['counts'])
# Dataset without problematic classes (Gyrodinium, Pseudo-nitzchia chain)
df_ = df[df['class'].isin(CLASSES)].reset_index(drop=True)
data = df.copy()

def printmd(string):
    display(Markdown(string))

#=== Set count forms & settings ===#
# COUNT
# SETTING
# Original raw counts
volumetric_counts = set_counts('gtruth', 'cells/mL', micro_default=True)
rc_counts = set_counts('gtruth', 'raw count', micro_default=True)
rc_counts_pred = set_counts('predicted', 'raw count', micro_default=True)

raw_counts = set_counts('gtruth', 'raw count', micro_default=False)
raw_counts_pred = set_counts('predicted', 'raw count', micro_default=False)

rc_settings = set_settings(rc_counts)
print('Example of setting\n{}'.format(rc_settings))
# Relative abundance
rel_counts = set_counts('gtruth', 'relative abundance', micro_default=False)
rel_counts = ['micro cells/mL relative abundance'] + list(rel_counts[1:])
# Classifier predicted counts
rel_counts_pred = set_counts('predicted', 'relative abundance', micro_default=False)
rel_counts_pred = ['micro cells/mL relative abundance'] + list(rel_counts_pred[1:])

#=== Set classifier gtruth vs predictions
lab_gtruth_pred = ['lab {} raw count'.format(lbl) for lbl in ['gtruth', 'predicted']]
pier_gtruth_pred = ['pier {} raw count'.format(lbl) for lbl in ['gtruth', 'predicted']]

def compute_relative_abundance(raw_count, data):
    if 'micro' in raw_count:
        relative_column = 'micro cells/mL relative abundance'
    else:
        relative_column = f'{raw_count.split()[0]} {raw_count.split()[1]} relative abundance'
    data[relative_column] = data.groupby('class')[raw_count].apply(lambda x: x / x.sum() * 100.0 if sum(x) != 0 else x)
    return data

def filter_classes(df, classes):
    return df[~df['class'].isin(classes)].reset_index(drop=True)

def load_absl_counts_dataset(data):
    df = data.copy()
    
    return df

def load_baseline_dataset(data):
    df = data.copy()
    
    return df

absl_counts = load_absl_counts_dataset(df.copy())
baseline = load_baseline_dataset(df.copy())

In [None]:
"""
# Raw Counts Dataset
"""
y = baseline.copy()
COUNTS = raw_counts
dataset_type = 'RAW_COUNTS'.upper()
printmd(f'# {dataset_type} ERROR ANALYSIS')
printmd('Camera Counts')
display(y[['class', 'datetime'] + list(COUNTS)].head(10))
# printmd('Automated Classifier Counts')
# display(y[['class', 'datetime'] + list(rc_counts_pred)].head(10))

#=== plot distributions ===#
from counts_analysis.plot_class_summary import plot_summary_sampling_class_dist
# printmd('Original Relative Abundance')
# plot_summary_sampling_class_dist(df, rel_counts, False)
printmd(f'### {dataset_type} Camera Distribution')
plot_summary_sampling_class_dist(y, COUNTS, True, relative=False)

# printmd(f'### {dataset_type} Automated Classifier Counts Distribution')
# plot_summary_sampling_class_dist(y, rc_counts_pred, True, relative=False)

from validate_exp.stat_fns import mase, investigate_mase, pearson, concordance_correlation_coefficient

# ms = investigate_mase(y.groupby('class').get_group('Prorocentrum micans'), gtruth=rc_counts[0], pred=rc_counts[1])
# ms['scaled_error'] = ms['error'] / ms['naive']
# display(ms)
# print(np.mean(ms['scaled_error']))

# Set evaluation metric
stat = mase

# Set settings
settings_ = [set_settings(count) for count in [raw_counts, raw_counts_pred]]
count_forms = dict(zip(['raw_counts', 'raw_counts_pred'], settings_))

from eval_counts import compare_count_forms

# Evaluate count forms
printmd(f'# {dataset_type} MASE')
settings_score = compare_count_forms(count_forms, stat, y)

# Evaluate count forms
printmd(f'# {dataset_type} Pearson')
settings_score = compare_count_forms(count_forms, pearson, y)

In [None]:
COUNTS = raw_counts

def _compute_mase(cls_df, x, y, error_name):
    ms = investigate_mase(cls_df, gtruth=x, pred=y)
    ms[error_name] = ms['error'] / ms['naive']
    return ms

LM_MASE = 'MASE (lab - micro)'
PM_MASE = 'MASE (pier - micro)'
PL_MASE = 'MASE (pier - lab)'

cls = 'Lingulodinium polyedra'
cls_df = y.groupby('class').get_group(cls)
lm = _compute_mase(cls_df, x=COUNTS[0], y=COUNTS[1], error_name=LM_MASE)
pm = _compute_mase(cls_df, x=COUNTS[0], y=COUNTS[2], error_name=PM_MASE)
pl = _compute_mase(cls_df, x=COUNTS[1], y=COUNTS[2], error_name=PL_MASE)

printmd(f'# {cls} Error Analysis')
printmd('### error lab - micro')
display(lm)
print(f'Final Score (lab - micro): {np.mean(lm[LM_MASE])} ({np.std(lm[LM_MASE])})\n')
printmd('### error pier - micro')
display(pm)
print(f'Final Score (pier - micro): {np.mean(pm[PM_MASE])} ({np.std(pm[PM_MASE])})\n')
printmd('### error pier - lab')
display(pl)
print(f'Final Score (pier - lab): {np.mean(pl[PL_MASE])} ({np.std(pl[PL_MASE])})\n')

In [None]:
COUNTS = raw_counts_pred

def _compute_mase(cls_df, x, y, error_name):
    ms = investigate_mase(cls_df, gtruth=x, pred=y)
    ms[error_name] = ms['error'] / ms['naive']
    return ms

LM_MASE = 'MASE (lab - micro)'
PM_MASE = 'MASE (pier - micro)'
PL_MASE = 'MASE (pier - lab)'

cls = 'Lingulodinium polyedra'
cls_df = y.groupby('class').get_group(cls)
lm = _compute_mase(cls_df, x=COUNTS[0], y=COUNTS[1], error_name=LM_MASE)
pm = _compute_mase(cls_df, x=COUNTS[0], y=COUNTS[2], error_name=PM_MASE)
pl = _compute_mase(cls_df, x=COUNTS[1], y=COUNTS[2], error_name=PL_MASE)

printmd(f'# {cls} (CLASSIFIER) Error Analysis')
printmd('### error lab - micro')
display(lm)
print(f'Final Score (lab - micro): {np.mean(lm[LM_MASE])} ({np.std(lm[LM_MASE])})\n')
printmd('### error pier - micro')
display(pm)
print(f'Final Score (pier - micro): {np.mean(pm[PM_MASE])} ({np.std(pm[PM_MASE])})\n')
printmd('### error pier - lab')
display(pl)
print(f'Final Score (pier - lab): {np.mean(pl[PL_MASE])} ({np.std(pl[PL_MASE])})\n')

In [None]:
cls = 'Prorocentrum micans'
cls_df = y.groupby('class').get_group(cls)
lm = _compute_mase(cls_df, x=COUNTS[0], y=COUNTS[1], error_name=LM_MASE)
pm = _compute_mase(cls_df, x=COUNTS[0], y=COUNTS[2], error_name=PM_MASE)
pl = _compute_mase(cls_df, x=COUNTS[1], y=COUNTS[2], error_name=PL_MASE)

printmd(f'# {cls} Error Analysis')
printmd('### error lab - micro')
display(lm)
print(f'Final Score (lab - micro): {np.mean(lm[LM_MASE])} ({np.std(lm[LM_MASE])})\n')
printmd('### error pier - micro')
display(pm)
print(f'Final Score (pier - micro): {np.mean(pm[PM_MASE])} ({np.std(pm[PM_MASE])})\n')
printmd('### error pier - lab')
display(pl)
print(f'Final Score (pier - lab): {np.mean(pl[PL_MASE])} ({np.std(pl[PL_MASE])})\n')

In [None]:
COUNTS = raw_counts_pred

cls = 'Prorocentrum micans'
cls_df = y.groupby('class').get_group(cls)
lm = _compute_mase(cls_df, x=COUNTS[0], y=COUNTS[1], error_name=LM_MASE)
pm = _compute_mase(cls_df, x=COUNTS[0], y=COUNTS[2], error_name=PM_MASE)
pl = _compute_mase(cls_df, x=COUNTS[1], y=COUNTS[2], error_name=PL_MASE)

printmd(f'# {cls} Error Analysis')
printmd('### error lab - micro')
display(lm)
print(f'Final Score (lab - micro): {np.mean(lm[LM_MASE])} ({np.std(lm[LM_MASE])})\n')
printmd('### error pier - micro')
display(pm)
print(f'Final Score (pier - micro): {np.mean(pm[PM_MASE])} ({np.std(pm[PM_MASE])})\n')
printmd('### error pier - lab')
display(pl)
print(f'Final Score (pier - lab): {np.mean(pl[PL_MASE])} ({np.std(pl[PL_MASE])})\n')

In [None]:
def display_side_by_side(dfs:list, captions:list):
    """Display tables side by side to save vertical space
    Input:
        dfs: list of pandas.DataFrame
        captions: list of table captions
    """
    output = ""
    combined = dict(zip(captions, dfs))
    for caption, df in combined.items():
        output += df.style.set_table_attributes("style='display:inline'").set_caption(caption)._repr_html_()
        output += "\xa0\xa0\xa0"
    display(HTML(output))
    
def _compute_mase(cls_df, x, y, error_name):
    ms = investigate_mase(cls_df, gtruth=x, pred=y)
    ms[error_name] = ms['error'] / ms['naive']
    return ms

COUNTS = raw_counts

LM_MASE = 'MASE (lab - micro)'
PM_MASE = 'MASE (pier - micro)'
PL_MASE = 'MASE (pier - lab)'

scaled_error = pd.DataFrame()
for cls, cls_df in y.groupby('class'):
#     if cls in ['Ceratium falcatiforme or fusus', 'Ceratium furca', 'Cochlodinium']:
#         continue
        
    lm = _compute_mase(cls_df, x=COUNTS[0], y=COUNTS[1], error_name=LM_MASE)
    pm = _compute_mase(cls_df, x=COUNTS[0], y=COUNTS[2], error_name=PM_MASE)
    pl = _compute_mase(cls_df, x=COUNTS[1], y=COUNTS[2], error_name=PL_MASE)
    printmd(f'# {cls}')
    printmd('### error lab - micro')
    display(lm)
    print(f'Final Score (lab - micro): {np.mean(lm[LM_MASE])}\n')
    printmd('### error pier - lab')
    display(pl)
    print(f'Final Score (lab - micro): {np.mean(pl[PL_MASE])}\n')

    cls_error = lm.merge(pm, on=['class', 'datetime'])
    cls_error = cls_error.merge(pl, on=['class', 'datetime'])

    scaled_error = scaled_error.append(cls_error)
    

In [None]:
from bokeh.models.formatters import DatetimeTickFormatter
formatter = DatetimeTickFormatter(months='%m/%d')

In [None]:
%%opts HeatMap[colorbar=True, width=1000, height=300, xrotation=60, tools=['hover'], shared_axes=True, fontscale=1.5]

import hvplot.pandas

scaled_error1 = scaled_error.copy()

scaled_error1 = scaled_error1[scaled_error1['class'].isin(['Lingulodinium polyedra', 'Prorocentrum micans'])]

h1 = scaled_error1.hvplot.heatmap(x='datetime', y='class', C=LM_MASE, cmap='coolwarm').opts(title=LM_MASE)
h2 = scaled_error1.hvplot.heatmap(x='datetime', y='class', C=PM_MASE, cmap='coolwarm').opts(title=PM_MASE)
h3 = scaled_error1.hvplot.heatmap(x='datetime', y='class', C=PL_MASE, cmap='coolwarm').opts(title=PL_MASE)
hv.Layout(h1 + h2 + h3).cols(1)

In [None]:
%%opts HeatMap[colorbar=True, width=1000, height=300, xrotation=60, tools=['hover'], shared_axes=True, fontscale=1.5]

import hvplot.pandas

scaled_error1 = scaled_error.copy()

scaled_error1 = scaled_error1[scaled_error1['class'].isin(['Lingulodinium polyedra', 'Prorocentrum micans'])]

for ms in [LM_MASE, PM_MASE, PL_MASE]:
    condition = scaled_error1[ms] > 1.0
    scaled_error1.loc[condition, ms] = 1.0

h1 = scaled_error1.hvplot.heatmap(x='datetime', y='class', C=LM_MASE, cmap='coolwarm').opts(title=LM_MASE)
h2 = scaled_error1.hvplot.heatmap(x='datetime', y='class', C=PM_MASE, cmap='coolwarm').opts(title=PM_MASE)
h3 = scaled_error1.hvplot.heatmap(x='datetime', y='class', C=PL_MASE, cmap='coolwarm').opts(title=PL_MASE)
hv.Layout(h1 + h2 + h3).cols(1)