In [1]:
import wandb
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Loading the data

In [2]:
api = wandb.Api(timeout=30)

entity = "mosaic-ml"
project = "paper-mlm-schedule"

min_perf = 0.3

In [3]:
task_to_metric = {
    "cola": "metrics/glue_cola/MulticlassMatthewsCorrCoef",
    "mnli": ["metrics/glue_mnli/MulticlassAccuracy", "metrics/glue_mnli_mismatched/MulticlassAccuracy"],
    "mrpc": ["metrics/glue_mrpc/BinaryF1Score", "metrics/glue_mrpc/MulticlassAccuracy"],
    "qnli": "metrics/glue_qnli/MulticlassAccuracy",
    "qqp": ["metrics/glue_qqp/BinaryF1Score", "metrics/glue_qqp/MulticlassAccuracy"],
    "rte": "metrics/glue_rte/MulticlassAccuracy",
    "sst-2": "metrics/glue_sst2/MulticlassAccuracy",
    "stsb": "metrics/glue_stsb/SpearmanCorrCoef"
}
tasks = [
    'mnli',
    'qnli',
    'qqp',
    'cola',
    'sst-2',
    'rte',
    'mrpc',
    'stsb',
]
tasks_formatted = [
    "cola/MulticlassMatthewsCorrCoef",
    "mnli/MulticlassAccuracy",
    "mnli_mismatched/MulticlassAccuracy",
    "mrpc/BinaryF1Score",
    "mrpc/MulticlassAccuracy",
    "qnli/MulticlassAccuracy",
    "qqp/BinaryF1Score",
    "qqp/MulticlassAccuracy",
    "rte/MulticlassAccuracy",
    "sst2/MulticlassAccuracy",
    "stsb/SpearmanCorrCoef"
]

In [43]:
min_task_acc = {
    "cola/MulticlassMatthewsCorrCoef": 45,
    "mnli/MulticlassAccuracy": 75,
    "mnli_mismatched/MulticlassAccuracy": 75,
    "mrpc/BinaryF1Score": 80,
    "mrpc/MulticlassAccuracy": 75,
    "qnli/MulticlassAccuracy": 80,
    "qqp/BinaryF1Score": 75,
    "qqp/MulticlassAccuracy": 80,
    "rte/MulticlassAccuracy": 65,
    "sst2/MulticlassAccuracy": 80,
    "stsb/SpearmanCorrCoef": 75,
}

In [4]:
def get_runs(run_ids=[]):
    run_lookup = {}
    runs = api.runs("mosaic-ml/paper-mlm-schedule", filters={
        "$and": [{'tags': "best-ckpt"}, {"tags": "glue"}]})
    for run in tqdm(runs):
        if run.state != "finished":
            continue
        
        run_name = run.name
        task = [s for s in run_name.split("_") if "task=" in s][0].split("=")[1]
        tags = run.tags
        group = run.group
        experiment_name = group.split("-og-seed-")[0]
        scheduler = experiment_name.split("-")[0]
        init_rate = [float(t.split("-")[1]) for t in tags if "initial" in t][0]
        final_rate = [float(t.split("-")[1]) for t in tags if "final" in t][0]
        experiment_name = f"{scheduler}-{init_rate}-{final_rate}"
        pretrain_seed = int(group.split("-og-seed-")[1])
        glue_seed = int(run.name.split("seed=")[-1])
                
        if task == 'sst2':
            task = 'sst-2'
        
        if task not in tasks:
            raise ValueError(f"Task {task} not recognized.")
        
        try:
            metric_names = task_to_metric[task]
            if isinstance(metric_names, str):
                metric_names = [metric_names]
            elif not isinstance(metric_names, list):
                raise Exception("Unsupported type for 'metric_name'")

            for metric_idx, metric_name in enumerate(metric_names):
                metric_hist = run.history(keys=[metric_name]).to_numpy()[:, 1:].mean(axis=1)
                final_metric = max(metric_hist)
                
                if final_metric < min_perf:
                    continue
                
                metric_task = metric_name.replace("metrics/glue_", "")
                run_lookup[run.id + f"-{metric_idx}"] = {
                    'task': metric_task,
                    'experiment_name': experiment_name,
                    'final_metric': 100 * final_metric,
                    'pretrain_seed': pretrain_seed,
                    "glue_seed": glue_seed,
                    'scheduler': scheduler,
                    "init_rate": init_rate,
                    "final_rate": final_rate
                }
        except:
            print(f"Error for run: {group} with id ({run.id})")
                
    return run_lookup
        


