In [1]:
import wandb
import plotly.express as px 
import plotly.graph_objects as go
import numpy as np 
import pandas as pd 
from tqdm import tqdm
import matplotlib.pyplot as plt 
from pytorch_lightning.loggers import WandbLogger

In [2]:
api = wandb.Api()

In [3]:
runs = [
    run for run in api.runs('jlehrer1/Ablation Study, Mouse') if run.name[-2] == '0' and run.state == 'finished'
]

In [4]:
names = np.array([run.name for run in runs])

names

array(['mouse_proportion=0.06', 'mouse_proportion=0.07',
       'mouse_proportion=0.08', 'mouse_proportion=0.05',
       'mouse_proportion=0.09', 'mouse_proportion=0.03',
       'mouse_proportion=0.07', 'mouse_proportion=0.03',
       'mouse_proportion=0.05', 'mouse_proportion=0.08',
       'mouse_proportion=0.09', 'mouse_proportion=0.06',
       'mouse_proportion=0.07', 'mouse_proportion=0.04',
       'mouse_proportion=0.05', 'mouse_proportion=0.08',
       'mouse_proportion=0.09', 'mouse_proportion=0.08',
       'mouse_proportion=0.09', 'mouse_proportion=0.04',
       'mouse_proportion=0.03', 'mouse_proportion=0.05',
       'mouse_proportion=0.07'], dtype='<U21')

In [10]:
import plotly
import seaborn as sns

def plot(runs, title):
    loss, wacc, bacc, f1 = pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

    for run in tqdm(runs):
        r = run.scan_history(keys=[
            'val_loss_epoch',
            'val_weighted_accuracy',
            'val_balanced_accuracy',
            'val_median_f1'
        ])

        df = pd.DataFrame(r)
        if len(df) > 50:
            df = df.loc[0:100, :]

            loss[run.name] = df['val_loss_epoch']
            wacc[run.name] = df['val_weighted_accuracy']
            bacc[run.name] = df['val_balanced_accuracy']
            f1[run.name] = df['val_median_f1']
            
    loss = loss.sort_index(axis=1)
    f1 = f1.sort_index(axis=1)
    wacc = wacc.sort_index(axis=1)
    bacc = bacc.sort_index(axis=1)
    
    scolors = sns.color_palette("rocket", as_cmap=True)
    colors = plotly.colors.n_colors('rgb(30,129,176)', 'rgb(255, 0, 0)', loss.shape[1], colortype = 'rgb')
    
    # VALIDATION LOSS PLOT 
    fig = go.Figure(
        layout=go.Layout(
            title='Validation Loss For Ablative Models',
            xaxis=dict(title='Epoch'),
            yaxis=dict(title='Loss'),
            font_family="Serif",
        )
    )

    for i, col in enumerate(loss):
        fig.add_trace(
            go.Scatter(
                x=loss.index, 
                y=loss[col], 
                name=f"Proportion={col.split('=')[1]}", 
                marker=dict(color=colors[i])
            )
        )


    fig.update_layout(legend=dict(
        yanchor="top",
        xanchor="right",
    ))

    fig.write_image(f'../../ms-thesis/images/ablation/loss_{title}.pdf', scale=3)
    fig.show()
    
    # MEDIAN F1 PLOT 
    fig = go.Figure(
        layout=go.Layout(
                title='Median F1 Score For Ablative Models',
                xaxis=dict(title='Epoch'),
                yaxis=dict(title='Median F1'),
                font_family="Serif",
            )
    )

    for i, col in enumerate(loss):
        fig.add_trace(
            go.Scatter(
                x=f1.index, 
                y=f1[col], 
                name=f"Proportion={col.split('=')[1]}", 
                marker=dict(color=colors[i])
            )
        )

    fig.update_layout(legend=dict(
        yanchor="top",
        xanchor="right",
    ))

    fig.write_image(f'../../ms-thesis/images/ablation/mf1_{title}.pdf', scale=3)
    fig.show()
        
    # BAR CHART: AVERAGE OF MEDIAN OVER FINAL EPOCHS 
    avg_f1 = f1.apply(lambda x: x[~x.isnull()][-10:].mean(), axis=0)
    xs = [x.split('=')[1] for x in avg_f1.index]
    
    fig = go.Figure(
        data=go.Bar(
            x=xs, y=avg_f1.values, text=avg_f1.values.round(2), textposition='auto',
        ),
        layout=go.Layout(
            title='Average of Median F1 over 10 Final Epochs (Validation Set)',
            xaxis=dict(title='Proportion'),
            yaxis=dict(title='Median F1'),
            font_family="Serif",
        )
    )

    fig.write_image(f'../../ms-thesis/images/ablation/final_mf1_{title}.pdf', scale=3)
    fig.show()
    
    # WEIGHTED ACCURACY
    fig = go.Figure(
    layout=go.Layout(
            title='Weighted Accuracy For Ablative Models',
            xaxis=dict(title='Epoch'),
            yaxis=dict(title='Loss'),
            font_family="Serif",
        )
    )

    for i, col in enumerate(loss):
        fig.add_trace(
            go.Scatter(
                x=wacc.index, 
                y=wacc[col], 
                name=f"Proportion={col.split('=')[1]}", 
                marker=dict(color=colors[i])
            )
        )

    fig.update_layout(legend=dict(
        yanchor="top",
        xanchor="right",
    ))

    fig.write_image(f'../../ms-thesis/images/ablation/weighted_acc_{title}.pdf', scale=3)
    fig.show()
    
    avg_acc = wacc.apply(lambda x: x[~x.isnull()][-10:].mean(), axis=0)
    xs = [x.split('=')[1] for x in avg_acc.index]
    
    fig = go.Figure(
        data=go.Bar(
            x=xs, y=avg_acc.values, text=avg_acc.values.round(2), textposition='auto',
        ),
        layout=go.Layout(
            title='Average of Weighted Accuracy over 10 Final Epochs (Validation Set)',
            xaxis=dict(title='Epoch'),
            yaxis=dict(title='Weighted Accuracy'),
            font_family="Serif",
        )
    )

    fig.write_image(f'../../ms-thesis/images/ablation/final_wacc_{title}.pdf', scale=3)
    fig.show()

In [11]:
plot(runs, 'ablation_4')

1


In [17]:
runs = [run for run in api.runs('jlehrer1/Ablation Study, ALL') if run.name[-2] == '0']
[run.name for run in runs]

['Proportion=0.04',
 'Proportion=0.09',
 'Proportion=0.07',
 'Proportion=0.08',
 'Proportion=0.06',
 'Proportion=0.05',
 'Proportion=0.03',
 'Proportion=0.02',
 'Proportion=0.01',
 'Proportion=0.01']

In [18]:
plot(runs, 'ablation_2')

100%|████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.63it/s]
