In [65]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [1]:
import wandb
import pandas as pd
import plotly.graph_objects as go
api = wandb.Api()

In [147]:
PROJECT = 'feifang24/mtl-uncertainty-final'
TASKS = {'SST-2':'sst', 'MRPC':'mrpc','RTE':'rte'}
SPLITS = ['train', 'dev', 'test']
AGG_METRICS = '{split}_loss'
METRIC_FORMATS = {
                'loss': '{split}_loss_by_task/{task}', 
                'uncertainty': '{split}_uncertainty_by_task/{task}',
                'acc': '{task}/{split}_ACC', # {dev, test} 
                'auc': '{task}/{split}_AUC', 
                'f1': '{task}/{split}_F1',
                'task_weight': 'task_weight/{task}'
               }

In [3]:
from collections import defaultdict
config_run_ids = defaultdict(list)
for run in api.runs(path=PROJECT):
    config_run_ids[run.name[:-2]].append(run.id)

In [82]:
config_metrics_df = {}
for config_name, run_ids in config_run_ids.items():
    runs = [api.run(f'{PROJECT}/{run_id}') for run_id in run_ids]
    avg_run = pd.concat([run.history() for run in runs])
    avg_run.index = [i * 200 for i in avg_run.index]
    #avg_run = pd.concat([avg_run[avg_run.index < 200], avg_run[avg_run.index % 200 == 0]])
    avg_run = avg_run.groupby(level=0).mean()
    config_metrics_df[config_name] = avg_run

In [83]:
list(config_run_ids.keys())

['sampling-smoothed-r=0.375',
 'sampling-smoothed-r=0.125',
 'sampling-smoothed-r=0.25',
 'sampling-raw',
 'sampling-smoothed-r=0.5',
 'baseline-uniform',
 'baseline-data-dist']

In [19]:
config_metrics_df['sampling-smoothed-r=0.125'].columns

Index(['train_loss_by_task/rte', 'dev_loss', 'mrpc/test_F1', 'sst/test_AUC',
       'test_uncertainty_by_task/mrpc', '_step', 'mrpc/dev_F1',
       'dev_uncertainty_by_task/sst', 'train_loss', 'sst/test_ACC',
       'task_weight/sst', '_runtime', 'test_uncertainty_by_task/rte',
       'test_loss_by_task/rte', 'test_loss_by_task/mrpc', 'mrpc/dev_ACC',
       'sst/dev_F1', 'dev_loss_by_task/mrpc', 'mrpc/dev_AUC', 'rte/dev_AUC',
       'mrpc/test_ACC', 'task_weight/mrpc', 'dev_loss_by_task/sst',
       'sst/test_F1', 'rte/test_F1', 'sst/dev_ACC', 'rte/dev_ACC',
       'dev_uncertainty_by_task/rte', 'task_weight/rte', 'rte/test_ACC',
       'train_loss_by_task/sst', 'rte/test_AUC', 'train_loss_by_task/mrpc',
       'dev_loss_by_task/rte', 'test_loss', 'test_uncertainty_by_task/sst',
       'test_loss_by_task/sst', 'sst/dev_AUC', 'rte/dev_F1', '_timestamp',
       'mrpc/test_AUC', 'dev_uncertainty_by_task/mrpc'],
      dtype='object')

In [214]:
import plotly.express as px
def plot_metric_over_time(split, metric, configs, plot_minima=False, plot_maxima=False):
    fig = go.Figure()
    metric_format = METRIC_FORMATS[metric]
    colors = px.colors.qualitative.Plotly
    dashes = ['dash', 'dot', 'dashdot']
    minima_x = {config: [] for config in configs}
    minima_y = {config: [] for config in configs}
    maxima_x = {config: [] for config in configs}
    maxima_y = {config: [] for config in configs}
    for j, config in enumerate(configs):
        for i, task in enumerate(TASKS.values()):
            m = config_metrics_df[config][metric_format.format(split=split, task=task)]
            fig.add_trace(go.Scatter(x=config_metrics_df[config].index, y=m,
                                mode='lines',
                                line=dict(color=colors[j], dash=dashes[i]),
                                name=f'{config}/{task}'))
            minima_x[config].append(m.idxmin())
            minima_y[config].append(m.min())
            maxima_x[config].append(m.idxmax())
            maxima_y[config].append(m.max())
        if metric == 'loss':
            m = config_metrics_df[config][AGG_METRICS.format(split=split)]
            fig.add_trace(go.Scatter(x=config_metrics_df[config].index, y=m,
                                mode='lines',
                                line=dict(color=colors[j]),# dash=dashes[i]),
                                name=f'{config}/aggregate'))
            minima_x[config].append(m.idxmin())
            minima_y[config].append(m.min())
    if plot_minima:
        for i, config in enumerate(configs):
            fig.add_trace(go.Scatter(x=minima_x[config], y=minima_y[config], 
                                     mode='markers', 
                                     marker=dict(color=colors[i]),
                                     name=f'{config} minimum'))
    if plot_maxima:
        for i, config in enumerate(configs):
            fig.add_trace(go.Scatter(x=maxima_x[config], y=maxima_y[config], 
                                     mode='markers', 
                                     marker=dict(color=colors[i]),
                                     name=f'{config} maximum'))
    fig.update_layout(
        yaxis_title=f'{split} {metric}'.title(),
        xaxis_title="Iteration",
        legend_title="Method/Task",
        font=dict(size=14)
    )
    fig.show()