In [5]:
run_lookup = get_runs()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1440/1440 [15:05<00:00,  1.59it/s]


In [96]:
columns=["task", "experiment_name", "final_metric", "pretrain_seed",
         "glue_seed", "scheduler", "init_rate", "final_rate"]
results = []
for run_info in run_lookup.values():
    if run_info is None:
        continue
    if run_info["final_metric"] < min_task_acc[run_info["task"]]:
        continue
    results.append([run_info[c] for c in columns])
print(results[0])
base_df = pd.DataFrame(results, columns = columns)
base_df[["final_metric", 'pretrain_seed', "glue_seed", "init_rate", "final_rate"]] = base_df[["final_metric", 'pretrain_seed', "glue_seed", "init_rate", "final_rate"]].apply(pd.to_numeric)
base_df = base_df.sort_values(by=['experiment_name'], ascending=False)

base_df

['stsb/SpearmanCorrCoef', 'linear-0.3-0.25', 89.50774073600769, 17, 90166, 'linear', 0.3, 0.25]


Unnamed: 0,task,experiment_name,final_metric,pretrain_seed,glue_seed,scheduler,init_rate,final_rate
532,mnli_mismatched/MulticlassAccuracy,linear-0.3-0.45,84.896255,3047,717,linear,0.30,0.45
274,qqp/BinaryF1Score,linear-0.3-0.45,88.276368,3047,10536,linear,0.30,0.45
223,sst2/MulticlassAccuracy,linear-0.3-0.45,93.004584,17,90166,linear,0.30,0.45
230,sst2/MulticlassAccuracy,linear-0.3-0.45,91.972476,3047,10536,linear,0.30,0.45
231,sst2/MulticlassAccuracy,linear-0.3-0.45,92.545873,2048,10536,linear,0.30,0.45
...,...,...,...,...,...,...,...,...
1691,stsb/SpearmanCorrCoef,constant-0.15-0.15,89.078087,17,90166,constant,0.15,0.15
1687,stsb/SpearmanCorrCoef,constant-0.15-0.15,89.544421,3047,8364,constant,0.15,0.15
1806,qqp/BinaryF1Score,constant-0.15-0.15,88.270420,2048,717,constant,0.15,0.15
1807,qqp/MulticlassAccuracy,constant-0.15-0.15,91.254020,2048,717,constant,0.15,0.15


# Sanity check data

In [97]:
for task in tasks_formatted:
    print(task)
    print(base_df[base_df.task==task].experiment_name.value_counts())
    print(' ')

cola/MulticlassMatthewsCorrCoef
experiment_name
linear-0.3-0.45       15
linear-0.3-0.4        15
linear-0.3-0.35       15
linear-0.3-0.2        15
linear-0.3-0.15       15
constant-0.4-0.4      15
constant-0.35-0.35    15
constant-0.3-0.3      15
constant-0.25-0.25    15
constant-0.2-0.2      15
constant-0.15-0.15    15
linear-0.3-0.25       14
Name: count, dtype: int64
 
