In [46]:
import pandas as pd 
import re 
import holoviews as hv
import numpy as np
import bokeh.palettes

hv.extension('bokeh')

split_dict = {
    'aav_1': 'des_mut' ,
    'aav_2': 'mut_des',
    'aav_3': 'one_vs_many',
    'aav_4': 'two_vs_many',
    'aav_5': 'seven_vs_many',
    'aav_6': 'low_vs_high',
    'aav_7': 'sampled',
    'meltome_mixed' : 'mixed_split',
    'meltome_human' : 'human',
    'meltome_humancell' : 'human_cell',
    'gb1_1': 'one_vs_rest',
    'gb1_2': 'two_vs_rest',
    'gb1_3': 'three_vs_rest',
    'gb1_4': 'sampled',
    'gb1_5': 'low_vs_high'
}


# Get data

In [150]:
def get_dataset(dataset):
    df = pd.read_csv(dataset+'_results.csv', header = None)
    df.columns = ['dataset', 'model', 'split', 'train_rho', 'train_mse', 'test_rho', 'test_mse', 'epochs_trained', 'lr', 'kernel_size', 'input_size', 'dropout', 'alpha', 'gb1_shorten']
    df = df[df.model!= 'ridge'] # remove ridge
    return df

aav = get_dataset('aav')
gb1 = get_dataset('gb1')
meltome = get_dataset('meltome')

# Bootstrapping function to get 95% CI

In [151]:
def bootstrap_samples(samples):
    centerpoint = np.mean(samples)
    new_samples = np.random.choice(samples, (1000,len(samples)))
    new_means = new_samples.mean(axis=1)
    sorted_means = np.sort(new_means)
    
    return (round(centerpoint,3), round(centerpoint-sorted_means[24],3), round(sorted_means[975]-centerpoint,3))

# Generating plots for the overlay

In [94]:
control_cols = ['model','split','spearman rho']

aav_data = [
    ['Ridge', 'one_vs_many', 0.22],
    ['Ridge', 'two_vs_many', 0.03],
    ['Ridge', 'seven_vs_many', 0.65],
    ['Ridge', 'low_vs_high', 0.12],
    ['Ridge', 'sampled', 0.83],
    ['Levenshtein', 'one_vs_many', -0.11],
    ['Levenshtein', 'two_vs_many', 0.57],
    ['Levenshtein', 'seven_vs_many', 0.52],
    ['Levenshtein', 'low_vs_high', 0.25],
    ['Levenshtein', 'sampled', -0.07],
]

aav_controls = pd.DataFrame(data=aav_data, columns=control_cols)

In [95]:
gb1_data = [
    ['Ridge', 'one_vs_rest', 0.28],
    ['Ridge', 'two_vs_rest', 0.59],
    ['Ridge', 'three_vs_rest', 0.76],
    ['Ridge', 'sampled', 0.82],
    ['Ridge', 'low_vs_high', 0.34],
    ['Levenshtein', 'one_vs_rest', 0.17],
    ['Levenshtein', 'two_vs_rest', 0.16],
    ['Levenshtein', 'three_vs_rest', -0.04],
    ['Levenshtein', 'sampled', 0.17],
    ['Levenshtein', 'low_vs_high', -0.10],
    ['BLOSUM', 'one_vs_rest', 0.15],
    ['BLOSUM', 'two_vs_rest', 0.14],
    ['BLOSUM', 'three_vs_rest', 0.01],
    ['BLOSUM', 'sampled', 0.17],
    ['BLOSUM', 'low_vs_high', -0.13],
]

gb1_controls = pd.DataFrame(data=gb1_data, columns=control_cols)


In [96]:
meltome_data = [
    ['Ridge', 'mixed_split', 0.17],
    ['Ridge', 'human', 0.15],
    ['Ridge', 'human_cell', 0.24],
]

meltome_controls = pd.DataFrame(data=meltome_data, columns=control_cols)

