In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from itertools import chain
from pathlib import Path

params = {
    'axes.grid' : True,
    "grid.linestyle": '--',
    "font.family": "serif",
    "font.serif": "Times New Roman",
}

sns.set_style("ticks", params)
sns.set_context("paper", font_scale=1.5)
sns.set_palette("Set2")

In [None]:
non_eng_lang_order = ["de", "fr", "nl", "pt", "ru", "zh"]
lang_pairs = list(chain.from_iterable([[f"{lang}-en", f"en-{lang}"] for lang in non_eng_lang_order]))

In [None]:
pretrained_root_path = Path("<path to pretrained 7B flores results>")
best_adapters_root_path = Path("<path to best adapters 7B flores results>")
finetune_root_path = Path("<path to best finetuned 7B flores results>")

def load_scores(scores_file: Path):
    lines = scores_file.read_text().splitlines()
    scores = {}
    for line in lines:
        key, value = line.split(": ")
        scores[key] = float(value)
    return scores

def load_lp(data_root: Path, lp: str, instructions, ckpt: str):
    sys_scores_path = data_root / lp / ckpt / instructions / "sys_scores.txt"
    scores = load_scores(sys_scores_path)
    return {"lp": lp, "Direction": "En-XX" if lp.startswith("en") else "XX-En", **scores}


def load_results(data_root: Path, instructions, ckpt: str):
    results = []
    lps_dirs = [d for d in data_root.iterdir() if d.is_dir()]
    for lp_dir in lps_dirs:
        lp = lp_dir.name
        results.append(load_lp(data_root, lp, instructions, ckpt))
    return pd.DataFrame(results)


pretrained_results = load_results(pretrained_root_path, "few_shot_instructions2", "0")
pretrained_results["Model"] = "Pretrained"
finetune_results = load_results(finetune_root_path, "zero_shot_instructions", "240000")
finetune_results["Model"] = "Finetuned"
best_adapters_results = load_results(best_adapters_root_path, "zero_shot_instructions", "20000")
best_adapters_results["Model"] = "LoRA"

results = pd.concat([pretrained_results, finetune_results, best_adapters_results])
results.rename(columns={"lp": "Language Pair" }, inplace=True)
results["COMET-22"] = results["COMET-22"] * 100
results

In [None]:
no_zh_results = results[~results["Language Pair"].str.contains("zh")]
no_zh_lang_pairs = [lp for lp in lang_pairs if not "zh" in lp]

In [None]:
_, ax = plt.subplots(figsize=(5, 3.5))
g = sns.barplot(
    data=no_zh_results, x="Language Pair", y="COMET-22", hue="Model",
    order=no_zh_lang_pairs,
    #order=["En-XX", "XX-En"],
    hue_order=["Pretrained", "Finetuned", "LoRA"],
    ax=ax,
)
g.legend().set_title("")
g.set_ylabel("COMET")
plt.xticks(rotation=45)
plt.ylim(80, 90)
sns.move_legend(
    g, "lower center",
    bbox_to_anchor=(.5, -.5), ncol=3, title=None, frameon=True,
)
#plt.savefig("figures/adapter_vs_finetuning.pdf", bbox_inches="tight", dpi=200)

In [None]:
def load_lp_all_ckpts(data_root: Path, lp: str):
    ckpt_dirs = [d for d in (data_root / lp).iterdir() if d.is_dir()]
    lp_results = []
    for ckpt_dir in ckpt_dirs:
        ckpt = ckpt_dir.name
        sys_scores_path = ckpt_dir / "zero_shot_instructions" / "sys_scores.txt"
        scores = load_scores(sys_scores_path)
        lp_results.append({"lp": lp, "Direction": "En-XX" if lp.startswith("en") else "XX-En", "Context": "Zero-Shot", "Step": int(ckpt), **scores})
        sys_scores_path = ckpt_dir / "few_shot_instructions2" / "sys_scores.txt"
        scores = load_scores(sys_scores_path)
        lp_results.append({"lp": lp, "Direction": "En-XX" if lp.startswith("en") else "XX-En", "Context": "Few-Shot", "Step": int(ckpt), **scores})
    return lp_results


def load_results_all_ckpts(data_root: Path):
    results = []
    lps_dirs = [d for d in data_root.iterdir() if d.is_dir()]
    for lp_dir in lps_dirs:
        lp = lp_dir.name
        results.extend(load_lp_all_ckpts(data_root, lp))
    return pd.DataFrame(results)

In [None]:
full_adapters_results = load_results_all_ckpts(best_adapters_root_path)
zero_shot_pretrained_results = load_results(pretrained_root_path, "zero_shot_instructions", "0")
zero_shot_pretrained_results["Step"] = 0
zero_shot_pretrained_results["Context"] = "Zero-Shot"

full_adapters_results = pd.concat([zero_shot_pretrained_results, full_adapters_results])
batch_size = 8
full_adapters_results["Sequences"] = full_adapters_results["Step"] * batch_size
full_adapters_results["COMET-22"] = full_adapters_results["COMET-22"] * 100
full_adapters_results = full_adapters_results[
    (full_adapters_results["Context"] == "Zero-Shot") &
    (~full_adapters_results["lp"].str.contains("zh"))
]
full_adapters_results

In [None]:
few_shot_pretrained_results = load_results(pretrained_root_path, "few_shot_instructions2", "0")
few_shot_pretrained_results["COMET-22"] = few_shot_pretrained_results["COMET-22"] * 100
grouped = few_shot_pretrained_results.drop(columns=["lp"]).groupby("Direction").mean()
print(grouped)
en_xx_baseline = grouped.loc["En-XX"]["COMET-22"]
xx_en_baseline = grouped.loc["XX-En"]["COMET-22"]
en_xx_baseline, xx_en_baseline

In [None]:
palette = sns.color_palette()
_, ax = plt.subplots(figsize=(5, 3.5))
full_adapters_results["Legend"] = "Finetuned " + full_adapters_results["Direction"]
ax.axhline(en_xx_baseline, 0, 20000, linestyle="dashed", color=palette[0], label="Pretrained En-XX")
ax.axhline(xx_en_baseline, 0, 20000, linestyle="dashed", color=palette[1], label="Pretrained XX-En")
g = sns.lineplot(
    data=full_adapters_results,
    x="Sequences", y="COMET-22", hue="Legend", style="Legend", 
    hue_order=["Finetuned En-XX", "Finetuned XX-En"], markers=["o", "^"], dashes=False, markersize=7,
    ax=ax,
)
palette = sns.color_palette()
g.legend().set_title("")
g.set_ylabel("COMET")
sns.move_legend(g, "lower right")
ticks = np.array([0, 1000, 5000, 10000, 15000, 20000]) * batch_size
xticklabels = [f"{t // 1000}" for t in ticks]
xticklabels[0] = "0"
plt.xticks(ticks, xticklabels)
plt.ylim(70, 90)
plt.xlabel("Training Examples (in thousands)")
sns.move_legend(
    g, "lower center",
    bbox_to_anchor=(.5, -.5), ncol=2, title=None, frameon=True,
)
#plt.savefig("figures/number_of_instructions.pdf", bbox_inches="tight", dpi=200)