mnli/MulticlassAccuracy
experiment_name
linear-0.3-0.45       15
linear-0.3-0.35       15
linear-0.3-0.2        15
constant-0.4-0.4      15
constant-0.35-0.35    15
constant-0.3-0.3      15
constant-0.25-0.25    15
constant-0.2-0.2      15
constant-0.15-0.15    15
linear-0.3-0.4        14
linear-0.3-0.25       14
linear-0.3-0.15       14
Name: count, dtype: int64
 
mnli_mismatched/MulticlassAccuracy
experiment_name
linear-0.3-0.45       15
linear-0.3-0.35       15
linear-0.3-0.2        15
constant-0.4-0.4      15
constant-0.35-0.35    15
constant-0.3-0.3      15
constant-0.25-0.25    15
constant-0.2-0.2      15
const

In [98]:
base_df = base_df[base_df["glue_seed"] != 8364]
base_df = base_df[base_df["pretrain_seed"] != 2048]

In [99]:
metric_stand_err = base_df.groupby(["experiment_name", "task"])["final_metric"].sem().reset_index()
metric_stand_err
grouped_df = base_df.groupby(["experiment_name", "task"]).mean(numeric_only=True).reset_index()
grouped_df["error"] = metric_stand_err["final_metric"]
grouped_df = grouped_df.round({'final_metric': 2, 'error': 2})
grouped_df

Unnamed: 0,experiment_name,task,final_metric,pretrain_seed,glue_seed,init_rate,final_rate,error
0,constant-0.15-0.15,cola/MulticlassMatthewsCorrCoef,55.27,1532.0,25359.5,0.15,0.15,0.67
1,constant-0.15-0.15,mnli/MulticlassAccuracy,84.40,1532.0,25359.5,0.15,0.15,0.07
2,constant-0.15-0.15,mnli_mismatched/MulticlassAccuracy,84.75,1532.0,25359.5,0.15,0.15,0.04
3,constant-0.15-0.15,mrpc/BinaryF1Score,92.07,1532.0,25359.5,0.15,0.15,0.21
4,constant-0.15-0.15,mrpc/MulticlassAccuracy,88.88,1532.0,25359.5,0.15,0.15,0.34
...,...,...,...,...,...,...,...,...
127,linear-0.3-0.45,qqp/BinaryF1Score,88.29,1532.0,25359.5,0.30,0.45,0.05
128,linear-0.3-0.45,qqp/MulticlassAccuracy,91.31,1532.0,25359.5,0.30,0.45,0.03
129,linear-0.3-0.45,rte/MulticlassAccuracy,76.62,1532.0,25359.5,0.30,0.45,0.33
130,linear-0.3-0.45,sst2/MulticlassAccuracy,92.33,1532.0,25359.5,0.30,0.45,0.14


In [100]:
grouped_df.pivot(index="experiment_name", columns="task", values="final_metric")

task,cola/MulticlassMatthewsCorrCoef,mnli/MulticlassAccuracy,mnli_mismatched/MulticlassAccuracy,mrpc/BinaryF1Score,mrpc/MulticlassAccuracy,qnli/MulticlassAccuracy,qqp/BinaryF1Score,qqp/MulticlassAccuracy,rte/MulticlassAccuracy,sst2/MulticlassAccuracy,stsb/SpearmanCorrCoef
experiment_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
constant-0.15-0.15,55.27,84.4,84.75,92.07,88.88,90.33,88.35,91.34,76.4,92.88,89.38
constant-0.2-0.2,56.28,84.5,84.84,91.53,88.24,90.56,88.29,91.28,75.72,92.49,89.57
constant-0.25-0.25,57.05,84.25,84.78,91.96,88.97,90.5,88.29,91.3,75.5,92.55,89.86
constant-0.3-0.3,57.17,84.53,84.86,92.5,89.61,90.54,88.33,91.31,77.21,92.65,89.81
constant-0.35-0.35,56.23,84.51,85.02,91.81,88.63,90.74,88.33,91.32,78.7,92.85,89.94
constant-0.4-0.4,56.22,84.18,84.58,91.78,88.51,90.56,88.35,91.35,74.91,92.53,89.94
linear-0.3-0.15,59.23,84.62,85.25,91.85,88.76,90.81,88.35,91.35,75.9,92.86,89.89
linear-0.3-0.2,57.1,84.55,84.98,91.14,87.53,90.85,88.26,91.28,77.03,92.85,89.7
linear-0.3-0.25,57.2,84.56,85.0,86.96,89.95,90.81,88.38,91.38,75.99,92.98,88.51
linear-0.3-0.35,56.09,84.37,84.87,91.49,88.17,90.69,88.28,91.3,77.21,92.76,89.88