In [153]:
def plot_overlay(data, splits, controls):
    
    controls['marker'] = controls['model'].apply(lambda x: {'Ridge':'square', 'Levenshtein':'circle', 'BLOSUM':'diamond'}[x])
    controls['size'] = controls['model'].apply(lambda x: {'Ridge':20, 'Levenshtein':16, 'BLOSUM':12}[x])
    controls['label'] = 'controls'
    
    overlays = []

    for i,split in enumerate(splits):
        
        # General opts
        opts = {'height':600, 'xrotation':45, 'ylim':(0,1), 'xlabel':split} #'xlabel':''
        
            
        temp_data = data.loc[data['split'] == split]
        
        # Controls plot
        points = hv.Points(
            data=controls.loc[controls['split'] == split],
            kdims=['label','spearman rho']
        ).opts(
            width=100, 
            height=600, 
            marker='marker', 
            size='size', 
            ylim=(0,1), 
            color='model', 
            line_color='gray', 
            line_width=1, 
            line_alpha=1, 
#             cmap=(bokeh.palettes.YlGn9[1], bokeh.palettes.YlGn9[4], bokeh.palettes.YlGn9[6]), 
            cmap=(bokeh.palettes.Reds6[0],bokeh.palettes.Reds6[3],bokeh.palettes.Reds6[5]),
            xlabel='',
            show_legend=False,
            xticks=None, 
            fill_alpha=0.8, 
            xrotation=45
        )
        
        overlays.append(points)
        
        # Violin opts
#         violin_opts = {'violin_color':'model', 'cmap':'glasbey_cool', 'inner':None, 'violin_fill_alpha':0.35, 'violin_line_width':0.2, 'violin_width':0.6, 'width':450, 'yaxis':None}
        violin_opts = {'inner':None, 'violin_fill_alpha':0.35, 'violin_line_width':0.2, 'violin_width':0.6, 'width':450, 'yaxis':None}        
        
        # Change some of the opts if the last plot
        if i+1 == len(splits):
            violin_opts['show_legend']=True
            violin_opts['legend_position']='right'
    
        # Violin plot
        violin = hv.Violin(
            data=temp_data,
            vdims=['test_rho'],
            kdims=['model'],
        ).opts(
            **opts,
            **violin_opts,
        )
        
        # Scatter opts
        scatter_opts = {'fill_color':'gray', 'line_color':'black', 'jitter':0.25, 'size':6, 'fill_alpha':1, 'line_width':.8, 'width':450, 'yaxis':None}
        
        # Scatter plot
        scatter = hv.Scatter(
            data=temp_data,
            vdims=['test_rho'],
            kdims=['model'],
        ).opts(
            **opts, 
            **scatter_opts,
        )
        
        # Mean opts
        mean_opts = {'marker':'dash', 'size':30, 'color':'black', 'line_width':3, 'width':450, 'yaxis':None}
        
        # Mean plot split
        mean = hv.Scatter(
            data=temp_data.groupby(['model']).mean().reset_index(),
            vdims=['test_rho'],
            kdims=['model'],
        ).opts(
            **opts, 
            **mean_opts,
        )
        
        # DataFrame for Er
        error_data = temp_data.groupby(['model'])['test_rho'].agg(bootstrap_samples).reset_index()
        error_data = error_data.apply(lambda row: (f"{row['model']}", row['test_rho'][0], row['test_rho'][1], row['test_rho'][2]), axis=1)
        error_data = error_data.to_list()
        
        errors = hv.ErrorBars(
            data=error_data,
        ).opts(line_width=5,line_color='black',line_alpha=1, width=450)
        
        overlays.append(violin*mean*scatter*errors)
    
    layout = hv.Layout(overlays).cols(6)
    return layout



In [44]:
aav_plot = plot_overlay(aav, ['one_vs_many', 'two_vs_many', 'seven_vs_many', 'low_vs_high','sampled'], aav_controls)

