In [2]:
##IMPORTS
import pandas as pd
import polars as pl
from collections import Counter
from transformers import pipeline
# from googletrans import Translator
import string
import nltk
from nltk.corpus import stopwords
from nltk import word_tokenize
import regex as re
from collections import Counter
import pickle
import numpy as np

# WEEK 36

In [3]:
#DOWNLOAD DATASET

splits = {'train': 'train.parquet', 'validation': 'validation.parquet'}
df_train = pd.read_parquet("hf://datasets/coastalcph/tydi_xor_rc/" + splits["train"])
df_val = pd.read_parquet("hf://datasets/coastalcph/tydi_xor_rc/" + splits["validation"])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


## Stats

In [4]:
#STATS

#SIZE

langs = ["ar", "ko", "te"]


train_counts = df_train[df_train["lang"].isin(langs)].groupby("lang").size()


val_counts = df_val[df_val["lang"].isin(langs)].groupby("lang").size()

size_df = pd.DataFrame({
    "train_size": train_counts,
    "val_size": val_counts
})

print("Dataset sizes for Arabic, Korean and Telugu:")
print(size_df)



Dataset sizes for Arabic, Korean and Telugu:
      train_size  val_size
lang                      
ar          2558       415
ko          2422       356
te          1355       384


In [5]:

## Each language punctuation
PUNCT_RE = re.compile(r"\p{P}")

# ARABIC
ar_train_q = df_train[df_train["lang"] == "ar"]["question"]
ar_val_q   = df_val[df_val["lang"] == "ar"]["question"]

ar_train_punct = Counter()
for q in ar_train_q:
    ar_train_punct.update(PUNCT_RE.findall(q))

ar_val_punct = Counter()
for q in ar_val_q:
    ar_val_punct.update(PUNCT_RE.findall(q))

print("Arabic — TRAIN punctuation count:")
print(ar_train_punct.most_common())
print("Arabic — VAL punctuation count:")
print(ar_val_punct.most_common())

# KOREAN
ko_train_q = df_train[df_train["lang"] == "ko"]["question"]
ko_val_q   = df_val[df_val["lang"] == "ko"]["question"]

ko_train_punct = Counter()
for q in ko_train_q:
    ko_train_punct.update(PUNCT_RE.findall(q))

ko_val_punct = Counter()
for q in ko_val_q:
    ko_val_punct.update(PUNCT_RE.findall(q))

print("Korean — TRAIN punctuation count:")
print(ko_train_punct.most_common())
print("Korean — VAL punctuation count:")
print(ko_val_punct.most_common())

# TELUGU
te_train_q = df_train[df_train["lang"] == "te"]["question"]
te_val_q   = df_val[df_val["lang"] == "te"]["question"]

te_train_punct = Counter()
for q in te_train_q:
    te_train_punct.update(PUNCT_RE.findall(q))

te_val_punct = Counter()
for q in te_val_q:
    te_val_punct.update(PUNCT_RE.findall(q))

print("Telugu — TRAIN punctuation count:")
print(te_train_punct.most_common())
print("Telugu — VAL punctuation count:")
print(te_val_punct.most_common())


Arabic — TRAIN punctuation count:
[('؟', 2556), ('"', 80), ('(', 25), (')', 25), ('-', 5), ('.', 2), ('/', 2), ('«', 2), ('»', 2), ('_', 2), ('\\', 1), ('—', 1), ('!', 1), ('،', 1)]
Arabic — VAL punctuation count:
[('؟', 413), ('"', 4), ('(', 3), (')', 3), ('،', 1), ('-', 1)]
Korean — TRAIN punctuation count:
[('?', 2420), (',', 23), ('.', 16), ("'", 6), ('"', 6), ('-', 5), (':', 2), ('/', 1), ('\\', 1), ('(', 1), (')', 1)]
Korean — VAL punctuation count:
[('?', 356), ('.', 9), (',', 3), ('-', 1)]
Telugu — TRAIN punctuation count:
[('?', 1355), ('.', 42), (',', 6), ('-', 3), ('%', 1), ('–', 1)]
Telugu — VAL punctuation count:
[('?', 384), ('.', 2), ('-', 1), ('%', 1)]


In [15]:
## Each language total words (not counting punctuation)
# tokenizer: split on \W+ (non-word chars); protect hyphens between letters/digits
# safeguard: build punctuation set from training+validation data, do not count these tokens as well