In [101]:
metrics_with_errors = []
for task_name in tasks_formatted:
    metrics_with_errors.append(task_name)
    metrics_with_errors.append(task_name + " STE")
metrics_with_errors

['cola/MulticlassMatthewsCorrCoef',
 'cola/MulticlassMatthewsCorrCoef STE',
 'mnli/MulticlassAccuracy',
 'mnli/MulticlassAccuracy STE',
 'mnli_mismatched/MulticlassAccuracy',
 'mnli_mismatched/MulticlassAccuracy STE',
 'mrpc/BinaryF1Score',
 'mrpc/BinaryF1Score STE',
 'mrpc/MulticlassAccuracy',
 'mrpc/MulticlassAccuracy STE',
 'qnli/MulticlassAccuracy',
 'qnli/MulticlassAccuracy STE',
 'qqp/BinaryF1Score',
 'qqp/BinaryF1Score STE',
 'qqp/MulticlassAccuracy',
 'qqp/MulticlassAccuracy STE',
 'rte/MulticlassAccuracy',
 'rte/MulticlassAccuracy STE',
 'sst2/MulticlassAccuracy',
 'sst2/MulticlassAccuracy STE',
 'stsb/SpearmanCorrCoef',
 'stsb/SpearmanCorrCoef STE']

In [102]:
pd.set_option('display.max_columns', None)
pd.set_option('display.expand_frame_repr', False)
#pd.set_option('max_colwidth', -1)

In [103]:
task_grouped = []
for experiment_name in grouped_df["experiment_name"]:
    experiment_grouped = [experiment_name]
    for metric_name in tasks_formatted:
        performance = grouped_df[
            (grouped_df["experiment_name"] == experiment_name) & (grouped_df["task"] == metric_name)]["final_metric"].values[0]

        error = grouped_df[
            (grouped_df["experiment_name"] == experiment_name) & (grouped_df["task"] == metric_name)]["error"].values[0]
        
        experiment_grouped.append(performance)
        experiment_grouped.append(error)
        
    task_grouped.append(experiment_grouped)

task_grouped_df = pd.DataFrame(task_grouped, columns=["experiment_name"]+metrics_with_errors)
task_grouped_df = task_grouped_df.drop_duplicates()
task_grouped_df["glue_mean"] = task_grouped_df[tasks_formatted].mean(axis=1)
constant_avg = task_grouped_df[
    task_grouped_df["experiment_name"] == "constant-0.15-0.15"]["glue_mean"].values[0]
task_grouped_df["glue_delta"] = task_grouped_df["glue_mean"].map(lambda avg: avg - constant_avg)
task_grouped_df = task_grouped_df.sort_values(by=["glue_mean"], ascending=False)
task_grouped_df

