In [2]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pandas.io.json._normalize import nested_to_record

%matplotlib inline

In [3]:
username = "lavender"
project_name = "new_manhattan_vs_brooklyn-fixed-label"
sample_name = "data/num_train_samples"
locations = ["manhattan", "brooklyn"]
metric_name = [f"test/{loc}/roc_auc" for loc in locations]
temporal_metric_name = [f"test/{loc}/temporal_roc_auc" for loc in locations]
loss_name = "test/loss"
temporal_loss_name = "test/temporal_loss"
pretrained_names = ["brooklyn", "tisch"]

In [4]:
import wandb

api = wandb.Api()

# Project is specified by <entity/project-name>
runs = api.runs(f"{username}/{project_name}")
summary_list = []
config_list = []
name_list = []
for run in runs:
    # run.summary are the output key/values like accuracy.
    # We call ._json_dict to omit large files
    summary_list.append(run.summary._json_dict)

    # run.config is the input metrics.
    # We remove special values that start with _.
    conf = nested_to_record(run.config, sep="/")
    config = {k: v for k, v in conf.items() if not k.startswith("_")}
    config_list.append(config)

    # run.name is the name of the run.
    name_list.append(run.name)

import pandas as pd

summary_df = pd.DataFrame.from_records(summary_list)
config_df = pd.DataFrame.from_records(config_list)
name_df = pd.DataFrame({"name": name_list})
all_df = pd.concat([name_df, config_df, summary_df], axis=1)

In [7]:
stats = all_df[
    [
        "pretrained",
        "finetuned",
        "test/brooklyn/temporal_roc_auc",
        "test/brooklyn/roc_auc",
        "test/manhattan/temporal_roc_auc",
        "test/manhattan/roc_auc",
        sample_name,
    ]
]
stats.to_csv("Fig3cd.csv")
stats

Unnamed: 0,pretrained,finetuned,test/brooklyn/temporal_roc_auc,test/brooklyn/roc_auc,test/manhattan/temporal_roc_auc,test/manhattan/roc_auc,data/num_train_samples
0,all_sites,manhattan,0.779699,0.827747,0.814015,0.847684,80661
1,all_sites,brooklyn,0.804892,0.845732,0.808016,0.833669,80661
2,all_sites,manhattan,0.794626,0.821461,0.815204,0.849114,80661
3,all_sites,brooklyn,0.817378,0.843731,0.809234,0.833678,80661
4,all_sites,manhattan,0.800225,0.831433,0.820215,0.845698,80661
5,all_sites,manhattan,0.798773,0.827032,0.81848,0.84825,80661
6,all_sites,brooklyn,0.812011,0.840594,0.810536,0.834868,80661
7,brooklyn,brooklyn,0.798853,0.844475,0.788279,0.820112,80661
8,brooklyn,manhattan,0.784049,0.824363,0.812227,0.842101,80661
9,manhattan,brooklyn,0.807531,0.840007,0.806726,0.832247,80661


In [8]:
stats.groupby(["pretrained", "finetuned"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,test/brooklyn/temporal_roc_auc,test/brooklyn/roc_auc,test/manhattan/temporal_roc_auc,test/manhattan/roc_auc,data/num_train_samples
pretrained,finetuned,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
all_sites,brooklyn,0.812392,0.843851,0.809947,0.834677,80661.0
all_sites,manhattan,0.791405,0.826459,0.816095,0.847449,80661.0
brooklyn,brooklyn,0.796474,0.844525,0.784578,0.814815,80661.0
brooklyn,manhattan,0.787448,0.818017,0.809092,0.84109,80661.0
manhattan,brooklyn,0.805489,0.839327,0.806995,0.831399,80661.0
manhattan,manhattan,0.786353,0.825325,0.817278,0.847925,80661.0


In [9]:
stats.groupby(["pretrained", "finetuned"]).std()

Unnamed: 0_level_0,Unnamed: 1_level_0,test/brooklyn/temporal_roc_auc,test/brooklyn/roc_auc,test/manhattan/temporal_roc_auc,test/manhattan/roc_auc,data/num_train_samples
pretrained,finetuned,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
all_sites,brooklyn,0.005466,0.002338,0.001713,0.001335,0.0
all_sites,manhattan,0.009203,0.003711,0.003174,0.001363,0.0
brooklyn,brooklyn,0.006916,0.001715,0.005974,0.003539,0.0
brooklyn,manhattan,0.005622,0.004862,0.003362,0.000883,0.0
manhattan,brooklyn,0.002864,0.001416,0.002338,0.001334,0.0
manhattan,manhattan,0.002746,0.002013,0.002582,0.000915,0.0