TOKEN_SPLIT = re.compile(r"\W+") # tokenizer
HY = "HY" # placeholder for protected hyphens
PROTECT_HYPHEN = re.compile(r"(?<=[\p{L}\p{N}])-(?=[\p{L}\p{N}])") # hyphen between letters/digits

# ARABIC
ar_train_q = df_train[df_train["lang"] == "ar"]["question"]
ar_val_q   = df_val[df_val["lang"] == "ar"]["question"]

# build punctuation set (optional safeguard)
ar_punct_set = set()
for q in pd.concat([ar_train_q, ar_val_q]):
    for ch in PUNCT_RE.findall(q):
        ar_punct_set.add(ch)

# protect hyphens, split on \W+, restore hyphens; slashes will split
ar_train_tokens = []
for q in ar_train_q:
    q2 = PROTECT_HYPHEN.sub(HY, q)
    toks = [t.replace(HY, "-") for t in TOKEN_SPLIT.split(q2) if t and t not in ar_punct_set]
    ar_train_tokens.extend(toks)

ar_val_tokens = []
for q in ar_val_q:
    q2 = PROTECT_HYPHEN.sub(HY, q)
    toks = [t.replace(HY, "-") for t in TOKEN_SPLIT.split(q2) if t and t not in ar_punct_set]
    ar_val_tokens.extend(toks)

print("Arabic - TRAIN total words:", len(ar_train_tokens))
print("Arabic - VAL total words:",   len(ar_val_tokens))

# KOREAN
ko_train_q = df_train[df_train["lang"] == "ko"]["question"]
ko_val_q   = df_val[df_val["lang"] == "ko"]["question"]

ko_punct_set = set()
for q in pd.concat([ko_train_q, ko_val_q]):
    for ch in PUNCT_RE.findall(q):
        ko_punct_set.add(ch)

ko_train_tokens = []
for q in ko_train_q:
    q2 = PROTECT_HYPHEN.sub(HY, q)
    toks = [t.replace(HY, "-") for t in TOKEN_SPLIT.split(q2) if t and t not in ko_punct_set]
    ko_train_tokens.extend(toks)

ko_val_tokens = []
for q in ko_val_q:
    q2 = PROTECT_HYPHEN.sub(HY, q)
    toks = [t.replace(HY, "-") for t in TOKEN_SPLIT.split(q2) if t and t not in ko_punct_set]
    ko_val_tokens.extend(toks)

print("Korean - TRAIN total words:", len(ko_train_tokens))
print("Korean - VAL total words:",   len(ko_val_tokens))

# TELUGU
te_train_q = df_train[df_train["lang"] == "te"]["question"]
te_val_q   = df_val[df_val["lang"] == "te"]["question"]

te_punct_set = set()
for q in pd.concat([te_train_q, te_val_q]):
    for ch in PUNCT_RE.findall(q):
        te_punct_set.add(ch)

te_train_tokens = []
for q in te_train_q:
    q2 = PROTECT_HYPHEN.sub(HY, q)
    toks = [t.replace(HY, "-") for t in TOKEN_SPLIT.split(q2) if t and t not in te_punct_set]
    te_train_tokens.extend(toks)

te_val_tokens = []
for q in te_val_q:
    q2 = PROTECT_HYPHEN.sub(HY, q)
    toks = [t.replace(HY, "-") for t in TOKEN_SPLIT.split(q2) if t and t not in te_punct_set]
    te_val_tokens.extend(toks)

print("Telugu - TRAIN total words:", len(te_train_tokens))
print("Telugu - VAL total words:",   len(te_val_tokens))


Arabic - TRAIN total words: 16199
Arabic - VAL total words: 2617
Korean - TRAIN total words: 11858
Korean - VAL total words: 1736
Telugu - TRAIN total words: 7690
Telugu - VAL total words: 2302


In [16]:
#Stats on numeric and hyphenated tokens

# After tokenization for Arabic
ar_numbers_train = sum(1 for t in ar_train_tokens if t.isdigit())
ar_numbers_val   = sum(1 for t in ar_val_tokens if t.isdigit())

ar_hyphen_train = sum(1 for t in ar_train_tokens if "-" in t)
ar_hyphen_val   = sum(1 for t in ar_val_tokens if "-" in t)

