In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from pathlib import Path

from typing import Optional

params = {
    'axes.grid' : True,
    "grid.linestyle": '--',
    "font.family": "serif",
    "font.serif": "Times New Roman",
}
sns.set_style("ticks", params)
sns.set_context("paper", font_scale=1.5)
sns.set_palette("Set2")

In [None]:
zero_shot_model = Path("<path to best zero-shot adapters model>")
balanced_few_shot_model = Path("<path to best balanced few-shot adapters model>")
unbalanced_few_shot_model = Path("<path to best unbalanced few-shot adapters model>")

In [None]:
def load_scores(scores_file: Path):
    lines = scores_file.read_text().splitlines()
    scores = {}
    for line in lines:
        key, value = line.split(": ")
        scores[key] = float(value)
    return scores

def load_lp(data_root: Path, lp: str, ckpt: Optional[str] = None):
    lp_results = []
    base_dir = data_root / lp
    if ckpt is not None:
        base_dir = base_dir / ckpt
    
    sys_scores_path = base_dir / "zero_shot_instructions" / "sys_scores.txt"
    scores = load_scores(sys_scores_path)
    lp_results.append({"lp": lp, "into/from": "En-XX" if lp.startswith("en") else "XX-En", "Context": "Zero-shot", "Step": int(ckpt), **scores})
    sys_scores_path = base_dir / "few_shot_instructions2" / "sys_scores.txt"
    scores = load_scores(sys_scores_path)
    lp_results.append({"lp": lp, "into/from": "En-XX" if lp.startswith("en") else "XX-En", "Context": "Five-shot", "Step": int(ckpt), **scores})
    return lp_results


def load_results(data_root: Path, ckpt: Optional[str] = None):
    results = []
    lps_dirs = [d for d in data_root.iterdir() if d.is_dir()]
    for lp_dir in lps_dirs:
        lp = lp_dir.name
        results.extend(load_lp(data_root, lp, ckpt))
    return pd.DataFrame(results)

def load_lp_all(data_root: Path, lp: str):
    ckpt_dirs = [d for d in (data_root / lp).iterdir() if d.is_dir()]
    lp_results = []
    for ckpt_dir in ckpt_dirs:
        ckpt = ckpt_dir.name
        sys_scores_path = ckpt_dir / "zero_shot_instructions" / "sys_scores.txt"
        scores = load_scores(sys_scores_path)
        lp_results.append({"lp": lp, "Context": "Zero-shot", "Step": int(ckpt), **scores})
        sys_scores_path = ckpt_dir / "few_shot_instructions2" / "sys_scores.txt"
        scores = load_scores(sys_scores_path)
        lp_results.append({"lp": lp, "Context": "Five-shot", "Step": int(ckpt), **scores})
    return lp_results


def load_results_all(data_root: Path):
    results = []
    lps_dirs = [d for d in data_root.iterdir() if d.is_dir()]
    for lp_dir in lps_dirs:
        lp = lp_dir.name
        results.extend(load_lp_all(data_root, lp))
    return pd.DataFrame(results)

In [None]:
domains = ["flores", "medical", "law", "tico"] #, "nllb_md_chat", "nllb_md_health", "nllb_md_news",]
domain2label = {
    "flores": "Flores",
    "medical": "Medical",
    "law": "Law",
    "nllb_md_chat": "NLLB Chat",
    "nllb_md_health": "NLLB Health",
    "nllb_md_news": "NLLB News",
    "tico": "Tico",
    "chat_wmt": "Chat",
}

all_results = []

for domain in domains:
    balanced_few_shot_results = load_results(balanced_few_shot_model / domain, "20000")
    balanced_few_shot_results["Data Mixture"] = "Balanced"
    unbalanced_few_shot_results = load_results(unbalanced_few_shot_model / domain, "20000")
    unbalanced_few_shot_results["Data Mixture"] = "Unbalanced"
    results = pd.concat([balanced_few_shot_results, unbalanced_few_shot_results])
    results["Domain"] = domain2label[domain]
    all_results.append(results)

all_results = pd.concat(all_results)
all_results["COMET-22"] = all_results["COMET-22"] * 100
all_results

In [None]:
all_results.drop(columns=["into/from", "lp", "Step"]).groupby(["Domain", "Data Mixture", "Context"]).agg("mean")

In [None]:
#_, ax = plt.subplots(figsize=(10, 5))
aspect = 0.9
width = 14 / aspect
height = width / 5
# Define sns color palette to use 
sns.set_palette("Set2")
g = sns.catplot(
    data=all_results, 
    kind="bar", 
    x="Data Mixture", y="COMET-22", hue="Context", col="Domain", 
    hue_order=["Zero-shot", "Five-shot"], 
    errorbar=None,
    height=height,
    aspect=aspect,
)
g.despine(left=False, bottom=False, top=False, right=False)
g.set(ylim=(80, 88))
g.set_xlabels("")
g.set_ylabels("COMET")
g.set_titles("{col_name}")
g.tight_layout()

flores_ax = g.axes[0][0]
flores_ax_trans = flores_ax.get_xaxis_transform()
plt.annotate("General Domain", xy=(0.5, 1.2), xycoords=flores_ax_trans, ha="center", va="bottom")
height = 1.17
plt.plot([-.4, 1.4],[height,height], color="k", transform=flores_ax_trans, clip_on=False, linewidth=.9)

medical_ax = g.axes[0][1]
medical_ax_trans = medical_ax.get_xaxis_transform()
plt.annotate("Specialized Domains", xy=(2.7, 1.2), xycoords=medical_ax_trans, ha="center", va="bottom")
plt.plot([-.4, 5.8],[height,height], color="k", transform=medical_ax_trans, clip_on=False, linewidth=.9)

sns.move_legend(
    g, "lower center",
    bbox_to_anchor=(.43, -.1), ncol=3, title=None, frameon=True,
)
#plt.savefig("figures/data_mixes.pdf", bbox_inches="tight", dpi=300)