In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from itertools import chain
from pathlib import Path

from typing import List, Optional

params = {
    'axes.grid' : True,
    "grid.linestyle": '--',
    "font.family": "serif",
    "font.serif": "Times New Roman",
}

sns.set_style("ticks", params)
sns.set_context("paper", font_scale=1.5)
sns.set_palette("Set2")

In [None]:
data_root = Path("<path to data>")
zero_shot_model = Path("<path to best zero-shot model>")
balanced_few_shot_model = Path("<path to best adapters model>")

In [None]:
def read_lines(path: Path, unescape_newline: bool = False) -> List[str]:
    with open(path) as f:
        lines = [l[:-1] for l in f.readlines()]
    if unescape_newline:
        lines = [l.replace("\\n", "\n") for l in lines]
    return lines

def load_scores(scores_file: Path):
    lines = scores_file.read_text().splitlines()
    scores = {}
    for line in lines:
        key, value = line.split(": ")
        scores[key] = float(value)
    return scores

def load_lp(dataset_root: Path, model_dataset_root: Path, lp: str, ckpt: str, instructions: str):
    sources = read_lines(dataset_root / lp / "train_eval.input.txt", unescape_newline=True)
    references = read_lines(dataset_root / lp / "train_eval.output.txt", unescape_newline=True)
    instructions_lines = read_lines(dataset_root / lp / f"{instructions}.txt", unescape_newline=True)

    scores = pd.read_csv(model_dataset_root / lp / ckpt / instructions / "seg_scores.txt")

    translations = read_lines(model_dataset_root / lp / ckpt / instructions / "translations.txt", unescape_newline=True)

    records = [
        {
            "lp": lp,
            "source": s,
            "reference": r,
            "translation": t,
            "instruction": i,
            "score": c,
        }
        for s, r, t, i, c in zip(sources, references, translations, instructions_lines, scores["COMET-22"] * 100)
    ]

    return records

def load_results(data_root: Path, model_root: Path, dataset: str, ckpt: str, instructions: str):
    dataset_root = data_root / dataset
    model_dataset_root = model_root / dataset

    results = []
    lps_dirs = [d for d in model_dataset_root.iterdir() if d.is_dir()]
    for lp_dir in lps_dirs:
        lp = lp_dir.name
        results.extend(load_lp(dataset_root, model_dataset_root, lp, ckpt, instructions))
    df = pd.DataFrame(results)
    return df

In [None]:
results = []

domains = ["medical", "law", "tico", "chat_wmt"]#, "nllb_md_chat", "nllb_md_health", "nllb_md_news",]
domain2label = {
    "flores": "Flores",
    "medical": "Medical",
    "law": "Law",
    "nllb_md_chat": "NLLB Chat",
    "nllb_md_health": "NLLB Health",
    "nllb_md_news": "NLLB News",
    "tico": "Tico",
    "chat_wmt": "Chat",
}

for domain in domains:
    few_shot = load_results(data_root, zero_shot_model, domain, "20000", "few_shot_instructions2")
    few_shot.rename(columns={"translation": "zero_shot_translation", "score": "zero_shot_score" }, inplace=True)
    balanced_few_shot = load_results(data_root, balanced_few_shot_model, domain, "20000", "few_shot_instructions2")
    balanced_few_shot.rename(columns={"translation": "few_shot_translation", "score": "few_shot_score" }, inplace=True)

    df = pd.concat([few_shot, balanced_few_shot.drop(columns=["instruction", "source", "reference", "lp"])], axis=1)
    df = df[~df["lp"].str.contains("zh")]
    df["Domain"] = domain2label[domain]
    df["delta"] = df["few_shot_score"] - df["zero_shot_score"]
    
    delta_threshold = 5
    # Minimum quality for zero-shot
    zero_shot_threshold = 70
    few_shot_threshold = 80

    df["is_improvement"] = (
        (df["delta"] > delta_threshold) &
        (df["zero_shot_score"] > zero_shot_threshold) &
        (df["few_shot_score"] > few_shot_threshold)
    )
    results.append(df)

results = pd.concat(results)
results

In [None]:
results["lp"].value_counts()

In [None]:
best_deltas = results[(results["is_improvement"]) & results["lp"].str.endswith("en")].sort_values("delta", ascending=False).sample(5)
for row in best_deltas.iterrows():
    print("-" * 80)
    #print("Instruction:")
    #print(row[1]["instruction"])
    print("Reference:", row[1]["reference"])
    print("Zero-shot:", row[1]["zero_shot_translation"])
    print("Few-shot:", row[1]["few_shot_translation"])
    print()

In [None]:
selected_spans = [
    "genetically modified feed",
    "538/2000",
]
selected_records = []
for span in selected_spans:
    selected_records.append(results[results["few_shot_translation"].str.contains(span)])
selected_records = pd.concat(selected_records)

for row in selected_records.iterrows():
    print("-" * 80)
    #print("Instruction:")
    #print(row[1]["instruction"])
    print("Source:", row[1]["source"])
    print("Reference:", row[1]["reference"])
    print("Zero-shot:", row[1]["zero_shot_translation"])
    print("Few-shot:", row[1]["few_shot_translation"])
    print()