print("Arabic - TRAIN numeric tokens:", ar_numbers_train)
print("Arabic - VAL numeric tokens:",   ar_numbers_val)
print("Arabic - TRAIN hyphenated tokens:", ar_hyphen_train)
print("Arabic - VAL hyphenated tokens:",   ar_hyphen_val)

# After tokenization for Korean
ko_numbers_train = sum(1 for t in ko_train_tokens if t.isdigit())
ko_numbers_val   = sum(1 for t in ko_val_tokens if t.isdigit())

ko_hyphen_train = sum(1 for t in ko_train_tokens if "-" in t)
ko_hyphen_val   = sum(1 for t in ko_val_tokens if "-" in t)

print("Korean — TRAIN numeric tokens:", ko_numbers_train)
print("Korean — VAL numeric tokens:",   ko_numbers_val)
print("Korean — TRAIN hyphenated tokens:", ko_hyphen_train)
print("Korean — VAL hyphenated tokens:",   ko_hyphen_val)

# After tokenization for Telugu
te_numbers_train = sum(1 for t in te_train_tokens if t.isdigit())
te_numbers_val   = sum(1 for t in te_val_tokens if t.isdigit())

te_hyphen_train = sum(1 for t in te_train_tokens if "-" in t)
te_hyphen_val   = sum(1 for t in te_val_tokens if "-" in t)

print("Telugu - TRAIN numeric tokens:", te_numbers_train)
print("Telugu - VAL numeric tokens:",   te_numbers_val)
print("Telugu - TRAIN hyphenated tokens:", te_hyphen_train)
print("Telugu - VAL hyphenated tokens:",   te_hyphen_val)


Arabic - TRAIN numeric tokens: 78
Arabic - VAL numeric tokens: 11
Arabic - TRAIN hyphenated tokens: 3
Arabic - VAL hyphenated tokens: 0
Korean — TRAIN numeric tokens: 9
Korean — VAL numeric tokens: 1
Korean — TRAIN hyphenated tokens: 5
Korean — VAL hyphenated tokens: 1
Telugu - TRAIN numeric tokens: 107
Telugu - VAL numeric tokens: 39
Telugu - TRAIN hyphenated tokens: 0
Telugu - VAL hyphenated tokens: 0


In [7]:
#5 Most common words (not counting punctuation); with English translations and their count

translator = Translator()

# ARABIC
# (skip pure numbers)
ar_counts = Counter([t.lower() for t in ar_train_tokens if t ])
ar_top5 = ar_counts.most_common(5)

print("Arabic — Top 5 most common words (TRAIN):")
for w, c in ar_top5:
    try:
        en = translator.translate(w, src='ar', dest='en').text
    except Exception as e:
        en = f"[translation error: {e}]"
    print(f"{w}\tcount={c}\t→ {en}")

# KOREAN
ko_counts = Counter([t.lower() for t in ko_train_tokens if t ])
ko_top5 = ko_counts.most_common(5)

print("\nKorean — Top 5 most common words (TRAIN):")
for w, c in ko_top5:
    try:
        en = translator.translate(w, src='ko', dest='en').text
    except Exception as e:
        en = f"[translation error: {e}]"
    print(f"{w}\tcount={c}\t→ {en}")

#  TELUGU
te_counts = Counter([t.lower() for t in te_train_tokens if t])
te_top5 = te_counts.most_common(5)

print("\nTelugu — Top 5 most common words (TRAIN):")
for w, c in te_top5:
    try:
        en = translator.translate(w, src='te', dest='en').text
    except Exception as e:
        en = f"[translation error: {e}]"
    print(f"{w}\tcount={c}\t→ {en}")


' #5 Most common words (not counting punctuation); with English translations and their count\n\ntranslator = Translator()\n\n# ARABIC\n# (skip pure numbers)\nar_counts = Counter([t.lower() for t in ar_train_tokens if t ])\nar_top5 = ar_counts.most_common(5)\n\nprint("Arabic — Top 5 most common words (TRAIN):")\nfor w, c in ar_top5:\n    try:\n        en = translator.translate(w, src=\'ar\', dest=\'en\').text\n    except Exception as e:\n        en = f"[translation error: {e}]"\n    print(f"{w}\tcount={c}\t→ {en}")\n\n# KOREAN\nko_counts = Counter([t.lower() for t in ko_train_tokens if t ])\nko_top5 = ko_counts.most_common(5)\n\nprint("\nKorean — Top 5 most common words (TRAIN):")\nfor w, c in ko_top5:\n    try:\n        en = translator.translate(w, src=\'ko\', dest=\'en\').text\n    except Exception as e:\n        en = f"[translation error: {e}]"\n    print(f"{w}\tcount={c}\t→ {en}")\n\n#  TELUGU\nte_counts = Counter([t.lower() for t in te_train_tokens if t])\nte_top5 = te_counts.most_