Unnamed: 0,experiment_name,cola/MulticlassMatthewsCorrCoef,cola/MulticlassMatthewsCorrCoef STE,mnli/MulticlassAccuracy,mnli/MulticlassAccuracy STE,mnli_mismatched/MulticlassAccuracy,mnli_mismatched/MulticlassAccuracy STE,mrpc/BinaryF1Score,mrpc/BinaryF1Score STE,mrpc/MulticlassAccuracy,mrpc/MulticlassAccuracy STE,qnli/MulticlassAccuracy,qnli/MulticlassAccuracy STE,qqp/BinaryF1Score,qqp/BinaryF1Score STE,qqp/MulticlassAccuracy,qqp/MulticlassAccuracy STE,rte/MulticlassAccuracy,rte/MulticlassAccuracy STE,sst2/MulticlassAccuracy,sst2/MulticlassAccuracy STE,stsb/SpearmanCorrCoef,stsb/SpearmanCorrCoef STE,glue_mean,glue_delta
66,linear-0.3-0.15,59.23,0.46,84.62,0.07,85.25,0.06,91.85,0.29,88.76,0.36,90.81,0.1,88.35,0.04,91.35,0.03,75.9,0.25,92.86,0.06,89.89,0.05,85.351818,0.438182
33,constant-0.3-0.3,57.17,0.55,84.53,0.05,84.86,0.05,92.5,0.19,89.61,0.26,90.54,0.05,88.33,0.04,91.31,0.03,77.21,0.42,92.65,0.13,89.81,0.15,85.32,0.406364
44,constant-0.35-0.35,56.23,0.4,84.51,0.04,85.02,0.07,91.81,0.26,88.63,0.3,90.74,0.08,88.33,0.04,91.32,0.03,78.7,0.45,92.85,0.14,89.94,0.09,85.28,0.366364
77,linear-0.3-0.2,57.1,0.66,84.55,0.07,84.98,0.03,91.14,0.24,87.53,0.37,90.85,0.04,88.26,0.04,91.28,0.03,77.03,0.27,92.85,0.08,89.7,0.15,85.024545,0.110909
99,linear-0.3-0.35,56.09,0.74,84.37,0.06,84.87,0.09,91.49,0.27,88.17,0.34,90.69,0.08,88.28,0.04,91.3,0.03,77.21,0.83,92.76,0.11,89.88,0.09,85.01,0.096364
22,constant-0.25-0.25,57.05,0.55,84.25,0.1,84.78,0.09,91.96,0.08,88.97,0.16,90.5,0.06,88.29,0.03,91.3,0.02,75.5,0.26,92.55,0.18,89.86,0.05,85.000909,0.087273
121,linear-0.3-0.45,55.88,0.55,84.08,0.04,84.67,0.08,91.93,0.38,88.79,0.51,90.77,0.09,88.29,0.05,91.31,0.03,76.62,0.33,92.33,0.14,89.86,0.07,84.957273,0.043636
0,constant-0.15-0.15,55.27,0.67,84.4,0.07,84.75,0.04,92.07,0.21,88.88,0.34,90.33,0.08,88.35,0.04,91.34,0.02,76.4,0.4,92.88,0.13,89.38,0.07,84.913636,0.0
11,constant-0.2-0.2,56.28,0.45,84.5,0.05,84.84,0.14,91.53,0.27,88.24,0.36,90.56,0.06,88.29,0.04,91.28,0.02,75.72,0.32,92.49,0.15,89.57,0.16,84.845455,-0.068182
55,constant-0.4-0.4,56.22,0.46,84.18,0.06,84.58,0.06,91.78,0.19,88.51,0.29,90.56,0.07,88.35,0.03,91.35,0.03,74.91,0.24,92.53,0.11,89.94,0.1,84.81,-0.103636


In [104]:
task_grouped_df[["experiment_name", "glue_delta"]]

Unnamed: 0,experiment_name,glue_delta
66,linear-0.3-0.15,0.438182
33,constant-0.3-0.3,0.406364
44,constant-0.35-0.35,0.366364
77,linear-0.3-0.2,0.110909
99,linear-0.3-0.35,0.096364
22,constant-0.25-0.25,0.087273
121,linear-0.3-0.45,0.043636
0,constant-0.15-0.15,0.0
11,constant-0.2-0.2,-0.068182
55,constant-0.4-0.4,-0.103636


