## Setup: device, imports, paths, config

In [2]:
# 01 - Setup: device, imports, paths, config

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import random
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Optional

import torch

# ---- Device ----
RUN_DEVICE = "gpu"  # "gpu" or "cpu"

if RUN_DEVICE.lower() == "gpu" and torch.cuda.is_available():
    device = torch.device("cuda")
    torch.backends.cudnn.benchmark = True
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    torch.backends.cudnn.enabled = False
    torch.set_num_threads(max(1, os.cpu_count() // 2))
    print("Using CPU")

# ---- Seeds ----
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device.type == "cuda":
    torch.cuda.manual_seed_all(SEED)

from sklearn.metrics import f1_score, classification_report, confusion_matrix

# ---- Transformers ----
from transformers import AutoTokenizer, AutoModelForCausalLM

# ---- WSD: NLTK WordNet (custom simple Lesk) ----
import nltk

# Make sure WordNet + stopwords are available
try:
    nltk.data.find("corpora/wordnet")
except LookupError:
    print("Downloading NLTK WordNet data...")
    nltk.download("wordnet")
    nltk.download("omw-1.4")

try:
    nltk.data.find("corpora/stopwords")
except LookupError:
    print("Downloading NLTK stopwords data...")
    nltk.download("stopwords")

from nltk.corpus import wordnet as wn
from nltk.corpus import stopwords
import string

STOP_WORDS = set(stopwords.words("english"))

# ---- Paths ----
DATA_DIR = Path("SemEval_2022_Task2-idiomaticity/SubTaskA")
TRAIN_ONE_SHOT = DATA_DIR / "Data" / "train_one_shot.csv"
TRAIN_ZERO_SHOT = DATA_DIR / "Data" / "train_zero_shot.csv"
DEV = DATA_DIR / "Data" / "dev.csv"
DEV_GOLD = DATA_DIR / "Data" / "dev_gold.csv"
EVAL = DATA_DIR / "Data" / "eval.csv"

OUT_DIR = Path("outputs_en_llm_wsd")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ---- LLM config ----
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"  # adjust if you want a smaller one

# Larger batch if GPU
BATCH_GEN = 8 if device.type == "cpu" else 32

# Toggle WSD usage
USE_WSD = True
print("USE_WSD =", USE_WSD)
print("BATCH_GEN =", BATCH_GEN)


Using GPU: NVIDIA H100 80GB HBM3 MIG 2g.20gb


  from .autonotebook import tqdm as notebook_tqdm


Downloading NLTK WordNet data...


[nltk_data] Downloading package wordnet to /home/mhossai6/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/mhossai6/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


USE_WSD = True
BATCH_GEN = 32


## Data loading, context utils, WSD helpers, EN loaders

In [3]:
# 02 - Data loading, context, WSD, EN train/dev/eval helpers

def load_any_csv(path: Path) -> pd.DataFrame:
    return pd.read_csv(path, sep=None, engine="python", dtype=str)

def ensure_label_int(df: pd.DataFrame, col="Label") -> pd.DataFrame:
    if col in df.columns:
        df[col] = df[col].astype(int)
    return df

def mark_first_case_insensitive(text: str, needle: str, ltag="<mwe>", rtag="</mwe>") -> str:
    if not isinstance(text, str) or not isinstance(needle, str):
        return text
    lt = text.lower()
    ln = needle.lower()
    i = lt.find(ln)
    if i == -1:
        return text
    return text[:i] + ltag + text[i:i+len(needle)] + rtag + text[i+len(needle):]

def pack_context(prev: str, target: str, nxt: str, mwe: str) -> str:
    prev = "" if pd.isna(prev) else prev
    nxt = "" if pd.isna(nxt) else nxt
    target = "" if pd.isna(target) else target
    tgt_marked = mark_first_case_insensitive(target, mwe)
    return f"Previous: {prev}\nTarget: {tgt_marked}\nNext: {nxt}"

# ---- WSD helpers (custom simple Lesk using NLTK WordNet) ----

def build_simple_sentence(row) -> str:
    prev = row.get("Previous", "")
    tgt = row.get("Target", "")
    nxt = row.get("Next", "")

    def _clean(x):
        if isinstance(x, str):
            return x
        if pd.isna(x):
            return ""
        return str(x)

    prev = _clean(prev)
    tgt = _clean(tgt)
    nxt = _clean(nxt)
    return " ".join([prev, tgt, nxt]).strip()


def get_mwe_head(mwe: str) -> str:
    if not isinstance(mwe, str):
        return ""
    toks = mwe.split()
    return toks[-1] if toks else ""


def simple_lesk_nltk(context_sentence: str, ambiguous_word: str):
    """
    Very small Lesk-style WSD using NLTK WordNet.
    Returns a Synset or None.
    """
    if not context_sentence or not ambiguous_word:
        return None

    tokens = [
        w.strip(string.punctuation).lower()
        for w in context_sentence.split()
    ]
    context = [w for w in tokens if w and w not in STOP_WORDS]

    synsets = wn.synsets(ambiguous_word)
    if not synsets:
        return None

    best_syn = None
    max_overlap = 0

    for syn in synsets:
        sig_tokens = syn.definition().split()
        for ex in syn.examples():
            sig_tokens += ex.split()

        sig_tokens = [
            w.strip(string.punctuation).lower()
            for w in sig_tokens
        ]
        signature = [w for w in sig_tokens if w and w not in STOP_WORDS]

        overlap = len(set(signature) & set(context))
        if overlap > max_overlap:
            max_overlap = overlap
            best_syn = syn

    return best_syn


def annotate_with_wsd(df: pd.DataFrame) -> pd.DataFrame:
    """
    Adds SenseID and SenseGloss columns for EN rows using simple_lesk_nltk.
    """
    df = df.copy()
    sense_ids = []
    sense_glosses = []

    for _, row in df.iterrows():
        sent = build_simple_sentence(row)
        mwe = row.get("MWE", "") or ""
        head = get_mwe_head(mwe)

        if not sent or not head:
            sense_ids.append("")
            sense_glosses.append("")
            continue

        try:
            synset = simple_lesk_nltk(sent, head)
        except Exception:
            synset = None

        if synset is None:
            sense_ids.append("")
            sense_glosses.append("")
        else:
            sense_ids.append(synset.name())
            sense_glosses.append(synset.definition())

    df["SenseID"] = sense_ids
    df["SenseGloss"] = sense_glosses
    return df


def load_train_dev(language="EN", oneshot=True) -> Tuple[pd.DataFrame, pd.DataFrame]:
    if oneshot:
        train_df = load_any_csv(TRAIN_ONE_SHOT)
    else:
        train_df = load_any_csv(TRAIN_ZERO_SHOT)
    dev_df = load_any_csv(DEV)
    gold_df = load_any_csv(DEV_GOLD)

    train_df.columns = [c.strip() for c in train_df.columns]
    dev_df.columns = [c.strip() for c in dev_df.columns]
    gold_df.columns = [c.strip() for c in gold_df.columns]

    train_df = train_df[train_df["Language"] == language].copy()
    dev_df = dev_df[dev_df["Language"] == language].copy()

    gold = gold_df[gold_df["Language"] == language][["ID", "Label"]].copy()
    gold["ID"] = gold["ID"].astype(str)
    dev_df["ID"] = dev_df["ID"].astype(str)
    dev_lab = dev_df.merge(gold, on="ID", how="left")

    train_df = ensure_label_int(train_df, "Label")
    dev_lab = ensure_label_int(dev_lab, "Label")

    if USE_WSD:
        print(f"Annotating train/dev ({language}, oneshot={oneshot}) with WSD...")
        train_df = annotate_with_wsd(train_df)
        dev_lab = annotate_with_wsd(dev_lab)

    return train_df, dev_lab


def load_eval(language="EN") -> pd.DataFrame:
    df = load_any_csv(EVAL)
    df.columns = [c.strip() for c in df.columns]
    df = df[df["Language"] == language].copy()

    if USE_WSD:
        print(f"Annotating eval ({language}) with WSD...")
        df = annotate_with_wsd(df)

    return df


## Load Qwen model & tokenizer

In [4]:
# 03 - Load Qwen model & tokenizer

# If you need HF login, you can uncomment this and set HF_TOKEN env var.
# from huggingface_hub import login as hf_login
# HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
# if HF_TOKEN:
#     hf_login(token=HF_TOKEN)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    use_fast=True,
    trust_remote_code=True,
)

if device.type == "cuda":
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=dtype,
        device_map="auto",
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float32,
        device_map={"": "cpu"},
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )

if tokenizer.pad_token_id is None:
    if tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token
    else:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id

model.eval()
print("Loaded LLM:", MODEL_NAME)


`torch_dtype` is deprecated! Use `dtype` instead!
2025-11-26 23:44:23.609503: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-26 23:44:23.623598: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764229463.633895  503544 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764229463.636564  503544 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1764229463.645714  503544 computation_placer.cc:177] computation placer already r

Loaded LLM: Qwen/Qwen2.5-7B-Instruct


## Chat/plain encoding & logits-based 0/1 classification

In [5]:
# 04 - Helpers: chat/plain encoding & logits-based 0/1 classification

def _apply_chat_or_plain_batch(texts: list, max_len: int = 512) -> dict:
    """
    Tokenize a batch of prompts, with an explicit max_len to avoid huge sequences.
    """
    if hasattr(tokenizer, "apply_chat_template"):
        messages_batch = [[
            {"role": "system", "content": "You are a concise classifier."},
            {"role": "user", "content": t}
        ] for t in texts]
        input_ids = tokenizer.apply_chat_template(
            messages_batch,
            add_generation_prompt=True,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_len,  # <<< important
        )
    else:
        input_ids = tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_len,  # <<< important
        ).input_ids

    attention_mask = (input_ids != tokenizer.pad_token_id).long()
    return {
        "input_ids": input_ids.to(device),
        "attention_mask": attention_mask.to(device),
    }