### We conclude the words are "stop words" that we learned in the lecture

In [17]:
# Stats about answerable vs unanswerable questions

# Define languages and splits

split_dfs = {
    "train": df_train,
    "val":   df_val
}


rows = []
for split_name, df in split_dfs.items():
    for lang in langs:
        total = df[df["lang"] == lang].shape[0]
        ans   = df[(df["lang"] == lang) & (df["answerable"])].shape[0]
        unans = total - ans
        ratio = ans / total if total > 0 else 0
        rows.append([split_name, lang, total, ans, unans, ratio])

# Create summary DataFrame
summary = pd.DataFrame(rows, columns=["Split", "Language", "Total", "Answerable", "Unanswerable", "Answerable Ratio"])
print(summary.to_string(index=False))


Split Language  Total  Answerable  Unanswerable  Answerable Ratio
train       ar   2558        2303           255          0.900313
train       ko   2422        2359            63          0.973988
train       te   1355        1310            45          0.966790
  val       ar    415         363            52          0.874699
  val       ko    356         337            19          0.946629
  val       te    384         291            93          0.757812


## RULE BASE CLASSIFIER

In [20]:
import os
import numpy as np
import pandas as pd
import regex as re
import string
#from unidecode import unidecode
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords

import nltk
nltk.download("stopwords", quiet=True)
# Needed for word_tokenize
try:
    nltk.download("punkt", quiet=True)
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass

EN_STOP = set(stopwords.words('english')) | set(string.punctuation)
LANGS = ["ar", "ko", "te"]


In [21]:
import torch
from transformers import pipeline

print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
DEVICE = 0 if torch.cuda.is_available() else -1
print("Using device idx for HF pipeline:", DEVICE)

MODEL_ID = "facebook/nllb-200-distilled-600M"
SRC_CODES = {"ar": "arb_Arab", "ko": "kor_Hang", "te": "tel_Telu"}
TGT_CODE = "eng_Latn"

# One translator pipeline reused for all batches
nllb = pipeline("translation", model=MODEL_ID, tokenizer=MODEL_ID, device=DEVICE)


PyTorch: 2.8.0+cu126
CUDA available: True
Using device idx for HF pipeline: 0


