In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from itertools import chain
from pathlib import Path

from typing import Optional

params = {
    'axes.grid' : True,
    "grid.linestyle": '--',
    "font.family": "serif",
    "font.serif": "Times New Roman",
}

sns.set_style("ticks", params)
sns.set_context("paper", font_scale=1.5)
sns.set_palette("Set2")

In [None]:
zero_shot_model_7B = Path("<path to best 7B zero-shot model>")
balanced_few_shot_model_7B = Path("<path to best 7B few-shot model>")

zero_shot_model_13B = Path("<path to best 13B zero-shot model>")
balanced_few_shot_model_13B = Path("<path to best 13B few-shot model>")

def load_scores(scores_file: Path):
    lines = scores_file.read_text().splitlines()
    scores = {}
    for line in lines:
        key, value = line.split(": ")
        scores[key] = float(value)
    return scores

def load_lp(data_root: Path, lp: str, ckpt: Optional[str] = None):
    lp_results = []
    base_dir = data_root / lp
    if ckpt is not None:
        base_dir = base_dir / ckpt
    
    sys_scores_path = base_dir / "zero_shot_instructions" / "sys_scores.txt"
    scores = load_scores(sys_scores_path)
    lp_results.append({"lp": lp, "Context": "Zero-shot", "Step": int(ckpt), **scores})
    sys_scores_path = base_dir / "few_shot_instructions2" / "sys_scores.txt"
    scores = load_scores(sys_scores_path)
    lp_results.append({"lp": lp, "Context": "Five-shot", "Step": int(ckpt), **scores})
    return lp_results


def load_results(data_root: Path, ckpt: Optional[str] = None):
    results = []
    lps_dirs = [d for d in data_root.iterdir() if d.is_dir()]
    for lp_dir in lps_dirs:
        lp = lp_dir.name
        results.extend(load_lp(data_root, lp, ckpt))
    return pd.DataFrame(results)

def load_lp_all(data_root: Path, lp: str):
    ckpt_dirs = [d for d in (data_root / lp).iterdir() if d.is_dir()]
    lp_results = []
    for ckpt_dir in ckpt_dirs:
        ckpt = ckpt_dir.name
        sys_scores_path = ckpt_dir / "zero_shot_instructions" / "sys_scores.txt"
        scores = load_scores(sys_scores_path)
        lp_results.append({"lp": lp, "Context": "Zero-shot", "Step": int(ckpt), **scores})
        sys_scores_path = ckpt_dir / "few_shot_instructions2" / "sys_scores.txt"
        scores = load_scores(sys_scores_path)
        lp_results.append({"lp": lp, "Context": "Five-shot", "Step": int(ckpt), **scores})
    return lp_results


def load_results_all(data_root: Path):
    results = []
    lps_dirs = [d for d in data_root.iterdir() if d.is_dir()]
    for lp_dir in lps_dirs:
        lp = lp_dir.name
        results.extend(load_lp_all(data_root, lp))
    return pd.DataFrame(results)

domains = ["flores", "wmt", "medical", "law", "tico", "chat_wmt"]#, "nllb_md_chat", "nllb_md_health", "nllb_md_news",]
domain2label = {
    "flores": "Flores",
    "wmt": "WMT 22",
    "medical": "Medical",
    "law": "Law",
    "nllb_md_chat": "NLLB Chat",
    "nllb_md_health": "NLLB Health",
    "nllb_md_news": "NLLB News",
    "tico": "Tico",
    "chat_wmt": "Chat",
}

all_results = []

for domain in domains:
    zero_shot_results_7B = load_results(zero_shot_model_7B / domain, "20000")
    zero_shot_results_7B["Model"] = "FT w/o few-shot"
    zero_shot_results_7B["Model Size"] = "7B"

    balanced_few_shot_results_7B = load_results(balanced_few_shot_model_7B / domain, "20000")
    balanced_few_shot_results_7B["Model"] = "FT w/ few-shot"
    balanced_few_shot_results_7B["Model Size"] = "7B"

    zero_shot_results_13B = load_results(zero_shot_model_13B / domain, "20000")
    zero_shot_results_13B["Model"] = "FT w/o few-shot"
    zero_shot_results_13B["Model Size"] = "13B"
    
    balanced_few_shot_results_13B = load_results(balanced_few_shot_model_13B / domain, "40000")
    balanced_few_shot_results_13B["Model"] = "FT w/ few-shot"
    balanced_few_shot_results_13B["Model Size"] = "13B"
    
    results = pd.concat([zero_shot_results_7B, balanced_few_shot_results_7B, zero_shot_results_13B, balanced_few_shot_results_13B])
    results["Domain"] = domain2label[domain]
    all_results.append(results)

all_results = pd.concat(all_results)
all_results["COMET-22"] = all_results["COMET-22"] * 100
all_results["COMETKiwi"] = all_results["COMETKiwi"] * 100

all_results = all_results[["Domain", "lp", "Model", "Model Size", "Context", "COMET-22", "COMETKiwi", "BLEU", "chrF"]]
all_results.rename(columns={"COMET-22": "COMET", "lp": "Language Pair"}, inplace=True)

domain_order = ["Flores", "WMT 22", "Medical", "Law", "Tico", "Chat"]
non_eng_lang_order = ["de", "fr", "nl", "pt", "ru", "zh"]
lang_pairs = list(chain.from_iterable([[f"{lang}-en", f"en-{lang}"] for lang in non_eng_lang_order]))
models_order = ["FT w/o few-shot", "FT w/ few-shot"]
context_order = ["Zero-shot", "Five-shot"]

metrics = ["COMETKiwi", "BLEU", "chrF"]

def sort_func(x):
    if x in domain_order:
        return domain_order.index(x)
    elif x in lang_pairs:
        return lang_pairs.index(x)
    elif x in models_order:
        return models_order.index(x)
    else:
        return context_order.index(x)
    
for model_size in ["7B", "13B"]:
    model_size_results = all_results[all_results["Model Size"] == model_size].drop(columns=["Model Size"])
    model_size_results.sort_values(by=["Domain", "Language Pair", "Model", "Context"], inplace=True, key=lambda x: x.apply(lambda y: sort_func(y)))

    flores_results = model_size_results[model_size_results["Domain"] == "Flores"].drop(columns=["Domain"])
    flores_results.set_index(["Language Pair", "Model", "Context"], inplace=True)
    flores_results.to_latex(f"tables/examples_vs_no_examples_flores_{model_size}.tex", float_format="%.2f")

    wmt_results = model_size_results[model_size_results["Domain"] == "WMT 22"].drop(columns=["Domain"])
    wmt_results.set_index(["Language Pair", "Model", "Context"], inplace=True)
    wmt_results.to_latex(f"tables/examples_vs_no_examples_wmt_{model_size}.tex", float_format="%.2f")

    non_flores_results = model_size_results[~model_size_results["Domain"].isin(["Flores", "WMT 22"])]
    non_flores_results.set_index(["Domain", "Language Pair", "Model", "Context"], inplace=True)
    non_flores_results.to_latex(f"tables/examples_vs_no_examples_domains_{model_size}.tex", float_format="%.2f")