In [209]:
plot_metric_over_time('dev', 'loss', 
                      [
                       'sampling-smoothed-r=0.5',
#                        'sampling-smoothed-r=0.375',
#                        'sampling-smoothed-r=0.25',
#                        'sampling-smoothed-r=0.125',
#                        'sampling-raw'
                      ],
                      plot_minima=True)

In [332]:
plot_metric_over_time('dev', 'loss', 
                      [
                       'sampling-smoothed-r=0.5',
                       'sampling-smoothed-r=0.375',
                       #'sampling-smoothed-r=0.25',
                       'sampling-smoothed-r=0.125',
                       'sampling-raw'
                      ],
                      plot_minima=True)

In [268]:
plot_metric_over_time('dev', 'uncertainty', 
                     [
#                       'sampling-smoothed-r=0.375',
                       'sampling-smoothed-r=0.5',
#                        'baseline-uniform',
#                        'baseline-data-dist'
                     ],
                      plot_minima=False)

In [24]:
plot_metric_over_time('train', 'loss', ['sampling-smoothed-r=0.375', 'sampling-smoothed-r=0.125', 'sampling-smoothed-r=0.5', 'sampling-smoothed-r=0.25'], plot_minima=False)

## Loss vs task weight 

In [341]:
# plotly.subplots.make_subplots(rows=1, cols=1, shared_xaxes=False, shared_yaxes=False, 
#                               start_cell='top-left', print_grid=False, 
#                               horizontal_spacing=None, vertical_spacing=None, 
#                               subplot_titles=None, column_widths=None, row_heights=None, 
#                               specs=None, insets=None, 
#                               column_titles=None, row_titles=None, x_title=None, y_title=None, figure=None, **kwargs)


from plotly.subplots import make_subplots
def plot_heatmap(configs, config_titles, zmin=0.17, zmax=0.57):
    fig = make_subplots(rows=3, cols=len(configs),shared_xaxes=True, row_heights=[10, 5,5],
                        column_titles=config_titles,
                       x_title='Number of iterations')
    colors = ['cadetblue', 'coral', 'olive', 'purple']
    minima_x = {config: [] for config in configs}
    minima_y = {config: [] for config in configs}
    for idx, config in enumerate(configs):
        x = config_metrics_df[config].index
        for i, (task_name, task) in enumerate(TASKS.items()):
            m = config_metrics_df[config][METRIC_FORMATS['loss'].format(split='dev', task=task)]
            fig.add_trace(go.Scatter(x=x,
                                     y=m,
                                     mode='lines',
                                     line=dict(color=colors[i]),
                                     name=task_name, showlegend=idx==0), row=1, col=idx+1)
            minima_x[config].append(m.idxmin())
            minima_y[config].append(m.min())
        # loss
        m = config_metrics_df[config][AGG_METRICS.format(split='dev')]
        fig.add_trace(go.Scatter(x=config_metrics_df[config].index, y=m,
                            mode='lines',
                            line=dict(color=colors[-1]),# dash=dashes[i]),
                            name=f'Aggregate', showlegend=idx==0), row=1, col=idx+1)
        minima_x[config].append(m.idxmin())
        minima_y[config].append(m.min())
        # plot minimum for each curve
        fig.add_trace(go.Scatter(x=minima_x[config], y=minima_y[config], 
                                 mode='markers', 
                                 marker=dict(color=colors, size=9),
                                 name=f'{config} minimum', showlegend=False), row=1, col=idx+1)
        # draw vertical line that goes thru min aggregate loss
        fig.add_trace(go.Scatter(x=[minima_x[config][-1]] * 2, y=[.2,.9], 
                                 mode='lines', 
                                 line=dict(color=colors[-1], dash='dash'),
                                 showlegend=False), row=1, col=idx+1)
        
        # task weight
        if 'uniform' in config:
            weight = 1 / len(TASKS)
            z = [[weight for _ in range(len(x) - 1)] for task in TASKS]
        elif 'data-dist' in config:
            task_samples = {'SST-2':8000, 'RTE':2490, 'MRPC':3668}
            total_samples = sum(task_samples.values())
            task_weights = {key: val/total_samples for key, val in task_samples.items()}
            z = [[task_weights[task] for _ in range(len(x) - 1)] for task in TASKS.keys()]
        else:
            z = [[1/len(TASKS)] + config_metrics_df[config][METRIC_FORMATS['task_weight'].format(task=task)].values.tolist()[1:-1] for task_name, task in TASKS.items()]
        fig.add_trace(go.Heatmap(
                           x=x,
                           z=z,
                           zmin=zmin,
                           zmax=zmax,
                           y=list(TASKS.keys()),
                           hoverongaps = False,
                           colorscale='blues', ygap=1, xgap=1, colorbar=go.heatmap.ColorBar(title='Probability',len=0.5, yanchor='middle')), row=2, col=idx+1)
        #fig.update_traces(colorbar_len=0.5, colorbar_yanchor="top", selector=dict(type='heatmap'))
        
        fig.add_trace(go.Heatmap(
                           x=x,
                           z=[[1] + config_metrics_df[config][METRIC_FORMATS['uncertainty'].format(task=task, split='dev')].values.tolist()[1:-1] for task_name, task in TASKS.items()],
                           zmin=1,
                           zmax=3,
                           y=list(TASKS.keys()),
                           hoverongaps = False,
                           colorscale='reds', ygap=1, xgap=1, colorbar=go.heatmap.ColorBar(title='Uncertainty',len=0.5, y=0)), row=3, col=idx+1)
        #fig.update_traces(colorbar_len=0.5, colorbar_yanchor="top", selector=dict(type='heatmap'))
        