config.json:   0%|          | 0.00/846 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.46G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.46G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/564 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/4.85M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.3M [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Device set to use cuda:0


In [24]:
from typing import List, Iterable, Dict, Optional
from tqdm import tqdm

def translate_list_pipe(texts: List[str],
                        src_lang: str,
                        batch_size: int = 128,
                        max_length: int = 320,
                        show_progress: bool = True) -> List[str]:

    if src_lang not in SRC_CODES:
        raise ValueError(f"Unknown lang code '{src_lang}'. Expected one of {list(SRC_CODES)}")

    outputs = []
    iterator = range(0, len(texts), batch_size)
    if show_progress:
        iterator = tqdm(iterator, total=(len(texts) + batch_size - 1)//batch_size, desc=f"Translating {src_lang}->EN")

    # Iterate over batches and translate
    for i in iterator:
        batch = [x if isinstance(x, str) else "" for x in texts[i:i+batch_size]]
        try:
            preds = nllb(
                batch,
                src_lang=SRC_CODES[src_lang],
                tgt_lang=TGT_CODE,
                truncation=True,
                max_length=max_length
            )
            # Extract translated strings
            outputs.extend([p.get("translation_text", "") for p in preds])
        except Exception as e:
            print(f"[WARN] Batch {i}:{i+len(batch)} failed: {type(e).__name__}: {e}")
            outputs.extend([""] * len(batch))
    return outputs

In [26]:
def ensure_column(df: pd.DataFrame, col: str):
    if col not in df.columns:
        df[col] = pd.Series([np.nan]*len(df), index=df.index)

# Ensure the target columns exist
ensure_column(df_train, "question_en")
ensure_column(df_val,   "question_en")

TRANSLATE_CONTEXT = False
if TRANSLATE_CONTEXT:
    ensure_column(df_train, "context_en")
    ensure_column(df_val,   "context_en")

def cache_translations(df: pd.DataFrame,
                       text_col: str,
                       out_col: str,
                       langs: Iterable[str] = LANGS,
                       batch_size: int = 128):

    for lg in langs:
        mask_lang = (df["lang"] == lg)
        mask_need = df[out_col].isna() | (df[out_col].astype(str).str.strip() == "")
        mask = mask_lang & mask_need
        if not mask.any():
            print(f"[{text_col}] '{lg}' - already cached, skipping.")
            continue

        texts = df.loc[mask, text_col].astype(str).tolist()
        print(f"[{text_col}] Translating {sum(mask)} rows for lang='{lg}'...")
        df.loc[mask, out_col] = translate_list_pipe(texts, src_lang=lg, batch_size=batch_size)

# Cache questions
cache_translations(df_train, text_col="question", out_col="question_en", batch_size=128)
cache_translations(df_val,   text_col="question", out_col="question_en", batch_size=128)

# cache contexts
if TRANSLATE_CONTEXT:
    cache_translations(df_train, text_col="context", out_col="context_en", batch_size=128)
    cache_translations(df_val,   text_col="context", out_col="context_en", batch_size=128)

# Persist to disk to not have to re-translate because it's very slow
df_train.to_parquet("df_train_translated.parquet")
df_val.to_parquet("df_val_translated.parquet")
print("Saved: df_train_translated.parquet, df_val_translated.parquet")

[question] Translating 2558 rows for lang='ar'...


Translating ar->EN:   0%|          | 0/20 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [None]:
def sanity_check_translations(df, text_col="question", trans_col="question_en", langs=LANGS, n=3):
    # Prints n sample original vs. translated texts for each language to verify translation
    for lg in langs:
        subset = df[df["lang"] == lg].dropna(subset=[text_col, trans_col])
        if subset.empty:
            print(f"[WARN] No data for lang '{lg}' in {text_col}")
            continue

        print(f"\n=== {lg.upper()} → EN (showing {n} samples) ===")
        for i, row in subset.head(n).iterrows():
            print("SRC:", row[text_col])
            print("EN :", row[trans_col])
            print("-"*40)

# Run sanity check on train set
sanity_check_translations(df_train, text_col="question", trans_col="question_en", langs=LANGS, n=3)

# Check val set
sanity_check_translations(df_val, text_col="question", trans_col="question_en", langs=LANGS, n=3)

[WARN] No data for lang 'ar' in question
[WARN] No data for lang 'ko' in question
[WARN] No data for lang 'te' in question
[WARN] No data for lang 'ar' in question
[WARN] No data for lang 'ko' in question
[WARN] No data for lang 'te' in question


In [29]:
import pandas as pd
import numpy as np
import regex as re
import string
import nltk
from nltk.corpus import stopwords


nltk.download('stopwords', quiet=True)
try:
    nltk.download('punkt', quiet=True)
except Exception:
    pass

STOP_WORDS = set(stopwords.words('english')) | set(string.punctuation)


def pick_cols(df, translate_contexts=False):
    q_col = 'question_en' if 'question_en' in df.columns else 'question'
    if translate_contexts and 'context_en' in df.columns:
        c_col = 'context_en'
    else:
        c_col = 'context'
    return q_col, c_col


def tokenize(text: str):
    tokens = re.split(r'\W+', str(text) if text is not None else "")
    return [t.lower() for t in tokens if t and t.lower() not in STOP_WORDS]

def overlap_score_question(question: str, context: str):
    q_toks = tokenize(question)
    c_toks = tokenize(context)
    if not q_toks:
        return 0.0, 0
    matched = set()
    for q in q_toks:
        for c in c_toks:
            if q == c or (q in c) or (c in q):
                matched.add(q)
                break
    matches = len(matched)
    ratio = matches / max(1, len(q_toks))
    return ratio, matches

def tune_parameters(train_df, q_col, c_col,
                    match_grid=(1,2,3,4,5,6,7,8,9,10),
                    thr_grid=(0.3,0.4,0.5,0.6,0.7,0.8,0.9)):
    data = [(overlap_score_question(getattr(r, q_col), getattr(r, c_col)), int(r.answerable))
            for r in train_df.itertuples(index=False)]
    best_acc, best_k, best_thr = 0.0, 1, 0.5
    for k in match_grid:
        for thr in thr_grid:
            correct = 0
            for (ratio, m), y in data:
                pred = int((m >= k) and (ratio >= thr))
                correct += (pred == y)
            acc = correct / len(data) if data else 0.0
            if acc > best_acc:
                best_acc, best_k, best_thr = acc, k, thr
    return {"min_match_count": best_k, "min_ratio_threshold": best_thr, "best_train_acc": best_acc}


def eval_metrics(df, q_col, c_col, min_matches, ratio_threshold):
    y_true, y_pred = [], []
    for r in df.itertuples(index=False):
        ratio, m = overlap_score_question(getattr(r, q_col), getattr(r, c_col))
        y_true.append(int(r.answerable))
        y_pred.append(int((m >= min_matches) and (ratio >= ratio_threshold)))
    y_true = np.asarray(y_true, int)
    y_pred = np.asarray(y_pred, int)

    tp = int(((y_pred==1) & (y_true==1)).sum())
    fp = int(((y_pred==1) & (y_true==0)).sum())
    fn = int(((y_pred==0) & (y_true==1)).sum())
    tn = int(((y_pred==0) & (y_true==0)).sum())

    acc  = (tp+tn)/max(1, tp+tn+fp+fn)
    prec = tp/max(1, tp+fp)
    rec  = tp/max(1, tp+fn)
    f1   = 0.0 if (prec+rec)==0 else 2*prec*rec/(prec+rec)

    return {
        "acc": round(acc, 4),
        "prec": round(prec, 4),
        "rec": round(rec, 4),
        "f1": round(f1, 4),
        "cm": {"TP": tp, "FP": fp, "FN": fn, "TN": tn}
    }


def run_rule_classifier(df_train, df_val, translate_contexts=False):
    results = {}
    for code, name in [("ar","Arabic"), ("ko","Korean"), ("te","Telugu")]:
        tr = df_train[df_train["lang"] == code].copy()
        va = df_val[df_val["lang"] == code].copy()
        if tr.empty or va.empty:
            results[name] = {
                "train_acc": None, "val_acc": None, "val_prec": None, "val_rec": None, "val_f1": None,
                "min_matches": None, "min_ratio": None, "cm": None, "n_train": len(tr), "n_val": len(va)
            }
            continue

        q_col_tr, c_col_tr = pick_cols(tr, translate_contexts=translate_contexts)
        q_col_va, c_col_va = pick_cols(va, translate_contexts=translate_contexts)

        params = tune_parameters(tr, q_col_tr, c_col_tr)
        metrics_val = eval_metrics(va, q_col_va, c_col_va,
                                   params["min_match_count"], params["min_ratio_threshold"])

        results[name] = {
            "n_train": len(tr),
            "n_val": len(va),
            "train_acc": round(params["best_train_acc"], 4),
            "val_acc": metrics_val["acc"],
            "val_prec": metrics_val["prec"],
            "val_rec": metrics_val["rec"],
            "val_f1": metrics_val["f1"],
            "min_matches": params["min_match_count"],
            "min_ratio": params["min_ratio_threshold"],
            "cm": metrics_val["cm"],
        }


    summary = pd.DataFrame({
        lang: {k:v for k,v in res.items() if k not in ("cm",)}
        for lang, res in results.items()
    }).T
    print(summary.to_string())

    # Also print confusion matrices
    print("\nConfusion matrices:")
    for lang, res in results.items():
        print(f"{lang}: {res['cm']}")

    return results


_ = run_rule_classifier(df_train, df_val, translate_contexts=False)


        n_train  n_val  train_acc  val_acc  val_prec  val_rec  val_f1  min_matches  min_ratio
Arabic   2558.0  415.0     0.1501   0.1807    0.8966   0.0716  0.1327          1.0        0.3
Korean   2422.0  356.0     0.0995   0.1292    0.9355   0.0861  0.1576          1.0        0.3
Telugu   1355.0  384.0     0.1461   0.2500    0.5366   0.0756  0.1325          1.0        0.3

Confusion matrices:
Arabic: {'TP': 26, 'FP': 3, 'FN': 337, 'TN': 49}
Korean: {'TP': 29, 'FP': 2, 'FN': 308, 'TN': 17}
Telugu: {'TP': 22, 'FP': 19, 'FN': 269, 'TN': 74}