aav_plot

In [155]:
gb1_plot = plot_overlay(gb1, ['one_vs_rest', 'two_vs_rest', 'three_vs_rest', 'low_vs_high','sampled'], gb1_controls)

gb1_plot

In [157]:
meltome_plot = plot_overlay(meltome, ['mixed_split', 'human', 'human_cell'], meltome_controls)

meltome_plot

In [32]:
def plot_scatter_overlay(data, splits):
    
    overlays = []

    for i,split in enumerate(splits):
        
        # General opts
        opts = {'height':700, 'width':700, 'xrotation':45, 'ylim':(0,1), 'xlabel':split}

        # Scatter opts
        scatter_opts = {'fill_color':'model', 'cmap':'glasbey_cool', 'line_color':'black', 'jitter':0.25, 'size':6, 'fill_alpha':1, 'line_width':.4}

        # Mean opts
        mean_opts = {'marker':'dash', 'size':30, 'color':'black', 'line_width':3}
        
        if i != 0:
            opts['yaxis']=None
        
        # Change some of the opts if the last plot
        if i+1 == len(splits):
            opts['width'] = 850
            scatter_opts['show_legend']=True
            scatter_opts['legend_position']='right'
        else:
            scatter_opts['show_legend']=False
            
        temp_data = data.loc[data['split'] == split]
        
        # Scatter plot
        scatter = hv.Scatter(
            data=temp_data,
            vdims=['test_rho'],
            kdims=['model'],
        ).opts(
            **opts, 
            **scatter_opts,
        )
        
        # Mean plot split
        mean = hv.Scatter(
            data=temp_data.groupby(['model']).mean().reset_index(),
            vdims=['test_rho'],
            kdims=['model'],
        ).opts(
            **opts, 
            **mean_opts,
        )
        
        # DataFrame for Er
        error_data = temp_data.groupby(['model'])['test_rho'].agg(bootstrap_samples).reset_index()
        error_data = error_data.apply(lambda row: (f"{row['model']}", row['test_rho'][0], row['test_rho'][1], row['test_rho'][2]), axis=1).to_list()
        
        errors = hv.ErrorBars(
            data=error_data,
        ).opts(line_width=5,line_color='black',line_alpha=1)
        
        overlays.append(mean*scatter*errors)
    
    layout = hv.Layout(overlays)
    return layout

In [33]:
aav_plot = plot_scatter_overlay(aav, ['one_vs_many', 'two_vs_many', 'seven_vs_many', 'low_vs_high','sampled'])

aav_plot

In [34]:
def just_plot_scatter(data, splits):
    
    overlays = []

    for i,split in enumerate(splits):
        
        # General opts
        opts = {'height':700, 'width':700, 'xrotation':45, 'ylim':(0,1), 'xlabel':split}

        # Scatter opts
        scatter_opts = {'fill_color':'model', 'cmap':'glasbey_cool', 'line_color':'black', 'jitter':0.25, 'size':10, 'fill_alpha':.8, 'line_width':.4}
        
        if i != 0:
            opts['yaxis']=None
        
        # Change some of the opts if the last plot
        if i+1 == len(splits):
            opts['width'] = 850
            scatter_opts['show_legend']=True
            scatter_opts['legend_position']='right'
        else:
            scatter_opts['show_legend']=False
            
        temp_data = data.loc[data['split'] == split]
        
        # Scatter plot
        scatter = hv.Scatter(
            data=temp_data,
            vdims=['test_rho'],
            kdims=['model'],
        ).opts(
            **opts, 
            **scatter_opts,
        )
        
        overlays.append(scatter)
    
    layout = hv.Layout(overlays)
    return layout

In [35]:
aav_plot_scatter = just_plot_scatter(aav, ['one_vs_many', 'two_vs_many', 'seven_vs_many', 'low_vs_high','sampled'])

aav_plot_scatter