In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import os

In [None]:
def metric_across_epochs(project_path, submission_id, level=2, data_set='valid', protein='agg', metric='pcorr', ylim=None, return_df=False):
    ''' Plot eval metric across epochs for a given set (eg valid)
    project_path: base project path, eg /cluster/work/grlab/projects/projects2021-multivstain/
    submission_id: job submission_id
    level: resolution level (2,4,6)
    data_set: which data split to use (train, valid, test)
    protein: selected proteins (separated by comma) or "agg" for aggregated stats
    metric: evauation metric used (eg pcorr)
    ylim: limits of y-axis
    '''
    fpath = Path(project_path).joinpath('results', submission_id, 'chkpt_selection')
    fname = metric+'_across_epochs-level_'+str(level)+'-'+data_set+'.csv'
    assert os.path.exists(fpath.joinpath(fname)), 'Requested file does not exist'
    df = pd.read_csv(fpath.joinpath(fname), index_col=[0])
    df = df.reset_index(drop=True)
    if protein=='agg':
        df = df.drop_duplicates('epoch')
        yaxis = 'agg_per_epoch'
        hue = None
    else:
        proteins = protein.split(',')
        df = df.loc[df.protein.isin(proteins),:]
        yaxis = metric
        hue = 'protein'
    sns.scatterplot(x='epoch', y=yaxis, data=df, hue=hue, legend=False)
    sns.lineplot(x='epoch', y=yaxis, data=df, hue=hue)
    if ylim is not None:
        plt.ylim(ylim)
    if hue is not None:
        plt.legend(bbox_to_anchor=(1,1))
    plt.show()
    if return_df:
        return df
    
    
def metric_across_epochs_sets(project_path, submission_id, level=2, data_set1='valid', data_set2='train', protein='agg', metric='pcorr', ylim=None):
    ''' Plot eval metric across epochs and across data splits (data_set1, data_set2 eg valid and train)
    project_path: base project path, eg /cluster/work/grlab/projects/projects2021-multivstain/
    submission_id: job submission_id
    level: resolution level (2,4,6)
    data_set: which data split to use (train, valid, test)
    protein: selected proteins (separated by comma) or "agg" for aggregated stats
    metric: evauation metric used (eg pcorr)
    ylim: limits of y-axis
    '''
    fpath = Path(project_path).joinpath('results', submission_id, 'chkpt_selection')
    fname = metric+'_across_epochs-level_'+str(level)+'-'+data_set1+'.csv'
    assert os.path.exists(fpath.joinpath(fname)), 'Requested file does not exist'
    df = pd.read_csv(fpath.joinpath(fname), index_col=[0])
    df['data_set'] = data_set1
    assert os.path.exists(fpath.joinpath(fname.replace(data_set1, data_set2))), 'Requested file does not exist'
    df_train = pd.read_csv(fpath.joinpath(fname.replace(data_set1, data_set2)), index_col=[0])
    df_train['data_set'] = data_set2
    df = pd.concat([df, df_train])
    df = df.reset_index(drop=True)
    
    proteins = protein.split(',')
    if protein=='agg':
        df = df.loc[:,['epoch', 'data_set', 'agg_per_epoch']].drop_duplicates(['epoch', 'data_set'])
        df['protein'] = 'agg'
        yaxis = 'agg_per_epoch'
    else:
        df = df.loc[df.protein.isin(proteins),:]
        yaxis = metric
    for protein in proteins:
        sns.lineplot(x='epoch', y=yaxis, data=df.loc[df.protein==protein,:], hue='data_set')
        if ylim is not None:
            plt.ylim(ylim)
        plt.legend(bbox_to_anchor=(1,1))
        plt.title(protein)
        plt.show()


In [None]:
project_path = '/raid/sonali/project_mvs/'
submission_id = "mj3pqeyk_dataaug-v2-flip_split3_selected-snr_no-wt_no-checkerboard"
data_set = 'valid'

In [None]:
# for a given dataset
metric_across_epochs(project_path, submission_id, level=2, data_set=data_set, protein='agg')#, ylim=(-0.1,0.3))
metric_across_epochs(project_path, submission_id, level=2, data_set=data_set, protein='MelanA,CD3,CD8a,CD20')#, ylim=(-0.1,0.3))

In [None]:
# look at top 10 epochs wrt aggregated metric (looking at agg_per_epoch column)
df = metric_across_epochs(project_path, submission_id, level=2, data_set=data_set, protein='agg', return_df=True)
display(df.sort_values(by=['agg_per_epoch']).tail(10))

In [None]:
# Look at metric across epochs for a specific protein and extract the top 5 epochs (looking at pcorr column)
for protein in ['CD20', 'MelanA']:
    print(protein)
    df_prot = metric_across_epochs(project_path, submission_id, level=2, data_set=data_set, protein=protein, return_df=True)
    display(df_prot.sort_values(by=['pcorr']).tail(5))

In [None]:
# # across datasets
# metric_across_epochs_sets(project_path, submission_id, level=2, protein='agg', ylim=(-0.1,0.3))
# metric_across_epochs_sets(project_path, submission_id, level=2, protein='MelanA,CD3,CD8a', ylim=(-0.1,0.3))