#     fig.update_layout(
#         xaxis_title="Iteration",
#         title=f'dev loss{config}'
#     )
    fig.update_xaxes(nticks=20)
    fig.show()

In [342]:
plot_heatmap(['baseline-uniform', 'baseline-data-dist', 'sampling-smoothed-r=0.5'],
             ['(a) Uniform Sampling\n(baseline)', '(b) Union Sampling\n(baseline)', '(c) Uncertainty-based\nActive Sampling'])

In [346]:
plot_heatmap(['sampling-raw', 'sampling-smoothed-r=0.25', 'sampling-smoothed-r=0.5'], None, zmin=0.28, zmax=0.4)

## Downstream metrics

In [323]:
config_metrics_df['baseline-uniform'].columns

Index(['train_loss_by_task/rte', 'dev_loss', 'mrpc/test_F1', 'sst/test_AUC',
       'test_uncertainty_by_task/mrpc', '_step', 'mrpc/dev_F1',
       'dev_uncertainty_by_task/sst', 'train_loss', 'sst/test_ACC', '_runtime',
       'test_uncertainty_by_task/rte', 'test_loss_by_task/rte',
       'test_loss_by_task/mrpc', 'mrpc/dev_ACC', 'sst/dev_F1',
       'dev_loss_by_task/mrpc', 'mrpc/dev_AUC', 'rte/dev_AUC', 'mrpc/test_ACC',
       'dev_loss_by_task/sst', 'sst/test_F1', 'rte/test_F1', 'sst/dev_ACC',
       'rte/dev_ACC', 'dev_uncertainty_by_task/rte', 'rte/test_ACC',
       'train_loss_by_task/sst', 'rte/test_AUC', 'train_loss_by_task/mrpc',
       'dev_loss_by_task/rte', 'test_loss', 'test_uncertainty_by_task/sst',
       'test_loss_by_task/sst', 'sst/dev_AUC', 'rte/dev_F1', 'global_step',
       '_timestamp', 'mrpc/test_AUC', 'dev_uncertainty_by_task/mrpc'],
      dtype='object')

In [324]:
dev_summary_dfs = {}
test_summary_dfs = {}
for config, metrics_df in config_metrics_df.items():
    dev_df = metrics_df[[c for c in metrics_df.columns if 'dev' in c]]
    test_df = metrics_df[[c for c in metrics_df.columns if 'test' in c]]
    dev_summary_df = dev_df[sorted([c for c in dev_df.columns if 'AUC' in c or 'F1' in c or 'ACC' in c or 'uncertainty' in c])]
    test_summary_df = test_df[sorted([c for c in test_df.columns if 'AUC' in c or 'F1' in c or 'ACC' in c or 'uncertainty' in c])]
    dev_summary_dfs[config] = dev_summary_df
    test_summary_dfs[config] = test_summary_df