_id0 = None
_id1 = None

def _candidate_token_id_for_digit(d: str) -> Optional[int]:
    ids = tokenizer.encode(d, add_special_tokens=False)
    if len(ids) == 1:
        return ids[0]
    ids = tokenizer.encode(" " + d, add_special_tokens=False)
    if len(ids) == 1:
        return ids[0]
    ids = tokenizer.encode(d + "\n", add_special_tokens=False)
    if len(ids) == 1:
        return ids[0]
    return None

def _init_digit_ids():
    global _id0, _id1
    if _id0 is None:
        _id0 = _candidate_token_id_for_digit("0")
    if _id1 is None:
        _id1 = _candidate_token_id_for_digit("1")

_init_digit_ids()

def classify_prompts_logits(prompts: list) -> list:
    """
    Returns list of 0/1 predictions using next-token logits; falls back to
    generate-and-parse if 0/1 token IDs are not found.
    """
    # if you want you can pass a smaller max_len here, e.g. 384
    enc = _apply_chat_or_plain_batch(prompts, max_len=512)

    with torch.no_grad():
        logits = model(**enc).logits  # [B, T, V]
        next_logits = logits[:, -1, :]
        if _id0 is not None and _id1 is not None:
            logit0 = next_logits[:, _id0]
            logit1 = next_logits[:, _id1]
            return (logit1 >= logit0).long().detach().cpu().tolist()

    # Fallback: generate one token, parse 0/1 from text
    outs = []
    with torch.no_grad():
        gen = model.generate(
            **enc,
            max_new_tokens=1,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    for i in range(gen.size(0)):
        cut = enc["input_ids"][i].shape[-1]
        new_ids = gen[i][cut:]
        text = tokenizer.decode(new_ids, skip_special_tokens=True)
        if "0" in text and "1" in text:
            outs.append(1 if text.index("1") < text.index("0") else 0)
        elif "1" in text:
            outs.append(1)
        elif "0" in text:
            outs.append(0)
        else:
            outs.append(1)  # default bias towards idiomatic
    return outs


## Zero-shot EN with optional WSD

In [7]:
# 05 - Zero-shot EN with optional WSD

from time import perf_counter
from tqdm.auto import tqdm

def make_zero_shot_prompt(mwe: str, ctx: str) -> str:
    return f"""You are a classifier that decides whether a multiword expression (MWE) is used literally (0) or idiomatically (1).

MWE: {mwe}

Context:
{ctx}

Answer with a single digit: 0 for literal, 1 for idiomatic."""

def make_zero_shot_prompt_wsd(mwe: str, ctx: str, sense_gloss: str) -> str:
    sense_gloss = (sense_gloss or "").strip()
    extra = ""
    if sense_gloss:
        extra = f"""\nWe also have a WordNet sense definition for this expression in this context:
"{sense_gloss}"."""
    return f"""You are a classifier that decides whether a multiword expression (MWE) is used literally (0) or idiomatically (1).

MWE: {mwe}{extra}

Context:
{ctx}

Answer with a single digit: 0 for literal, 1 for idiomatic."""

def _load_cache(cache_path: Path) -> dict:
    if cache_path.exists():
        df = pd.read_csv(cache_path, dtype={"ID": str, "Label": int})
        return dict(zip(df["ID"].astype(str), df["Label"].astype(int)))
    return {}

def _append_one(cache_path: Path, rec: tuple):
    _id, _lab = rec
    header_needed = not cache_path.exists()
    with open(cache_path, "a") as f:
        if header_needed:
            f.write("ID,Label\n")
        f.write(f"{_id},{int(_lab)}\n")

def progressive_predict_zero_shot_batched(df: pd.DataFrame, cache_path: Path, desc: str = "Zero-shot EN") -> list:
    df = df.copy()
    df["ID"] = df["ID"].astype(str)

    preds_map = _load_cache(cache_path)
    done = set(preds_map.keys())
    todo_idx = [i for i, _id in enumerate(df["ID"]) if _id not in done]

    print(f"{desc} | Resuming with {len(done)} cached / {len(df)} total")
    t0 = perf_counter()

    for start in tqdm(range(0, len(todo_idx), BATCH_GEN), desc=desc, leave=True):
        batch_rows = todo_idx[start:start+BATCH_GEN]
        prompts, ids = [], []
        for j in batch_rows:
            r = df.iloc[j]
            _id = r["ID"]
            mwe = r["MWE"]
            ctx = pack_context(r.get("Previous",""), r.get("Target",""), r.get("Next",""), mwe)
            if USE_WSD:
                gloss = (r.get("SenseGloss", "") or "").strip()
                prompt = make_zero_shot_prompt_wsd(mwe, ctx, gloss)
            else:
                prompt = make_zero_shot_prompt(mwe, ctx)
            prompts.append(prompt)
            ids.append(_id)
        if not prompts:
            continue
        labels = classify_prompts_logits(prompts)
        for _id, lab in zip(ids, labels):
            preds_map[_id] = int(lab)
            _append_one(cache_path, (_id, int(lab)))

    elapsed = perf_counter() - t0
    print(f"{desc} | Newly computed: {len(todo_idx)} | Cached at start: {len(done)} | Total: {len(df)} | "
          f"Elapsed: {elapsed:.1f}s | {(elapsed/max(1,len(todo_idx))):.3f}s/example (new only)")

    yhat = [preds_map[str(i)] for i in df["ID"]]
    return yhat

# ---- Run zero-shot EN (dev + eval) ----

train_0s_en, dev_0s_en = load_train_dev(language="EN", oneshot=False)

cache_dev_0s = OUT_DIR / f"cache_llm_zeroshot_dev_en_{'wsd' if USE_WSD else 'baseline'}.csv"
yhat_dev_0s = progressive_predict_zero_shot_batched(dev_0s_en, cache_dev_0s, desc=f"Zero-shot EN (dev, USE_WSD={USE_WSD})")
ytrue_dev_0s = dev_0s_en["Label"].tolist()
f1_0s = f1_score(ytrue_dev_0s, yhat_dev_0s, average="macro")
print(f"[LLM Zero-shot EN] Dev macro-F1: {f1_0s:.4f}")
print(classification_report(ytrue_dev_0s, yhat_dev_0s, digits=4))
print(confusion_matrix(ytrue_dev_0s, yhat_dev_0s))

eval_en = load_eval(language="EN")
cache_eval_0s = OUT_DIR / f"cache_llm_zeroshot_eval_en_{'wsd' if USE_WSD else 'baseline'}.csv"
yhat_eval_0s = progressive_predict_zero_shot_batched(eval_en, cache_eval_0s, desc=f"Zero-shot EN (eval, USE_WSD={USE_WSD})")

sub_0s = pd.DataFrame({
    "ID": eval_en["ID"].astype(str),
    "Language": eval_en["Language"],
    "Setting": ["zero_shot"] * len(eval_en),
    "Label": yhat_eval_0s
})
sub_0s_path = OUT_DIR / f"eval_submission_en_llm_zeroshot_{'wsd' if USE_WSD else 'baseline'}.csv"
sub_0s.to_csv(sub_0s_path, index=False)
print(f"Wrote {sub_0s_path}")


Annotating train/dev (EN, oneshot=False) with WSD...
Zero-shot EN (dev, USE_WSD=True) | Resuming with 466 cached / 466 total


Zero-shot EN (dev, USE_WSD=True): 0it [00:00, ?it/s]


Zero-shot EN (dev, USE_WSD=True) | Newly computed: 0 | Cached at start: 466 | Total: 466 | Elapsed: 0.0s | 0.001s/example (new only)
[LLM Zero-shot EN] Dev macro-F1: 0.4145
              precision    recall  f1-score   support

           0     0.3548    0.6044    0.4472       182
           1     0.5385    0.2958    0.3818       284

    accuracy                         0.4163       466
   macro avg     0.4467    0.4501    0.4145       466
weighted avg     0.4667    0.4163    0.4073       466

[[110  72]
 [200  84]]
Annotating eval (EN) with WSD...
Zero-shot EN (eval, USE_WSD=True) | Resuming with 483 cached / 483 total


Zero-shot EN (eval, USE_WSD=True): 0it [00:00, ?it/s]

Zero-shot EN (eval, USE_WSD=True) | Newly computed: 0 | Cached at start: 483 | Total: 483 | Elapsed: 0.0s | 0.001s/example (new only)
Wrote outputs_en_llm_wsd/eval_submission_en_llm_zeroshot_wsd.csv





## One-shot EN with optional WSD

In [8]:
# 06 - One-shot EN with optional WSD (self-contained, independent of zero-shot cell)

BATCH_GEN = 4   # smaller batch to avoid OOM for long one-shot + WSD prompts

from time import perf_counter
from tqdm.auto import tqdm

def _load_cache(cache_path: Path) -> dict:
    if cache_path.exists():
        df = pd.read_csv(cache_path, dtype={"ID": str, "Label": int})
        return dict(zip(df["ID"].astype(str), df["Label"].astype(int)))
    return {}

def _append_one(cache_path: Path, rec: tuple):
    _id, _lab = rec
    header_needed = not cache_path.exists()
    with open(cache_path, "a") as f:
        if header_needed:
            f.write("ID,Label\n")
        f.write(f"{_id},{int(_lab)}\n")


def build_oneshot_index(train_df: pd.DataFrame) -> dict:
    """
    For each MWE, store one positive (label=1) and one negative (label=0) example.
    """
    idx = {}
    for _, row in train_df.iterrows():
        mwe = row["MWE"]
        lab = int(row["Label"])
        ctx = pack_context(row.get("Previous",""), row.get("Target",""), row.get("Next",""), mwe)
        if mwe not in idx:
            idx[mwe] = {0: None, 1: None}
        if idx[mwe][lab] is None:
            idx[mwe][lab] = {"context": ctx}
    return idx


def pick_global_oneshot_fallback(train_df: pd.DataFrame) -> dict:
    """
    Pick one global positive and one global negative example as fallback.
    """
    fallback = {0: None, 1: None}
    for lab in [0, 1]:
        rows = train_df[train_df["Label"] == lab]
        if len(rows) > 0:
            row = rows.iloc[0]
            mwe = row["MWE"]
            ctx = pack_context(row.get("Previous",""), row.get("Target",""), row.get("Next",""), mwe)
            fallback[lab] = {"context": ctx}
    # in case one class missing, duplicate the other
    if fallback[0] is None and fallback[1] is not None:
        fallback[0] = fallback[1]
    if fallback[1] is None and fallback[0] is not None:
        fallback[1] = fallback[0]
    return fallback


def make_one_shot_prompt(mwe: str, pos_ctx: str, neg_ctx: str, test_ctx: str) -> str:
    return f"""You are a classifier that decides whether a multiword expression (MWE) is used literally (0) or idiomatically (1).

First, see two labeled examples for the same MWE.

Example A (idiomatic, label 1):
{pos_ctx}

Example B (literal, label 0):
{neg_ctx}

Now classify the following new occurrence of the same MWE:

MWE: {mwe}

Context:
{test_ctx}

Answer with a single digit: 0 for literal, 1 for idiomatic."""


def make_one_shot_prompt_wsd(mwe: str, pos_ctx: str, neg_ctx: str, test_ctx: str, sense_gloss: str) -> str:
    sense_gloss = (sense_gloss or "").strip()
    extra = ""
    if sense_gloss:
        extra = f"""\nWe also have a WordNet sense definition for this new occurrence:
"{sense_gloss}"."""
    return f"""You are a classifier that decides whether a multiword expression (MWE) is used literally (0) or idiomatically (1).

First, see two labeled examples for the same MWE.

Example A (idiomatic, label 1):
{pos_ctx}

Example B (literal, label 0):
{neg_ctx}

Now classify the following new occurrence of the same MWE:

MWE: {mwe}{extra}

Context:
{test_ctx}

Answer with a single digit: 0 for literal, 1 for idiomatic."""


def progressive_predict_one_shot_batched(df: pd.DataFrame,
                                         oneshot_index: dict,
                                         global_pool: dict,
                                         cache_path: Path,
                                         desc: str = "One-shot EN") -> list:
    df = df.copy()
    df["ID"] = df["ID"].astype(str)

    preds_map = _load_cache(cache_path)
    done = set(preds_map.keys())
    todo_idx = [i for i, _id in enumerate(df["ID"]) if _id not in done]

    print(f"{desc} | Resuming with {len(done)} cached / {len(df)} total")
    t0 = perf_counter()

    for start in tqdm(range(0, len(todo_idx), BATCH_GEN), desc=desc, leave=True):
        batch_rows = todo_idx[start:start+BATCH_GEN]
        prompts, ids = [], []
        for j in batch_rows:
            r = df.iloc[j]
            _id = r["ID"]
            mwe = r["MWE"]
            test_ctx = pack_context(r.get("Previous",""), r.get("Target",""), r.get("Next",""), mwe)

            entry = oneshot_index.get(mwe, {0: None, 1: None})
            pos_ctx = entry.get(1, {}).get("context") if entry.get(1) else None
            neg_ctx = entry.get(0, {}).get("context") if entry.get(0) else None

            if pos_ctx is None:
                pos_ctx = global_pool[1]["context"]
            if neg_ctx is None:
                neg_ctx = global_pool[0]["context"]

            if USE_WSD:
                gloss = (r.get("SenseGloss", "") or "").strip()
                prompt = make_one_shot_prompt_wsd(mwe, pos_ctx, neg_ctx, test_ctx, gloss)
            else:
                prompt = make_one_shot_prompt(mwe, pos_ctx, neg_ctx, test_ctx)

            prompts.append(prompt)
            ids.append(_id)

        if not prompts:
            continue
        labels = classify_prompts_logits(prompts)
        for _id, lab in zip(ids, labels):
            preds_map[_id] = int(lab)
            _append_one(cache_path, (_id, int(lab)))

    elapsed = perf_counter() - t0
    print(f"{desc} | Newly computed: {len(todo_idx)} | Cached at start: {len(done)} | Total: {len(df)} | "
          f"Elapsed: {elapsed:.1f}s | {(elapsed/max(1,len(todo_idx))):.3f}s/example (new only)")

    yhat = [preds_map[str(i)] for i in df["ID"]]
    return yhat


# ---- Run one-shot EN (dev + eval) ----

train_1s_en, dev_1s_en = load_train_dev(language="EN", oneshot=True)
oneshot_index = build_oneshot_index(train_1s_en)
global_pool = pick_global_oneshot_fallback(train_1s_en)

cache_dev_1s = OUT_DIR / f"cache_llm_oneshot_dev_en_{'wsd' if USE_WSD else 'baseline'}.csv"
yhat_dev_1s = progressive_predict_one_shot_batched(
    dev_1s_en, oneshot_index, global_pool, cache_dev_1s,
    desc=f"One-shot EN (dev, USE_WSD={USE_WSD})"
)
ytrue_dev_1s = dev_1s_en["Label"].tolist()
f1_1s = f1_score(ytrue_dev_1s, yhat_dev_1s, average="macro")
print(f"[LLM One-shot EN] Dev macro-F1: {f1_1s:.4f}")
print(classification_report(ytrue_dev_1s, yhat_dev_1s, digits=4))
print(confusion_matrix(ytrue_dev_1s, yhat_dev_1s))

eval_en = load_eval(language="EN")
cache_eval_1s = OUT_DIR / f"cache_llm_oneshot_eval_en_{'wsd' if USE_WSD else 'baseline'}.csv"
yhat_eval_1s = progressive_predict_one_shot_batched(
    eval_en, oneshot_index, global_pool, cache_eval_1s,
    desc=f"One-shot EN (eval, USE_WSD={USE_WSD})"
)

sub_1s = pd.DataFrame({
    "ID": eval_en["ID"].astype(str),
    "Language": eval_en["Language"],
    "Setting": ["one_shot"] * len(eval_en),
    "Label": yhat_eval_1s
})
sub_1s_path = OUT_DIR / f"eval_submission_en_llm_oneshot_{'wsd' if USE_WSD else 'baseline'}.csv"
sub_1s.to_csv(sub_1s_path, index=False)
print(f"Wrote {sub_1s_path}")


Annotating train/dev (EN, oneshot=True) with WSD...
One-shot EN (dev, USE_WSD=True) | Resuming with 466 cached / 466 total


One-shot EN (dev, USE_WSD=True): 0it [00:00, ?it/s]


One-shot EN (dev, USE_WSD=True) | Newly computed: 0 | Cached at start: 466 | Total: 466 | Elapsed: 0.0s | 0.002s/example (new only)
[LLM One-shot EN] Dev macro-F1: 0.3836
              precision    recall  f1-score   support

           0     0.3802    0.8022    0.5159       182
           1     0.5610    0.1620    0.2514       284

    accuracy                         0.4120       466
   macro avg     0.4706    0.4821    0.3836       466
weighted avg     0.4904    0.4120    0.3547       466

[[146  36]
 [238  46]]
Annotating eval (EN) with WSD...
One-shot EN (eval, USE_WSD=True) | Resuming with 483 cached / 483 total


One-shot EN (eval, USE_WSD=True): 0it [00:00, ?it/s]

One-shot EN (eval, USE_WSD=True) | Newly computed: 0 | Cached at start: 483 | Total: 483 | Elapsed: 0.0s | 0.001s/example (new only)
Wrote outputs_en_llm_wsd/eval_submission_en_llm_oneshot_wsd.csv





## 

## Save run metadata for LLM + WSD

In [9]:
# 07 - Save run metadata for LLM + WSD

with open(OUT_DIR / f"run_en_llm_{'wsd' if USE_WSD else 'baseline'}.txt", "w") as f:
    f.write(f"MODEL_NAME={MODEL_NAME}\n")
    f.write(f"DEVICE={device.type}\n")
    f.write(f"BATCH_GEN={BATCH_GEN}\n")
    f.write(f"USE_WSD={USE_WSD}\n")
    f.write(f"ZERO_SHOT_DEV_F1={f1_0s:.4f}\n")
    f.write(f"ONE_SHOT_DEV_F1={f1_1s:.4f}\n")

print("Saved run metadata for LLM + WSD.")
print("Output dir:", OUT_DIR)


Saved run metadata for LLM + WSD.
Output dir: outputs_en_llm_wsd
