In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from itertools import chain
from pathlib import Path

from typing import List, Optional

params = {
    'axes.grid' : True,
    "grid.linestyle": '--',
    "font.family": "serif",
    "font.serif": "Times New Roman",
}

sns.set_style("ticks", params)
sns.set_context("paper", font_scale=1.5)
sns.set_palette("Set2")

In [None]:
data_root = Path("<path to data root>")
no_few_shot_model_path = Path("<path to best zero-shot adapters 7B model flores results>")
few_shot_model_path = Path("<path to best few-shot adapters 7B model flores results>")

In [None]:
def read_lines(path: Path, unescape_newline: bool = False) -> List[str]:
    with open(path) as f:
        lines = [l[:-1] for l in f.readlines()]
    if unescape_newline:
        lines = [l.replace("\\n", "\n") for l in lines]
    return lines

def load_scores(scores_file: Path):
    lines = scores_file.read_text().splitlines()
    scores = {}
    for line in lines:
        key, value = line.split(": ")
        scores[key] = float(value)
    return scores

def load_lp(dataset_root: Path, model_dataset_root: Path, lp: str, ckpt: str, instructions: str):
    sources = read_lines(dataset_root / lp / "train_eval.input.txt", unescape_newline=True)
    references = read_lines(dataset_root / lp / "train_eval.output.txt", unescape_newline=True)
    instructions_lines = read_lines(dataset_root / lp / f"{instructions}.txt", unescape_newline=True)

    scores = pd.read_csv(model_dataset_root / lp / ckpt / instructions / "seg_scores.txt")
    comet_scores = scores["COMET-22"] * 100

    translations = read_lines(model_dataset_root / lp / ckpt / instructions / "translations.txt", unescape_newline=True)

    records = [
        {
            "lp": lp,
            "source": s,
            "reference": r,
            "translation": t,
            "instruction": i,
            "score": c,
        }
        for s, r, t, i, c in zip(sources, references, translations, instructions_lines, comet_scores)
    ]

    return records

def load_results(data_root: Path, model_root: Path, dataset: str, ckpt: str, instructions: str):
    dataset_root = data_root / dataset
    model_dataset_root = model_root / dataset

    results = []
    lps_dirs = [d for d in model_dataset_root.iterdir() if d.is_dir()]
    for lp_dir in lps_dirs:
        lp = lp_dir.name
        results.extend(load_lp(dataset_root, model_dataset_root, lp, ckpt, instructions))
    df = pd.DataFrame(results)
    return df

def load(data_root, model_path):
    results = []

    domains = ["flores", "medical", "law", "tico", "chat_wmt"]#, "nllb_md_chat", "nllb_md_health", "nllb_md_news",]
    domain2label = {
        "flores": "Flores",
        "medical": "Medical",
        "law": "Law",
        "nllb_md_chat": "NLLB Chat",
        "nllb_md_health": "NLLB Health",
        "nllb_md_news": "NLLB News",
        "tico": "Tico",
        "chat_wmt": "Chat",
    }

    for domain in domains:
        zero_shot = load_results(data_root, model_path, domain, "20000", "zero_shot_instructions")
        zero_shot.rename(columns={"translation": "zero_shot_translation", "score": "zero_shot_score" }, inplace=True)
        few_shot = load_results(data_root, model_path, domain, "20000", "few_shot_instructions2")
        few_shot.rename(columns={"translation": "few_shot_translation", "score": "few_shot_score" }, inplace=True)

        df = pd.concat([zero_shot, few_shot.drop(columns=["source", "reference", "instruction", "lp"])], axis=1)
        df = df[~df["lp"].str.contains("zh")]
        df["Domain"] = domain2label[domain]
        df["Delta"] = df["few_shot_score"] - df["zero_shot_score"]
        results.append(df)

    results = pd.concat(results)
    return results

In [None]:
no_few_shot_results = load(data_root, no_few_shot_model_path)
no_few_shot_results["Model"] = "FT w/o few-shot"
few_shot_results = load(data_root, few_shot_model_path)
few_shot_results["Model"] = "FT w/ few-shot"
results = pd.concat([no_few_shot_results, few_shot_results])

In [None]:
non_eng_lang_order = ["de", "fr", "nl", "pt", "ru"]
lang_pairs = list(chain.from_iterable([[f"{lang}-en", f"en-{lang}"] for lang in non_eng_lang_order]))
flores_results = results[(results["Domain"] == "Flores") & (results["Model"] == "FT w/ few-shot")]
_, ax = plt.subplots(figsize=(5, 3))
g = sns.boxenplot(
    data=flores_results,
    y="Delta", ax=ax, x="lp",
    order=lang_pairs,
)
g.set_ylabel("COMET Score $\Delta$")
g.set_xlabel("Language Pair")
plt.xticks(rotation=45)
ax.set_ylim(-40, 40)
#plt.savefig(f"figures/flores_deltas.pdf", bbox_inches="tight", dpi=200)


In [None]:
non_eng_lang_order = ["de", "fr", "nl", "pt", "ru"]
lang_pairs = list(chain.from_iterable([[f"{lang}-en", f"en-{lang}"] for lang in non_eng_lang_order]))
flores_results = results[(results["Domain"] == "Flores") & (results["Model"] == "FT w/o few-shot")]
_, ax = plt.subplots(figsize=(5, 3))
g = sns.boxenplot(
    data=flores_results,
    y="Delta", ax=ax, x="lp",
    order=lang_pairs,
)
g.set_ylabel("COMET Score $\Delta$")
g.set_xlabel("Language Pair")
plt.xticks(rotation=45)
ax.set_ylim(-40, 40)
#plt.savefig(f"figures/flores_deltas_FT_no_fewshot.pdf", bbox_inches="tight", dpi=200)

In [None]:
height = 3.5
aspect = 5 / 3.5
width = aspect * 2
g = sns.catplot(
    kind="boxen",
    data=results,
    y="Delta", x="Domain", col="Model",
    height=height, aspect=aspect,
)
g.despine(left=False, bottom=False, right=False, top=False)
g.set_ylabels("COMET Score $\Delta$")
g.set_xlabels("")
g.set_titles("{col_name}")
g.tight_layout()
for ax in g.axes.flatten():
    ax.set_ylim(-70, 50)
    #ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
#plt.savefig(f"figures/domains_deltas.pdf", bbox_inches="tight", dpi=200)

In [None]:
results.groupby("Domain").mean()

In [None]:
for t in results.sort_values("Delta", ascending=False).head(100).itertuples():
    print("LP:", t.lp)
    print("SRC:", t.source)
    print("REF:", t.reference)
    print("0-shot:", t.zero_shot_translation)
    print("5-shot:", t.few_shot_translation)
    print("Delta:", t.Delta)
    print()

In [None]:
for t in results.sort_values("Delta", ascending=True).head(100).itertuples():
    print("LP:", t.lp)
    print("SRC:", t.source)
    print("REF:", t.reference)
    print("0-shot:", t.zero_shot_translation)
    print("5-shot:", t.few_shot_translation)
    print("Delta:", t.Delta)
    print()