In [1]:
import json
from pathlib import Path
import re
import pandas as pd
import plotly.express as px

In [2]:
def load_scores(name):
    p = re.compile(r".*\\(.*)\\.*")
    data = {}
    for fname in Path("./.eval/").glob(f"*/{name}"):
        if "news" in str(fname): 
            continue # skip news invalid format
        with fname.open("rb") as fp:
            ckp_data = json.load(fp)
            ckp_data = dict((int(k), ckp_data[k]["test_score"]) for k in ckp_data.keys())
            data[p.match(str(fname)).groups()[0]] = ckp_data
    return data

def convert_to_df_1(data):
    a = []
    for k1, d in data.items():
        for k2, v in d.items():
            if k1.endswith("-bt-500k"):
                a.append({"model" : f"base+BT [{k1[:5]}]", "steps": k2, "score": v})
            elif k1.endswith("+bt-250k"):
                a.append({"model" : f"base [{k1[:5]}]", "steps": k2, "score": v})
            elif k1.endswith("-mixed-500k"):
                a.append({"model" : f"extended [{k1[:5]}]", "steps": k2, "score": v})
    df = pd.DataFrame(a)
    return df

def convert_to_df_0(data):
    a = []
    for k1, d in data.items():
        for k2, v in d.items():
            if k1.endswith("-BERT-GPT2-xattn"):
                a.append({"model" : f"BERT-GPT2 (xattn) [{k1[:5]}]", "steps": k2, "score": v})
            elif k1.endswith("-BERT-GPT2-xattn-LoRA"):
                a.append({"model" : f"BERT-GPT2 (xattn->LoRA) [{k1[:5]}]", "steps": k2, "score": v})
            elif k1.endswith("+bt-250k") and k1.startswith("en-ja"):
                a.append({"model" : f"mBART [{k1[:5]}]", "steps": k2, "score": v})
    df = pd.DataFrame(a)
    return df

#### BERT-GPT2 vs mBART

In [3]:
px.line(
    convert_to_df_0(load_scores("flores_dev.json")), 
    x="steps", y="score", color="model", range_x=(2500, 25000), # range_y=(9, 15)
)

In [7]:
convert_to_df_0(load_scores("wmt_vat.json")).groupby("model").last()

Unnamed: 0_level_0,steps,score
model,Unnamed: 1_level_1,Unnamed: 2_level_1
BERT-GPT2 (xattn) [en-ja],25000,1.699591
BERT-GPT2 (xattn->LoRA) [en-ja],25000,2.470439
mBART [en-ja],25000,9.392265


#### mBART 

In [5]:
px.line(
    convert_to_df_1(load_scores("flores_dev.json")), 
    x="steps", y="score", color="model", range_x=(5000, 50000), range_y=(9, 15)
)

In [8]:
convert_to_df_1(load_scores("wmt_vat.json")).groupby("model").last()

Unnamed: 0_level_0,steps,score
model,Unnamed: 1_level_1,Unnamed: 2_level_1
base [en-ja],25000,9.392265
base [ja-en],25000,10.160598
base+BT [en-ja],50000,9.538963
base+BT [ja-en],50000,10.555422
extended [en-ja],50000,9.491911
extended [ja-en],50000,10.897228
