In [None]:
import pandas as pd
import numpy as np

from pathlib import Path
from tqdm import tqdm

from templates import *

from typing import List, Tuple

In [None]:
root_data_dir = "<path to data>"

In [None]:
def load_train_eval(root: Path, lp: str):
    # Load train_eval.input.txt and train_eval.output.txt for a lp.
    input_path = root / lp / "train_eval.input.txt"
    output_path = root / lp / "train_eval.output.txt"
    inputs = input_path.read_text().split("\n")
    outputs = output_path.read_text().split("\n")
    assert len(inputs) == len(outputs)
    records = []
    for i, o in zip(inputs, outputs):
        records.append({"input": i, "output": o, "lp": lp})

    train_eval = pd.DataFrame.from_records(records)
    
    return train_eval

def load_few_shot(root: Path, lp: str):
    # Load train_eval.input.txt and train_eval.output.txt for a lp.
    input_path = root / lp / "few_shot.input.txt"
    output_path = root / lp / "few_shot.output.txt"
    inputs = input_path.read_text().split("\n")
    outputs = output_path.read_text().split("\n")
    assert len(inputs) == len(outputs)
    records = []
    for i, o in zip(inputs, outputs):
        records.append({"input": i, "output": o, "lp": lp})
    
    train_eval = pd.DataFrame.from_records(records)
    
    return train_eval

def sample_examples(few_shot_df: pd.DataFrame, lp: str, n: int, k: int, seed: int = 42) -> List[List[Tuple[str, str]]]:
    rng = np.random.default_rng(seed)
    few_shot_df = few_shot_df[few_shot_df["lp"] == lp]

    idxs = [rng.choice(few_shot_df.index, size=k, replace=False) for _ in range(n)]
    rows = [few_shot_df.loc[idx] for idx in idxs]
    examples = [
        [(row["input"], row["output"]) for _, row in r.iterrows()]
        for r in rows
    ]
    return examples

def write_escaped_lines(lines, path):
    lines = [line.replace("\n", "\\n") for line in lines]
    with open(path, "w") as f:
        f.write("\n".join(lines))


In [None]:
def process_dataset(root_dir: Path, lps: List[str]):
    train_eval_dfs = {lp: load_train_eval(root_dir, lp) for lp in lps}
    few_shot_dfs = {lp: load_few_shot(root_dir, lp) for lp in lps}

    for lp, train_eval_df in train_eval_dfs.items():
        few_shot_df = few_shot_dfs[lp]
        few_shot_examples = sample_examples(few_shot_df, lp=lp, n=len(train_eval_df), k=5)
        train_eval_df["few_shot_examples"] = few_shot_examples
        train_eval_df["zero_shot_instruction"] = train_eval_df.apply(
            lambda x: instruction_template(lp, x["input"]), axis=1,
        )
        train_eval_df["few_shot_instruction1"] = train_eval_df.apply(
            lambda x: format1_few_shot_instruction_template(lp, x["input"], x["few_shot_examples"]), axis=1,
        )
        train_eval_df["few_shot_instruction2"] = train_eval_df.apply(
            lambda x: format2_few_shot_instruction_template(lp, x["input"], x["few_shot_examples"]), axis=1,
        )
        train_eval_df["few_shot_instruction3"] = train_eval_df.apply(
            lambda x: format3_few_shot_instruction_template(lp, x["input"], x["few_shot_examples"]), axis=1,
        )
        zero_shot_path = root_dir / lp / f"zero_shot_instructions.txt"
        few_shot1_path = root_dir / lp / f"few_shot_instructions1.txt"
        few_shot2_path = root_dir / lp / f"few_shot_instructions2.txt"
        few_shot3_path = root_dir / lp / f"few_shot_instructions3.txt"
        write_escaped_lines(train_eval_df["zero_shot_instruction"].tolist(), zero_shot_path)
        write_escaped_lines(train_eval_df["few_shot_instruction1"].tolist(), few_shot1_path)
        write_escaped_lines(train_eval_df["few_shot_instruction2"].tolist(), few_shot2_path)
        write_escaped_lines(train_eval_df["few_shot_instruction3"].tolist(), few_shot3_path)

In [None]:
datasets = {
    "flores": {
        "root_dir": Path(f"{root_data_dir}/flores/"),
        "lps": ["de-en", "en-de", "fr-en", "en-fr", "nl-en", "en-nl", "pt-en", "en-pt", "ru-en", "en-ru", "zh-en", "en-zh"],
    },
    "wmt": {
        "root_dir": Path(f"{root_data_dir}/wmt/"),
        "lps": ["de-en", "en-de", "ru-en", "en-ru", "zh-en", "en-zh"],
    },

    "law": {
        "root_dir": Path(f"{root_data_dir}/law/"),
        "lps": ["de-en", "en-de"],
    },
    "medical": {
        "root_dir": Path(f"{root_data_dir}/medical/"),
        "lps": ["de-en", "en-de"],
    },

    "tico": {
        "root_dir": Path(f"{root_data_dir}/tico/"),
        "lps": ["en-fr", "en-pt"],
    },

    "chat_wmt": {
        "root_dir": Path(f"{root_data_dir}/chat_wmt/"),
        "lps": ["en-de", "en-fr", "en-pt"],
    },
}

for dataset, args in tqdm(datasets.items(), total=len(datasets)):
    process_dataset(**args)