In [105]:
paper_just_metrics = ["mnli/MulticlassAccuracy", "mnli_mismatched/MulticlassAccuracy", "qnli/MulticlassAccuracy",
                 "qqp/BinaryF1Score", "rte/MulticlassAccuracy", "sst2/MulticlassAccuracy",
                 "mrpc/BinaryF1Score", "cola/MulticlassMatthewsCorrCoef", "stsb/SpearmanCorrCoef"
                ]
paper_metrics = []
for metric_name in paper_just_metrics:
    paper_metrics.append(metric_name)
    paper_metrics.append(metric_name + " STE")

In [106]:
paper_df = task_grouped_df[["experiment_name", *paper_just_metrics]]
paper_df["glue_mean"] = paper_df[paper_just_metrics].mean(axis=1)
constant_avg = paper_df[
    paper_df["experiment_name"] == "constant-0.15-0.15"]["glue_mean"].values[0]
paper_df["glue_delta"] = paper_df["glue_mean"].map(lambda avg: avg - constant_avg)
paper_df = paper_df.sort_values(by="glue_mean", ascending=False)
paper_df

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  paper_df["glue_mean"] = paper_df[paper_just_metrics].mean(axis=1)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  paper_df["glue_delta"] = paper_df["glue_mean"].map(lambda avg: avg - constant_avg)


Unnamed: 0,experiment_name,mnli/MulticlassAccuracy,mnli_mismatched/MulticlassAccuracy,qnli/MulticlassAccuracy,qqp/BinaryF1Score,rte/MulticlassAccuracy,sst2/MulticlassAccuracy,mrpc/BinaryF1Score,cola/MulticlassMatthewsCorrCoef,stsb/SpearmanCorrCoef,glue_mean,glue_delta
66,linear-0.3-0.15,84.62,85.25,90.81,88.35,75.9,92.86,91.85,59.23,89.89,84.306667,0.547778
44,constant-0.35-0.35,84.51,85.02,90.74,88.33,78.7,92.85,91.81,56.23,89.94,84.236667,0.477778
33,constant-0.3-0.3,84.53,84.86,90.54,88.33,77.21,92.65,92.5,57.17,89.81,84.177778,0.418889
77,linear-0.3-0.2,84.55,84.98,90.85,88.26,77.03,92.85,91.14,57.1,89.7,84.051111,0.292222
99,linear-0.3-0.35,84.37,84.87,90.69,88.28,77.21,92.76,91.49,56.09,89.88,83.96,0.201111
22,constant-0.25-0.25,84.25,84.78,90.5,88.29,75.5,92.55,91.96,57.05,89.86,83.86,0.101111
121,linear-0.3-0.45,84.08,84.67,90.77,88.29,76.62,92.33,91.93,55.88,89.86,83.825556,0.066667
0,constant-0.15-0.15,84.4,84.75,90.33,88.35,76.4,92.88,92.07,55.27,89.38,83.758889,0.0
11,constant-0.2-0.2,84.5,84.84,90.56,88.29,75.72,92.49,91.53,56.28,89.57,83.753333,-0.005556
55,constant-0.4-0.4,84.18,84.58,90.56,88.35,74.91,92.53,91.78,56.22,89.94,83.672222,-0.086667


In [107]:
paper_df[["experiment_name", "glue_mean", "glue_delta"]]

Unnamed: 0,experiment_name,glue_mean,glue_delta
66,linear-0.3-0.15,84.306667,0.547778
44,constant-0.35-0.35,84.236667,0.477778
33,constant-0.3-0.3,84.177778,0.418889
77,linear-0.3-0.2,84.051111,0.292222
99,linear-0.3-0.35,83.96,0.201111
22,constant-0.25-0.25,83.86,0.101111
121,linear-0.3-0.45,83.825556,0.066667
0,constant-0.15-0.15,83.758889,0.0
11,constant-0.2-0.2,83.753333,-0.005556
55,constant-0.4-0.4,83.672222,-0.086667