In [325]:
# best points manually chosen based on plots below
best_checkpoints = {
                    'union': ('baseline-data-dist', 2400),
                    'uniform': ('baseline-uniform', 2400),
                    'active': ('sampling-smoothed-r=0.5', 2200)
                   }

In [329]:
METRICS = ['ACC', 'AUC', 'F1']
dev_summaries = {}
test_summaries = {}
for name, (run_name, step) in best_checkpoints.items():
    dev_summary = dev_summary_dfs[run_name].loc[step].to_dict()
    test_summary = test_summary_dfs[run_name].loc[step].to_dict()
    for m in METRICS:
        dev_by_task = [val for key, val in dev_summary.items() if m in key]
        test_by_task = [val for key, val in test_summary.items() if m in key]
        dev_summary[f'aggregate/dev_{m}'] = sum(dev_by_task) / len(dev_by_task)
        test_summary[f'aggregate/test_{m}'] = sum(test_by_task) / len(test_by_task)
        
    for key, val in dev_summary.items():
        dev_summary[key] = round(val,2)

    for key, val in test_summary.items():
        test_summary[key] = round(val,2)
    dev_summaries[name] = dev_summary
    test_summaries[name] = test_summary

In [330]:
import json
print(json.dumps(dev_summaries, indent=4))

{
    "union": {
        "dev_uncertainty_by_task/mrpc": 2.58,
        "dev_uncertainty_by_task/rte": 2.28,
        "dev_uncertainty_by_task/sst": 2.27,
        "mrpc/dev_ACC": 83.17,
        "mrpc/dev_AUC": 88.37,
        "mrpc/dev_F1": 88.21,
        "rte/dev_ACC": 66.18,
        "rte/dev_AUC": 72.02,
        "rte/dev_F1": 68.61,
        "sst/dev_ACC": 90.97,
        "sst/dev_AUC": 96.37,
        "sst/dev_F1": 91.34,
        "aggregate/dev_ACC": 80.11,
        "aggregate/dev_AUC": 85.59,
        "aggregate/dev_F1": 82.72
    },
    "uniform": {
        "dev_uncertainty_by_task/mrpc": 2.42,
        "dev_uncertainty_by_task/rte": 1.85,
        "dev_uncertainty_by_task/sst": 2.41,
        "mrpc/dev_ACC": 83.17,
        "mrpc/dev_AUC": 88.41,
        "mrpc/dev_F1": 88.61,
        "rte/dev_ACC": 67.15,
        "rte/dev_AUC": 72.46,
        "rte/dev_F1": 71.06,
        "sst/dev_ACC": 90.37,
        "sst/dev_AUC": 96.24,
        "sst/dev_F1": 90.84,
        "aggregate/dev_ACC": 80.23,
     

In [331]:
print(json.dumps(test_summaries, indent=4))

{
    "union": {
        "mrpc/test_ACC": 78.76,
        "mrpc/test_AUC": 85.96,
        "mrpc/test_F1": 85.12,
        "rte/test_ACC": 62.11,
        "rte/test_AUC": 69.35,
        "rte/test_F1": 62.97,
        "sst/test_ACC": 91.2,
        "sst/test_AUC": 96.72,
        "sst/test_F1": 92.18,
        "test_uncertainty_by_task/mrpc": 2.54,
        "test_uncertainty_by_task/rte": 2.48,
        "test_uncertainty_by_task/sst": 2.28,
        "aggregate/test_ACC": 77.36,
        "aggregate/test_AUC": 84.01,
        "aggregate/test_F1": 80.09
    },
    "uniform": {
        "mrpc/test_ACC": 79.74,
        "mrpc/test_AUC": 85.02,
        "mrpc/test_F1": 86.28,
        "rte/test_ACC": 65.47,
        "rte/test_AUC": 69.96,
        "rte/test_F1": 69.06,
        "sst/test_ACC": 90.03,
        "sst/test_AUC": 96.07,
        "sst/test_F1": 91.23,
        "test_uncertainty_by_task/mrpc": 2.46,
        "test_uncertainty_by_task/rte": 2.21,
        "test_uncertainty_by_task/sst": 2.31,
        "aggreg

In [216]:
plot_metric_over_time('dev', 'auc', 
                      [
                       'baseline-uniform',

                      ],
                      plot_maxima=True)

In [217]:
plot_metric_over_time('dev', 'auc', 
                      ['baseline-data-dist'],
                      plot_maxima=True)

In [219]:
plot_metric_over_time('dev', 'auc', 
                      ['sampling-smoothed-r=0.5'],
                      plot_maxima=True)

In [338]:
plot_metric_over_time('dev', 'auc', 
                      ['sampling-raw'],
                      plot_maxima=True)