In [None]:
import wandb
import pandas as pd
import plotly.graph_objects as go
api = wandb.Api()

In [None]:
PROJECT = 'feifang24/mtl-uncertainty-final'
TASKS = ['sst', 'rte', 'mrpc']
SPLITS = ['train', 'dev', 'test']
AGG_METRICS = ['{split}_loss']
METRIC_FORMATS = {
                'loss': '{split}_loss_by_task/{task}', 
                'uncertainty': '{split}_uncertainty_by_task/{task}',
                'acc': '{task}/{split}_ACC', # {dev, test} 
                'auc': '{task}/{split}_AUC', 
                'f1': '{task}/{split}_F1'
               }

In [None]:
from collections import defaultdict
config_run_ids = defaultdict(list)
for run in api.runs(path=PROJECT):
    config_run_ids[run.name[:-2]].append(run.id)

In [None]:
runs = [api.run(f'{project}/{run_id}') for run_id in run_ids]

In [None]:
config_run_ids

In [None]:
config_metrics_df = {}
for config_name, run_ids in config_run_ids.items():
    runs = [api.run(f'{project}/{run_id}') for run_id in run_ids]
    avg_run = pd.concat([run.history() for run in runs])
    avg_run = avg_run.groupby(level=0).mean()
    config_metrics_df[config_name] = avg_run

In [None]:
config_metrics_df['sampling-smoothed-r=0.125'].columns

In [None]:
import plotly.express as px
def plot_metric_over_time(split, metric, configs, plot_minima=False):
    fig = go.Figure()
    metric_format = METRIC_FORMATS[metric]
    colors = px.colors.qualitative.Plotly
    dashes = ['dash', 'dot', 'dashdot']
    minima_x = {config: [] for config in configs}
    minima_y = {config: [] for config in configs}
    for i, task in enumerate(TASKS):
        for j, config in enumerate(configs):
            m = config_metrics_df[config][metric_format.format(split=split, task=task)]
            fig.add_trace(go.Scatter(x=config_metrics_df[config].index, y=m,
                                mode='lines',
                                line=dict(color=colors[j], dash=dashes[i]),
                                name=f'{config}/{task}'))
            minima_x[config].append(m.idxmin())
            minima_y[config].append(m.min())
    if plot_minima:
        for i, config in enumerate(configs):
            fig.add_trace(go.Scatter(x=minima_x[config], y=minima_y[config], 
                                     mode='markers', 
                                     marker=dict(color=colors[i]),
                                     name=f'{config} minimum'))
    fig.update_layout(
        yaxis_title=f'{split} {metric}'.title(),
        xaxis_title="Iteration",
        legend_title="Method/Task",
        font=dict(size=14)
    )
    fig.show()

In [None]:
plot_metric_over_time('dev', 'loss', ['sampling-smoothed-r=0.375', 'sampling-smoothed-r=0.125', 'sampling-smoothed-r=0.5', 'sampling-smoothed-r=0.25'], plot_minima=True)

In [None]:
plot_metric_over_time('train', 'loss', ['sampling-smoothed-r=0.375', 'sampling-smoothed-r=0.125', 'sampling-smoothed-r=0.5', 'sampling-smoothed-r=0.25'], plot_minima=False)

# Uncertainty Heatmap

In [None]:
from plotly.subplots import make_subplots
def plot_heatmap(config):
    fig = make_subplots(rows=2, cols=1)
    colors = ['cadetblue', 'coral', 'olive']
    for i, task in enumerate(TASKS):
        m = config_metrics_df[config][METRIC_FORMATS['auc'].format(split='dev', task=task)]
        fig.add_trace(go.Scatter(x=config_metrics_df[config].index, y=m,
                                        mode='lines',
                                        line=dict(color=colors[i]),
                                        name=f'{task}'), row=1, col=1)
    fig.add_trace(go.Heatmap(
                       x=list(range(13)),
                       z=[config_metrics_df[config][METRIC_FORMATS['uncertainty'].format(split='dev', task=task)].values for task in TASKS],
                       y=TASKS,
                       hoverongaps = False,
                       colorscale='blues'), row=2, col=1)
    fig.update_traces(colorbar_len=0.5, colorbar_yanchor="top", selector=dict(type='heatmap'))
    fig.update_layout(
        xaxis_title="Iteration",
        title=f'{config} AUC'
    )
    fig.show()

In [None]:
plot_heatmap('baseline-uniform')

In [None]:
plot_heatmap('baseline-data-dist')

In [None]:
plot_heatmap('sampling-raw')

In [None]:
plot_heatmap('sampling-smoothed-r=0.375')

## Downstream metrics

In [None]:
for config, metrics_df in config_metrics_df.items():
    print(config, end = " & ")
    for task in TASKS:
        for metric in ['auc', 'acc', 'f1']:
            dev_metric = config_metrics_df[config][METRIC_FORMATS[metric].format(split='dev', task=task)]
            test_metric = config_metrics_df[config][METRIC_FORMATS[metric].format(split='test', task=task)]
            print(round(test_metric.values[dev_metric.idxmax()], 2), end=" & ")

In [None]:
for config, metrics_df in config_metrics_df.items():
    print(config)
    for metric in ['auc', 'acc', 'f1']:
        for task in TASKS:
            dev_metric = config_metrics_df[config][METRIC_FORMATS[metric].format(split='dev', task=task)]
            print(f"{round(dev_metric.max(), 2)}")