### **Cell 1: GPU Check**

In [None]:
!nvidia-smi
import torch, sys, platform
print("CUDA available:", torch.cuda.is_available())
print("PyTorch:", torch.__version__)
print("Python:", sys.version)
print("OS:", platform.platform())

Sat Aug 16 20:03:54 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   39C    P0             53W /  400W |   31673MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

### **Cell 2 — Mount Drive & set paths**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# === change these to your folders ===
DATA_DIR = "/content/drive/MyDrive/summary_generation_data"   # your JSONLs live here
SAVE_ROOT = "/content/drive/MyDrive/model_weights_tokens_files"   # parent folder for all runs
# ====================================

# Versioning: time-stamped run folder
import os, time
RUN_ID = time.strftime("%Y%m%d-%H%M%S")  # e.g., 20250816-142501
SAVE_DIR_RUN = os.path.join(SAVE_ROOT, f"mistral7b_fp16_{RUN_ID}")
os.makedirs(SAVE_DIR_RUN, exist_ok=True)

COMM_PATH = f"{DATA_DIR}/commentary.jsonl"
SC_PATH   = f"{DATA_DIR}/scorecards.jsonl"
RPT_PATH  = f"{DATA_DIR}/reports.jsonl"

print("Data dir:", DATA_DIR)
print("Run output dir:", SAVE_DIR_RUN)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Data dir: /content/drive/MyDrive/summary_generation_data
Run output dir: /content/drive/MyDrive/model_weights_tokens_files/mistral7b_fp16_20250816-200400


### **Cell 3: Requiremets & Dependencies Install**

In [None]:
# Keep it minimal to avoid conflicts
!pip -q install --no-cache-dir \
  "transformers==4.43.3" \
  "peft==0.12.0" \
  "accelerate==0.31.0" \
  "datasets==2.20.0" \
  "sentence-transformers==3.0.1"

import transformers, peft, accelerate, datasets, sentence_transformers
print("transformers:", transformers.__version__)
print("peft:", peft.__version__)
print("accelerate:", accelerate.__version__)
print("datasets:", datasets.__version__)
print("sentence-transformers:", sentence_transformers.__version__)

transformers: 4.43.3
peft: 0.12.0
accelerate: 0.31.0
datasets: 2.20.0
sentence-transformers: 3.0.1


### **Cell 4: model & sequence lengths**

In [None]:
# Recommended:
MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
# Alternative:
# MODEL_ID = "google/gemma-7b-it"  # accept license on HF first if needed

# Sequence lengths (A100 can handle fp16 with these)
MAX_IN  = 2048
MAX_OUT = 256

print("MODEL_ID:", MODEL_ID)

MODEL_ID: mistralai/Mistral-7B-Instruct-v0.3


### **Cell 5: For Gemma only login if needed**

In [None]:
if MODEL_ID.startswith("google/"):
    from huggingface_hub import login
    print("Gemma selected — if you get 403 later, run login() and paste your HF token.")
    # login()  # uncomment if needed

### **Cell 6: Load JSONLs & build examples**

In [None]:
# Cell 6 — build examples with a REQUIRED deterministic opening sentence

import os, json, re

def load_jsonl(path):
    d = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            o = json.loads(line)
            d[o["match_id"]] = o
    return d

assert os.path.exists(COMM_PATH), f"Missing {COMM_PATH}"
assert os.path.exists(SC_PATH),   f"Missing {SC_PATH}"
assert os.path.exists(RPT_PATH),  f"Missing {RPT_PATH}"

commentary = load_jsonl(COMM_PATH)
scorecards = load_jsonl(SC_PATH)
reports    = load_jsonl(RPT_PATH)

match_ids = sorted(set(commentary) & set(scorecards) & set(reports))
print("Matches available:", len(match_ids))
print("Example match_id:", match_ids[0] if match_ids else None)

def compact_stats(stats):
    keep = ["team1","team2","winner","result","result_margin","venue","date",
            "toss_winner","toss_decision","first_innings_runs","first_innings_wkts",
            "second_innings_runs","second_innings_wkts","top_batters","top_bowlers"]
    s = {k: stats.get(k) for k in keep if k in stats}
    if s.get("top_batters"): s["top_batters"] = s["top_batters"][:2]
    if s.get("top_bowlers"): s["top_bowlers"] = s["top_bowlers"][:2]
    return s

KEYWORDS = ("wicket","out","caught","lbw","review","drs","six","four","powerplay","death","target","needed","fifty","hundred")

def pick_chunks(chunks, k=6):
    if len(chunks) <= k: return chunks
    def score(s):
        t = s.lower()
        return sum(kw in t for kw in KEYWORDS)
    half = k//2
    head_tail = chunks[:half] + chunks[-half:]
    middle = chunks[half:-half] if len(chunks) > k else []
    best_mid = sorted(middle, key=score, reverse=True)[:max(0, k-len(head_tail))]
    pos = {c:i for i,c in enumerate(chunks)}
    merged = sorted(set(head_tail + best_mid), key=lambda c: pos[c])
    return merged[:k]

def loser_of(stats):
    t1, t2, w = stats.get("team1",""), stats.get("team2",""), stats.get("winner","")
    if not w: return ""
    return t2 if w == t1 else t1

def opening_sentence(stats):
    w = stats.get("winner","")
    l = loser_of(stats)
    m = str(stats.get("result_margin",""))
    v = stats.get("venue","")
    if not (w and l and m and v):
        return ""
    return f"{w} defeated {l} by {m} at {v}."

def build_example(mid, max_chunks=6):
    chunks = pick_chunks(commentary[mid]["commentary_chunks"], max_chunks)
    stats  = compact_stats(scorecards[mid]["stats"])
    req_open = opening_sentence(stats)  # deterministic opener from facts

    prompt = (
        "You are a cricket Expert. Produce a concise IPL match report with exactly 3 paragraphs:\n"
        "P1: REQUIRED opening sentence (copy facts) + one sentence context.\n"
        "P2: Turning events that decided the game.\n"
        "P3: Standout batter & bowler; short details (toss/DRS/injuries/pitch) + closing.\n\n"
        "Rules:\n"
        "- 180–220 words total.\n"
        "- Use only SCORECARD & EXCERPTS; do NOT invent facts.\n"
        '- REQUIRED opening sentence format: "{winner} defeated {loser} by {result_margin} at {venue}."\n\n'
        f"SCORECARD:\n{json.dumps(stats, ensure_ascii=False)}\n\n"
        "COMMENTARY EXCERPTS (chronological):\n" + "\n\n".join(chunks) + "\n\n"
        "Write the report now."
    )

    gold = reports[mid]["report_text"].strip()
    if req_open and not gold.startswith(req_open):
        gold = f"{req_open}\n\n" + gold

    return {"input": prompt, "output": gold}

examples = [build_example(mid) for mid in match_ids]
len(examples), examples[0]["input"][:300]

Matches available: 4
Example match_id: 1422119


(4,
 'You are a cricket Expert. Produce a concise IPL match report with exactly 3 paragraphs:\nP1: REQUIRED opening sentence (copy facts) + one sentence context.\nP2: Turning events that decided the game.\nP3: Standout batter & bowler; short details (toss/DRS/injuries/pitch) + closing.\n\nRules:\n- 180–220 word')

### **Cell 7: Tokenize (supervised fine-tuning format)**

In [None]:
# Cell 7 — tokenize for SFT (input + eos + output, mask input in labels)

from datasets import Dataset
from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

def to_features(ex, max_in=2048, max_out=256):
    x = tok(ex["input"], truncation=True, max_length=max_in)
    y = tok(ex["output"], truncation=True, max_length=max_out)
    input_ids = x["input_ids"] + [tok.eos_token_id] + y["input_ids"]
    attention = [1] * len(input_ids)
    labels    = [-100] * (len(x["input_ids"]) + 1) + y["input_ids"]
    input_ids = input_ids[:max_in]
    attention = attention[:max_in]
    labels    = labels[:max_in]
    return {"input_ids": input_ids, "attention_mask": attention, "labels": labels}

raw_ds = Dataset.from_list(examples)
ds = raw_ds.map(lambda row: to_features(row, MAX_IN, MAX_OUT), remove_columns=["input","output"])
ds

Map:   0%|          | 0/4 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 4
})

In [None]:
# Cell X — free GPU memory from previous runs
import gc, torch, inspect, sys

names_to_drop = ["gen_pipe","pipe","trainer","model","base_model"]
for n in names_to_drop:
    if n in globals():
        try:
            obj = globals().pop(n)
            del obj
        except:
            pass

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

# Optional: print free/total memory
if torch.cuda.is_available():
    free, total = torch.cuda.mem_get_info()
    print(f"CUDA mem free: {free/1e9:.2f} GB / {total/1e9:.2f} GB")

CUDA mem free: 37.50 GB / 42.47 GB


### **Cell 8: Load base model (FP16) + attach LoRA**

In [None]:
# Cell 8 — load base model on GPU (FP16, no offload) + attach LoRA

import os, torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

# Reduce fragmentation in PyTorch allocator
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256"

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# Load on CPU first (materialized), then move to GPU
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=False,   # prevent meta tensors
    device_map=None,           # load on CPU
    attn_implementation="sdpa"
)

# Move whole model to GPU
base_model.to("cuda", dtype=torch.float16)

# Memory savers for training
base_model.gradient_checkpointing_enable()
base_model.config.use_cache = False

# LoRA config
target_modules = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
lora_cfg = LoraConfig(
    r=16, lora_alpha=16, lora_dropout=0.05,
    target_modules=target_modules, bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(base_model, lora_cfg)

# Sanity: ensure nothing is on 'meta' or CPU
bad_meta = [n for n,p in model.named_parameters() if getattr(p, "device", None).type == "meta"]
bad_cpu  = [n for n,p in model.named_parameters() if getattr(p, "device", None).type == "cpu"]
assert not bad_meta, f"Meta tensors detected: {bad_meta[:5]}"
assert not bad_cpu,  f"CPU tensors detected: {bad_cpu[:5]}"

model.print_trainable_parameters()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 41,943,040 || all params: 7,289,966,592 || trainable%: 0.5754


### **Cell 9: Train Model (short & safe defaults)**

In [None]:
# Cell 9 — Build training dataset with concise 2-paragraph format, train, save

from transformers import TrainingArguments, Trainer
from datasets import Dataset
import os, json, re

# ---------- helpers (reuse from earlier cells if already defined) ----------
def loser_of(stats):
    t1, t2, w = stats.get("team1",""), stats.get("team2",""), stats.get("winner","")
    if not w: return ""
    return t2 if w == t1 else t1

def format_margin(stats):
    margin = stats.get("result_margin")
    res_txt = (stats.get("result") or "").lower()
    if margin is None: return ""
    try:
        n = int(margin)
        unit = "wickets" if "wicket" in res_txt else ("runs" if "run" in res_txt else "")
        return f"{n} {unit}".strip()
    except Exception:
        return str(margin)

def opening_sentence(stats):
    w = stats.get("winner","")
    l = loser_of(stats)
    m = format_margin(stats)
    v = stats.get("venue","")
    if not (w and l and m and v): return ""
    return f"{w} defeated {l} by {m} at {v}."

# If you already have these from earlier cells, this will just reuse them:
# - commentary, scorecards, reports, match_ids
# - compact_stats(stats)
# - pick_chunks(chunks, k)
# - tok tokenizer from Cell 8
# - MAX_IN, MAX_OUT from earlier cells (fallback values below)
try:
    MAX_IN
except NameError:
    MAX_IN = 2048
try:
    MAX_OUT
except NameError:
    MAX_OUT = 256

# ---------- build training examples in the *new style* ----------
def build_training_example(mid, max_chunks=6):
    chunks = pick_chunks(commentary[mid]["commentary_chunks"], max_chunks)
    stats  = compact_stats(scorecards[mid]["stats"])
    # Deterministic opener from facts
    req_open = opening_sentence(stats)

    # Gold summary
    gold = reports[mid]["report_text"].strip()

    # If gold doesn't start with the deterministic opener, prepend it so the model learns the pattern
    if req_open and not gold.startswith(req_open):
        gold = f"{req_open}\n\n{gold}"

    # Training prompt (fan-style, 2 short paragraphs, <=6 sentences, end with <END>)
    prompt = (
        "You are a cricket expert. Write a match summary in exactly two short paragraphs (no headings).\n"
        "Paragraph 1: Describe the main turning events that decided the game (key wickets, overs, partnerships).\n"
        "Paragraph 2: Name one standout batter and one standout bowler; add brief pitch/toss/DRS notes if relevant; "
        "close with a one-line implication.\n\n"
        "Rules:\n"
        "- Maximum 6 sentences total across both paragraphs.\n"
        "- Use only facts from SCORECARD and COMMENTARY EXCERPTS. Never invent or exaggerate.\n"
        "- Keep the style factual, concise, and fan-friendly. No filler like 'respectable total' or season stats.\n"
        "- End with the token <END>.\n\n"
        f"SCORECARD:\n{json.dumps(stats, ensure_ascii=False)}\n\n"
        "COMMENTARY EXCERPTS:\n" + "\n".join(chunks) + "\n\n"
        "### TARGET SUMMARY:"
    )

    target = gold + " <END>"
    return {"input": prompt, "output": target}

train_examples = [build_training_example(mid, max_chunks=6) for mid in match_ids]
print("Training samples:", len(train_examples))
print(train_examples[0]["input"][:300])

# ---------- tokenize to SFT format (mask prompt in labels) ----------
def to_features(ex, max_in=MAX_IN, max_out=MAX_OUT):
    x = tok(ex["input"], truncation=True, max_length=max_in)
    y = tok(ex["output"], truncation=True, max_length=max_out)
    input_ids = x["input_ids"] + [tok.eos_token_id] + y["input_ids"]
    attention = [1] * len(input_ids)
    labels    = [-100] * (len(x["input_ids"]) + 1) + y["input_ids"]
    # final truncation
    input_ids = input_ids[:max_in]
    attention = attention[:max_in]
    labels    = labels[:max_in]
    return {"input_ids": input_ids, "attention_mask": attention, "labels": labels}

ds = Dataset.from_list(train_examples).map(
    lambda ex: to_features(ex, MAX_IN, MAX_OUT),
    remove_columns=["input", "output"]
)
print(ds)

# ---------- training ----------
args = TrainingArguments(
    output_dir="/content/lora_out",     # scratch (local)
    learning_rate=2e-4,
    num_train_epochs=2,                 # set 2–3 for final
    # max_steps=300,                    # <- uncomment to cap cost/time during trials
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=20,
    save_steps=300,
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    fp16=True,                          # FP16 on A100
    report_to="none",
    max_grad_norm=0.5,
    ddp_find_unused_parameters=False,   # safer with LoRA
)

trainer = Trainer(model=model, args=args, train_dataset=ds)
trainer.train()

# ---------- save to versioned run folder ----------
os.makedirs(SAVE_DIR_RUN, exist_ok=True)
model.save_pretrained(SAVE_DIR_RUN)
tok.save_pretrained(SAVE_DIR_RUN)
print("Saved to:", SAVE_DIR_RUN)

Training samples: 4
You are a cricket expert. Write a match summary in exactly two short paragraphs (no headings).
Paragraph 1: Describe the main turning events that decided the game (key wickets, overs, partnerships).
Paragraph 2: Name one standout batter and one standout bowler; add brief pitch/toss/DRS notes if rele


Map:   0%|          | 0/4 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 4
})


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss


Saved to: /content/drive/MyDrive/model_weights_tokens_files/mistral7b_fp16_20250816-200400


In [None]:
# Cell X — free trainer/optimizer state so inference fits in VRAM
import gc, torch

# Drop Trainer to free optimizer states & dataloaders
if 'trainer' in globals():
    try:
        del trainer
    except:
        pass

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

# Switch model to eval & enable cache for faster decoding
model.eval()
if hasattr(model, "config"):
    model.config.use_cache = True

### **Cell 10: Inference (structured prompting)**

In [None]:
pip install rapidfuzz

Collecting rapidfuzz
  Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz
Successfully installed rapidfuzz-3.13.0


In [None]:
# Cell 10 — inference (no pipeline), with delayed stop + logits constraints + whitelist sanitizers + debug saves

import os, re, json, torch
from collections import Counter
from transformers import (
    StoppingCriteria, StoppingCriteriaList, LogitsProcessorList,
    NoBadWordsLogitsProcessor, InfNanRemoveLogitsProcessor
)

# ---- optional fuzzy matching (improves name correction) ----
try:
    from rapidfuzz import process as fuzz_process, fuzz
    HAVE_FUZZ = True
except Exception:
    HAVE_FUZZ = False

# ---------- base helpers ----------
def loser_of(stats):
    t1, t2, w = stats.get("team1",""), stats.get("team2",""), stats.get("winner","")
    if not w: return ""
    return t2 if w == t1 else t1

def format_margin(stats):
    margin = stats.get("result_margin")
    res_txt = (stats.get("result") or "").lower()
    if margin is None: return ""
    try:
        n = int(margin)
        unit = "wickets" if "wicket" in res_txt else ("runs" if "run" in res_txt else "")
        return f"{n} {unit}".strip()
    except Exception:
        return str(margin)

def opening_sentence(stats):
    w = stats.get("winner","")
    l = loser_of(stats)
    m = format_margin(stats)
    v = stats.get("venue","")
    if not (w and l and m and v): return ""
    return f"{w} defeated {l} by {m} at {v}."

def contains_literal(text, value):
    if not value: return True
    t = re.sub(r"\s+", " ", text).lower()
    v = re.sub(r"\s+", " ", str(value)).lower()
    return v in t

def clean_headings(txt: str) -> str:
    # strip any headings that slip through
    txt = re.sub(r"(?m)^\s*(P\d:|Paragraph\s*\d:|Pargraph\s*\d:)\s*", "", txt)
    txt = re.sub(r"\s+\.", ".", txt)
    txt = re.sub(r"\s+,", ",", txt)
    return txt.strip()

# ---------- whitelist + sanitizers ----------
NAME_RX = re.compile(r"[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*")

_GENERIC_BAD = {
    "Chennai","Super","Kings","Royal","Challengers","Bengaluru","Bangalore",
    "Stadium","Chepauk","Match","Overs","Runs","Wickets","Powerplay",
    "Review","DRS","Paragraph","Pargraph","Section","Heading","Outline"
}

def extract_names_from_text(txt: str):
    cands = [m.group(0).strip() for m in NAME_RX.finditer(txt)]
    return [c for c in cands if c not in _GENERIC_BAD and len(c) <= 30]

def build_player_whitelist(stats, chunks):
    allowed = set()
    for key in ("top_batters","top_bowlers","players","batting_card","bowling_card"):
        val = stats.get(key)
        if isinstance(val, list):
            for item in val:
                if isinstance(item, dict):
                    for k in ("name","player","batter","bowler"):
                        if k in item and item[k]:
                            allowed.add(str(item[k]).strip())
                elif isinstance(item, str):
                    allowed.add(item.strip())
    counter = Counter()
    for c in chunks:
        for nm in extract_names_from_text(c):
            counter[nm] += 1
    for nm, cnt in counter.items():
        if cnt >= 2:
            allowed.add(nm)
    return sorted({re.sub(r"\s+", " ", a).strip() for a in allowed})

def fuzzy_fix_name(name, whitelist):
    if not HAVE_FUZZ or not whitelist:
        return name
    cand, score, _ = fuzz_process.extractOne(name, whitelist, scorer=fuzz.WRatio)
    return cand if score >= 88 else name

def sanitize_names(text, whitelist):
    if not whitelist:
        return text
    wl_lower = {w.lower(): w for w in whitelist}
    lines = []
    for line in text.splitlines():
        fixed = line
        # correct each detected name
        for n in set(extract_names_from_text(line)):
            rep = wl_lower.get(n.lower(), None)
            if rep is None:
                rep = fuzzy_fix_name(n, whitelist)
            if rep != n:
                fixed = re.sub(rf"\b{re.escape(n)}\b", rep, fixed)
        # drop lines that still contain OOV names AND assert actions
        oov = [n for n in extract_names_from_text(fixed) if n.lower() not in wl_lower]
        if oov and re.search(r"\b(dismissed|bowled|caught|lbw|stumped|scored|hit|took|figures|overs|partnership)\b", fixed.lower()):
            continue
        lines.append(fixed)
    return "\n".join(lines).strip()

def valid_over_notation(text):
    # Normalize X.Y overs so that Y∈{0..5}
    return re.sub(r"(\b\d+)\.(\d{1,2})\s*overs", lambda m: f"{m.group(1)}.{min(int(m.group(2)),5)} overs", text)

def sanitize_numbers(text, stats):
    # Bowler max overs T20 = 4
    txt = re.sub(r"(\b4)\.(\d)\s*overs", r"\1 overs", text)
    # De-fluff
    txt = re.sub(r"\brespectable total\b", "a total", txt, flags=re.I)
    res = (stats.get("result") or "").lower()
    if "wicket" in res:
        txt = re.sub(r"\bfell short\b.*?(\.|\n)", ". ", txt, flags=re.I)
    return valid_over_notation(txt)

def enforce_two_paragraphs(text):
    paras = [p.strip() for p in text.split("\n\n") if p.strip()]
    return "\n\n".join(paras[:3])  # opener + 2

# ---------- generation controls (delayed stop) ----------
class StopOnTokensAfterMin(StoppingCriteria):
    def __init__(self, stop_ids, prompt_len, min_new_tokens):
        super().__init__()
        self.stop_ids = stop_ids
        self.prompt_len = prompt_len
        self.min_new_tokens = min_new_tokens
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        gen_len = input_ids.shape[-1] - self.prompt_len
        if gen_len < self.min_new_tokens:
            return False
        S = self.stop_ids.shape[-1]
        if input_ids.shape[-1] < S:
            return False
        tail = input_ids[0, -S:]
        return torch.equal(tail.cpu(), self.stop_ids.cpu())

def make_stop_criteria_after_min(tokenizer, stop_str: str, prompt_len: int, min_new_tokens: int):
    stop_ids = tokenizer(stop_str, add_special_tokens=False, return_tensors="pt").input_ids[0]
    return StoppingCriteriaList([StopOnTokensAfterMin(stop_ids, prompt_len, min_new_tokens)])

def build_bad_words_ids(tokenizer, words_or_phrases):
    ids = []
    for w in words_or_phrases:
        toks = tokenizer(w, add_special_tokens=False).input_ids
        if toks: ids.append(toks)
    return ids

def contradiction_phrases(stats):
    res = (stats.get("result") or "").lower()
    if "wicket" in res:
        return ["fell short", "fell just short", "could not chase", "defended the total",
                "won by runs", "victory by runs", "ran out of overs"]
    if "run" in res:
        return ["won by wickets", "victory by wickets", "chased down comfortably",
                "reached the target", "got over the line in the chase"]
    return []

# ---------- main inference ----------
def infer_for_match(mid, max_chunks=4, save=True):
    chunks = pick_chunks(commentary[mid]["commentary_chunks"], max_chunks)
    stats  = compact_stats(scorecards[mid]["stats"])
    req_open = opening_sentence(stats) or ""
    whitelist = build_player_whitelist(stats, chunks)

    # Prompt: concise, two paragraphs, no headings/filler, finish with <END>
    prompt = (
        "You are a cricket expert. Continue the match summary in exactly two short paragraphs (no headings).\n"
        "The REQUIRED opening sentence has already been written. Do NOT repeat it.\n\n"
        "Rules:\n"
        "- Maximum 6 sentences total across both paragraphs.\n"
        "- Use only facts from SCORECARD and COMMENTARY EXCERPTS. Never invent.\n"
        "- Keep it fan-friendly and factual. Do not include season-wide claims unless present in SCORECARD.\n"
        "- Do not write headings or labels.\n"
        "- End with <END>.\n\n"
        f"SCORECARD:\n{json.dumps(stats, ensure_ascii=False)}\n\n"
        "COMMENTARY EXCERPTS:\n" + "\n".join(chunks) + "\n\n"
        "Continue with the two paragraphs only."
    )

    enc = tok(prompt, return_tensors="pt")
    input_ids = enc.input_ids.to("cuda")
    attention_mask = enc.attention_mask.to("cuda")

    banned = [
        "P1:", "P2:", "P3:", "Paragraph 1:", "Paragraph 2:", "Paragraph 3:",
        "Pargraph 1:", "Pargraph 2:", "Section", "Heading", "Outline:"
    ]
    for p in contradiction_phrases(stats):
        banned.append(p)

    bad_words_ids = build_bad_words_ids(tok, banned)

    logits_processors = LogitsProcessorList([
        InfNanRemoveLogitsProcessor(),
        NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids, eos_token_id=tok.eos_token_id),
    ])

    MIN_NEW = 180  # keep in sync with gen_kwargs
    stop_criteria = make_stop_criteria_after_min(tok, "<END>", prompt_len=input_ids.shape[-1], min_new_tokens=MIN_NEW)

    gen_kwargs = dict(
        do_sample=False,
        repetition_penalty=1.02,
        no_repeat_ngram_size=4,
        min_new_tokens=MIN_NEW,
        max_new_tokens=260,
        pad_token_id=tok.pad_token_id,
        use_cache=True,
        eos_token_id=None,              # don't stop on EOS
        logits_processor=logits_processors,
        stopping_criteria=stop_criteria,
    )

    with torch.no_grad():
        out_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **gen_kwargs,
        )

    # remove prompt (no echo) and cut at <END>
    prompt_len = input_ids.shape[-1]
    cont_ids = out_ids[0, prompt_len:]
    raw_out = tok.decode(cont_ids, skip_special_tokens=True)
    raw_out = raw_out.split("<END>")[0].strip()
    raw_out = clean_headings(raw_out)

    # --- assemble full report (raw) ---
    raw_report = req_open + ("\n\n" if req_open else "") + raw_out

    # --- sanitize with whitelist & numbers ---
    san_report = sanitize_names(raw_report, whitelist)
    san_report = sanitize_numbers(san_report, stats)
    final_report = enforce_two_paragraphs(san_report)

    # --- failsafe: if sanitization nuked too much, fall back to raw ---
    if len(final_report) < 120:   # tweak if needed
        final_report = enforce_two_paragraphs(raw_report)

    # factual guardrail; if it fails, try one retry and prefer longer passing candidate
    ok = (contains_literal(final_report, stats.get("winner","")) and
          contains_literal(final_report, format_margin(stats)) and
          contains_literal(final_report, stats.get("venue","")))
    if not ok:
        with torch.no_grad():
            out_ids2 = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                repetition_penalty=1.05,
                no_repeat_ngram_size=4,
                min_new_tokens=170,
                max_new_tokens=250,
                pad_token_id=tok.pad_token_id,
                use_cache=True,
                eos_token_id=None,
                logits_processor=logits_processors,
                stopping_criteria=make_stop_criteria_after_min(tok, "<END>", prompt_len, 170),
            )
        cont_ids2 = out_ids2[0, prompt_len:]
        raw_out2 = tok.decode(cont_ids2, skip_special_tokens=True).split("<END>")[0].strip()
        raw_out2 = clean_headings(raw_out2)
        raw_report2 = req_open + ("\n\n" if req_open else "") + raw_out2
        san2 = sanitize_numbers(sanitize_names(raw_report2, whitelist), stats)
        cand = enforce_two_paragraphs(san2)

        def passes(txt):
            return (contains_literal(txt, stats.get("winner","")) and
                    contains_literal(txt, format_margin(stats)) and
                    contains_literal(txt, stats.get("venue","")))
        if passes(cand) and len(cand) > len(final_report):
            final_report = cand

    # save RAW / SANITIZED / FINAL to Drive
    if save:
        out_dir = os.path.join(SAVE_DIR_RUN, "reports_gen")
        os.makedirs(out_dir, exist_ok=True)
        base = os.path.join(out_dir, f"match_{mid}")
        with open(base + "_RAW.txt", "w", encoding="utf-8") as f: f.write(raw_report)
        with open(base + "_SANITIZED.txt", "w", encoding="utf-8") as f: f.write(san_report)
        with open(base + ".txt", "w", encoding="utf-8") as f: f.write(final_report)
        print("Saved:", base + ".txt")

    return final_report

# Run once (prints full final report)
test_mid = match_ids[0]
final_txt = infer_for_match(test_mid, max_chunks=3, save=True)  # 3 chunks → faster & tighter
print(final_txt)

Saved: /content/drive/MyDrive/model_weights_tokens_files/mistral7b_fp16_20250816-200400/reports_gen/match_1422119.txt
Chennai Super Kings defeated Royal Challengers Bengaluru by 6 wickets at MA Chidambaram Stadium, Chepauk, Chennai.

In the high-scoring encounter at the MA Chidambaran Stadium, Chennaii Super Kings (CSK) emerged victorious against Royal Challengers Bengluru (RCB) by 6 wickets. RCB, after winning the toss, opted to bat first and posted a competitive total of 174 runs, losing 6 wickers in their 20 overs. Anuj Rawat top-scored for RCB with 48 runs off 26 balls, while KD Karthick contributed 38 runs. For CSK, Mustafizur Rehman and C Green took 4 and 2 wickets respectively.

Chasing a modest target, CSK got off to a steady start, losing their first wicket in the 5th over. However, Ruturai Gaikwade and Rachin Ravindra provided the impetus with a 100-run partnership for the second wicket. Gaikawade scored 64 runs off 42 balls, while Ravindra remained unbeaten on 84 runs off ju

In [None]:
# Cell 10 — inference (no pipeline), with delayed stop + logits constraints + whitelist sanitizers + typo fixes + debug saves

import os, re, json, torch
from collections import Counter
from transformers import (
    StoppingCriteria, StoppingCriteriaList, LogitsProcessorList,
    NoBadWordsLogitsProcessor, InfNanRemoveLogitsProcessor
)

# ---- optional fuzzy matching (improves name correction) ----
try:
    from rapidfuzz import process as fuzz_process, fuzz
    HAVE_FUZZ = True
except Exception:
    HAVE_FUZZ = False

# ---------- base helpers ----------
def loser_of(stats):
    t1, t2, w = stats.get("team1",""), stats.get("team2",""), stats.get("winner","")
    if not w: return ""
    return t2 if w == t1 else t1

def format_margin(stats):
    margin = stats.get("result_margin")
    res_txt = (stats.get("result") or "").lower()
    if margin is None: return ""
    try:
        n = int(margin)
        unit = "wickets" if "wicket" in res_txt else ("runs" if "run" in res_txt else "")
        return f"{n} {unit}".strip()
    except Exception:
        return str(margin)

def opening_sentence(stats):
    w = stats.get("winner","")
    l = loser_of(stats)
    m = format_margin(stats)
    v = stats.get("venue","")
    if not (w and l and m and v): return ""
    return f"{w} defeated {l} by {m} at {v}."

def contains_literal(text, value):
    if not value: return True
    t = re.sub(r"\s+", " ", text).lower()
    v = re.sub(r"\s+", " ", str(value)).lower()
    return v in t

def clean_headings(txt: str) -> str:
    # strip any headings that slip through
    txt = re.sub(r"(?m)^\s*(P\d:|Paragraph\s*\d:|Pargraph\s*\d:)\s*", "", txt)
    txt = re.sub(r"\s+\.", ".", txt)
    txt = re.sub(r"\s+,", ",", txt)
    return txt.strip()

# ---------- whitelist + sanitizers ----------
NAME_RX = re.compile(r"[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*")

_GENERIC_BAD = {
    "Chennai","Super","Kings","Royal","Challengers","Bengaluru","Bangalore",
    "Stadium","Chepauk","Match","Overs","Runs","Wickets","Powerplay",
    "Review","DRS","Paragraph","Pargraph","Section","Heading","Outline"
}

def extract_names_from_text(txt: str):
    cands = [m.group(0).strip() for m in NAME_RX.finditer(txt)]
    return [c for c in cands if c not in _GENERIC_BAD and len(c) <= 30]

def build_player_whitelist(stats, chunks):
    allowed = set()
    for key in ("top_batters","top_bowlers","players","batting_card","bowling_card"):
        val = stats.get(key)
        if isinstance(val, list):
            for item in val:
                if isinstance(item, dict):
                    for k in ("name","player","batter","bowler"):
                        if k in item and item[k]:
                            allowed.add(str(item[k]).strip())
                elif isinstance(item, str):
                    allowed.add(item.strip())
    # names appearing in commentary multiple times
    counter = Counter()
    for c in chunks:
        for nm in extract_names_from_text(c):
            counter[nm] += 1
    for nm, cnt in counter.items():
        if cnt >= 2:
            allowed.add(nm)
    return sorted({re.sub(r"\s+", " ", a).strip() for a in allowed})

def fuzzy_fix_name(name, whitelist):
    if not HAVE_FUZZ or not whitelist:
        return name
    cand, score, _ = fuzz_process.extractOne(name, whitelist, scorer=fuzz.WRatio)
    return cand if score >= 88 else name

# quick deterministic typo fixes seen in outputs
def fix_common_typos(txt: str) -> str:
    replacements = {
        "Chennaii": "Chennai",
        "Bengluru": "Bengaluru",
        "MA Chidambaran": "MA Chidambaram",
        "Chidambaran": "Chidambaram",
        "Gaikwade": "Gaikwad",
        "Gaikawade": "Gaikwad",
        "Ruturai": "Ruturaj",
        "Karthick": "Karthik",    # ensure this matches your scorecard spelling
        "Rehman": "Rahman",
        "wickers": "wickets",
        "CSK' s": "CSK's",
    }
    for bad, good in replacements.items():
        txt = re.sub(rf"\b{re.escape(bad)}\b", good, txt)
    return txt

def sanitize_names(text, whitelist, drop_action_lines=False):
    """
    Replace OOV names with 'Player' instead of dropping the whole line.
    Set drop_action_lines=True to restore old behavior (drop lines).
    """
    if not whitelist:
        return text
    wl_lower = {w.lower(): w for w in whitelist}

    def is_action_line(s: str) -> bool:
        return re.search(
            r"\b(dismissed|bowled|caught|lbw|stumped|scored|hit|took|figures|overs|partnership)\b",
            s.lower()
        ) is not None

    lines = []
    for line in text.splitlines():
        fixed = line

        # 1) correct known names (exact or fuzzy)
        for n in set(extract_names_from_text(line)):
            rep = wl_lower.get(n.lower())
            if rep is None and HAVE_FUZZ:
                rep = fuzzy_fix_name(n, whitelist)
                if rep.lower() not in wl_lower:
                    rep = None
            if rep and rep != n:
                fixed = re.sub(rf"\b{re.escape(n)}\b", rep, fixed)

        # 2) handle any remaining OOV names
        oov = [n for n in extract_names_from_text(fixed) if n.lower() not in wl_lower]
        if oov and is_action_line(fixed):
            if drop_action_lines:
                continue
            for n in oov:
                fixed = re.sub(rf"\b{re.escape(n)}\b", "Player", fixed)

        lines.append(fixed)

    out = "\n".join(lines).strip()
    out = re.sub(r"\s+\.", ".", out)
    out = re.sub(r"\s+,", ",", out)
    return out

def valid_over_notation(text):
    # Normalize X.Y overs so that Y∈{0..5}
    return re.sub(r"(\b\d+)\.(\d{1,2})\s*overs", lambda m: f"{m.group(1)}.{min(int(m.group(2)),5)} overs", text)

def sanitize_numbers(text, stats):
    # Bowler max overs T20 = 4
    txt = re.sub(r"(\b4)\.(\d)\s*overs", r"\1 overs", text)
    # De-fluff
    txt = re.sub(r"\brespectable total\b", "a total", txt, flags=re.I)
    res = (stats.get("result") or "").lower()
    if "wicket" in res:
        txt = re.sub(r"\bfell short\b.*?(\.|\n)", ". ", txt, flags=re.I)
    return valid_over_notation(txt)

def enforce_two_paragraphs(text):
    paras = [p.strip() for p in text.split("\n\n") if p.strip()]
    return "\n\n".join(paras[:3])  # opener + 2

# ---------- generation controls (delayed stop) ----------
class StopOnTokensAfterMin(StoppingCriteria):
    def __init__(self, stop_ids, prompt_len, min_new_tokens):
        super().__init__()
        self.stop_ids = stop_ids
        self.prompt_len = prompt_len
        self.min_new_tokens = min_new_tokens
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        gen_len = input_ids.shape[-1] - self.prompt_len
        if gen_len < self.min_new_tokens:
            return False
        S = self.stop_ids.shape[-1]
        if input_ids.shape[-1] < S:
            return False
        tail = input_ids[0, -S:]
        return torch.equal(tail.cpu(), self.stop_ids.cpu())

def make_stop_criteria_after_min(tokenizer, stop_str: str, prompt_len: int, min_new_tokens: int):
    stop_ids = tokenizer(stop_str, add_special_tokens=False, return_tensors="pt").input_ids[0]
    return StoppingCriteriaList([StopOnTokensAfterMin(stop_ids, prompt_len, min_new_tokens)])

def build_bad_words_ids(tokenizer, words_or_phrases):
    ids = []
    for w in words_or_phrases:
        toks = tokenizer(w, add_special_tokens=False).input_ids
        if toks: ids.append(toks)
    return ids

def contradiction_phrases(stats):
    res = (stats.get("result") or "").lower()
    if "wicket" in res:
        return ["fell short", "fell just short", "could not chase", "defended the total",
                "won by runs", "victory by runs", "ran out of overs"]
    if "run" in res:
        return ["won by wickets", "victory by wickets", "chased down comfortably",
                "reached the target", "got over the line in the chase"]
    return []

# ---------- main inference ----------
def infer_for_match(mid, max_chunks=3, save=True):
    chunks = pick_chunks(commentary[mid]["commentary_chunks"], max_chunks)
    stats  = compact_stats(scorecards[mid]["stats"])
    req_open = opening_sentence(stats) or ""
    whitelist = build_player_whitelist(stats, chunks)

    # Prompt: concise, two paragraphs, no headings/filler, finish with <END>, with allowed names hint
    allowed_names = whitelist
    prompt = (
        "You are a cricket expert. Continue the match summary in exactly two short paragraphs (no headings).\n"
        "The REQUIRED opening sentence has already been written. Do NOT repeat it.\n\n"
        "Rules:\n"
        "- Maximum 6 sentences total across both paragraphs.\n"
        "- Use only facts from SCORECARD and COMMENTARY EXCERPTS. Never invent.\n"
        "- Use ONLY these player names if you mention players; if unsure, use generic phrases like 'the opener' or 'the seamer':\n"
        f"  ALLOWED NAMES: {', '.join(allowed_names) if allowed_names else '—'}\n"
        "- Keep it fan-friendly and factual. Do not include season-wide or streak claims unless present in SCORECARD.\n"
        "- Do not write headings or labels.\n"
        "- End with <END>.\n\n"
        f"SCORECARD:\n{json.dumps(stats, ensure_ascii=False)}\n\n"
        "COMMENTARY EXCERPTS:\n" + "\n".join(chunks) + "\n\n"
        "Continue with the two paragraphs only."
    )

    enc = tok(prompt, return_tensors="pt")
    input_ids = enc.input_ids.to("cuda")
    attention_mask = enc.attention_mask.to("cuda")

    # ban headings/labels & obvious contradiction phrases & common misspellings
    banned = [
        "P1:", "P2:", "P3:", "Paragraph 1:", "Paragraph 2:", "Paragraph 3:",
        "Pargraph 1:", "Pargraph 2:", "Section", "Heading", "Outline:",
        # season/streak hallucinations
        "winning streak", "now won all their matches", "perfect record this season",
        # common misspellings seen
        "Chennaii", "Bengluru", "Gaikwade", "Gaikawade", "Karthick", "Rehman", "wickers",
    ]
    for p in contradiction_phrases(stats):
        banned.append(p)

    bad_words_ids = build_bad_words_ids(tok, banned)
    logits_processors = LogitsProcessorList([
        InfNanRemoveLogitsProcessor(),
        NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids, eos_token_id=tok.eos_token_id),
    ])

    MIN_NEW = 190  # increased to avoid short outputs
    stop_criteria = make_stop_criteria_after_min(tok, "<END>", prompt_len=input_ids.shape[-1], min_new_tokens=MIN_NEW)

    gen_kwargs = dict(
        do_sample=False,
        repetition_penalty=1.02,
        no_repeat_ngram_size=4,
        min_new_tokens=MIN_NEW,
        max_new_tokens=300,   # more room for two paragraphs
        pad_token_id=tok.pad_token_id,
        use_cache=True,
        eos_token_id=None,              # don't stop on EOS
        logits_processor=logits_processors,
        stopping_criteria=stop_criteria,
    )

    with torch.no_grad():
        out_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **gen_kwargs,
        )

    # remove prompt (no echo) and cut at <END>
    prompt_len = input_ids.shape[-1]
    cont_ids = out_ids[0, prompt_len:]
    raw_out = tok.decode(cont_ids, skip_special_tokens=True)
    raw_out = raw_out.split("<END>")[0].strip()
    raw_out = clean_headings(raw_out)

    # --- assemble full report (raw) + typo fixes ---
    raw_report = req_open + ("\n\n" if req_open else "") + raw_out
    raw_report = fix_common_typos(raw_report)

    # --- sanitize with whitelist & numbers ---
    san_report = sanitize_names(raw_report, whitelist)
    san_report = sanitize_numbers(san_report, stats)
    # if sanitization removed too much, keep raw
    if len(san_report.split()) < 40:
        san_report = raw_report

    final_report = enforce_two_paragraphs(san_report)

    # factual guardrail; if it fails, try one retry and prefer longer passing candidate
    ok = (contains_literal(final_report, stats.get("winner","")) and
          contains_literal(final_report, format_margin(stats)) and
          contains_literal(final_report, stats.get("venue","")))
    if not ok:
        with torch.no_grad():
            out_ids2 = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                repetition_penalty=1.05,
                no_repeat_ngram_size=4,
                min_new_tokens=180,
                max_new_tokens=290,
                pad_token_id=tok.pad_token_id,
                use_cache=True,
                eos_token_id=None,
                logits_processor=logits_processors,
                stopping_criteria=make_stop_criteria_after_min(tok, "<END>", prompt_len, 180),
            )
        cont_ids2 = out_ids2[0, prompt_len:]
        raw_out2 = tok.decode(cont_ids2, skip_special_tokens=True).split("<END>")[0].strip()
        raw_out2 = clean_headings(raw_out2)
        raw_report2 = req_open + ("\n\n" if req_open else "") + raw_out2
        raw_report2 = fix_common_typos(raw_report2)
        san2 = sanitize_numbers(sanitize_names(raw_report2, whitelist), stats)
        if len(san2.split()) < 40:
            san2 = raw_report2
        cand = enforce_two_paragraphs(san2)

        def passes(txt):
            return (contains_literal(txt, stats.get("winner","")) and
                    contains_literal(txt, format_margin(stats)) and
                    contains_literal(txt, stats.get("venue","")))
        if passes(cand) and len(cand) > len(final_report):
            final_report = cand

    # save RAW / SANITIZED / FINAL to Drive
    if save:
        out_dir = os.path.join(SAVE_DIR_RUN, "reports_gen")
        os.makedirs(out_dir, exist_ok=True)
        base = os.path.join(out_dir, f"match_{mid}")
        with open(base + "_RAW.txt", "w", encoding="utf-8") as f: f.write(raw_report)
        with open(base + "_SANITIZED.txt", "w", encoding="utf-8") as f: f.write(san_report)
        with open(base + ".txt", "w", encoding="utf-8") as f: f.write(final_report)
        print("Saved:", base + ".txt")

    return final_report

# Run once (prints full final report)
test_mid = match_ids[0]
final_txt = infer_for_match(test_mid, max_chunks=3, save=True)  # 3 chunks → tighter & faster
print(final_txt)

Saved: /content/drive/MyDrive/model_weights_tokens_files/mistral7b_fp16_20250816-200400/reports_gen/match_1422119.txt
Chennai Super Kings defeated Royal Challengers Bengaluru by 6 wickets at MA Chidambaram Stadium, Chepauk, Chennai.

Anuj RawAT and KD KARTHik opened for Player (RCB), with Rawat scoring 48 off 26 balls, including 5 fours and 1 six. KD Player contributed 38 off 32 balls, hitting 4 fours. Player, their efforts were not enough to propel RCB to a competitive total, as they were restricted to 174/6 in their 20 overs.

For Chennai SuperKings (CSK), MustafizUR Mustafizur Rahman was the standout bowler, taking 4 wickets for 30 runs in 4 overs. C Green also chipped in with 2 wickets for just 27 runs in 3.3 overs. Player response, CSK reached their target with 6 wickets and 4 balls to spare, thanks to a 44-ball 67 from Rachin GAikwad and a quickfire 19-ball 33 from Faf du PlessIS. Du Plessis' innings included 3 fours and a six, while GaikwAD hit 7 fours and as many sixes. The win

### **Cell 11: Quick Fact check & ROUGE-L**

In [None]:
import re

def contains_literal(text, value):
    if not value: return True
    t = re.sub(r"\s+", " ", text).lower()
    v = re.sub(r"\s+", " ", str(value)).lower()
    return v in t

def compute_loser(stats):
    t1,t2,w = stats.get("team1",""), stats.get("team2",""), stats.get("winner","")
    if not w: return ""
    return t2 if w==t1 else t1

def lcs_len(a,b):
    n,m=len(a),len(b)
    dp=[[0]*(m+1) for _ in range(n+1)]
    for i in range(1,n+1):
        ai=a[i-1]
        for j in range(1,m+1):
            dp[i][j]=dp[i-1][j-1]+1 if ai==b[j-1] else max(dp[i-1][j],dp[i][j-1])
    return dp[n][m]

def rouge_l(pred, ref):
    p, r = pred.split(), ref.split()
    if not p or not r: return {"r":0.0,"p":0.0,"f":0.0}
    lcs = lcs_len(p,r); rec=lcs/len(r); prec=lcs/len(p)
    f = 0.0 if rec+prec==0 else (2*prec*rec)/(prec+rec+1e-12)
    return {"r":rec,"p":prec,"f":f}

gold = reports[test_mid]["report_text"]
stats = compact_stats(scorecards[test_mid]["stats"])

facts_ok = {
    "winner": contains_literal(gen_text, stats.get("winner","")),
    "margin": contains_literal(gen_text, str(stats.get("result_margin",""))),
    "venue":  contains_literal(gen_text, stats.get("venue","")),
    "toss":   (contains_literal(gen_text, stats.get("toss_winner","")) and contains_literal(gen_text, stats.get("toss_decision",""))),
    "loser":  contains_literal(gen_text, compute_loser(stats)) if compute_loser(stats) else True
}
facts_ok["all_pass"] = all(facts_ok.values())

rl = rouge_l(gen_text, gold)
print("FACTS:", facts_ok)
print("ROUGE-L:", rl)

FACTS: {'winner': True, 'margin': False, 'venue': True, 'toss': True, 'loser': True, 'all_pass': False}
ROUGE-L: {'r': 0.12686567164179105, 'p': 0.1574074074074074, 'f': 0.14049586776810083}


# **NEW RUN to keep the output fresh**

In [None]:
# Cell 1 — GPU & runtime check
# Purpose: verify GPU, seed, environment

import os, random, torch, platform, numpy as np
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

!nvidia-smi || true
print("Python:", platform.python_version())
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

Tue Aug 19 15:53:22 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   61C    P8             14W /   72W |       0MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
# Cell 2 — Install minimal, stable deps (no bitsandbytes)
# Purpose: keep stack simple & stable for FP16 LoRA on A100

!pip -q install --no-cache-dir \
  "transformers==4.43.3" \
  "peft==0.12.0" \
  "accelerate==0.31.0" \
  "datasets==2.20.0" \
  "sentence-transformers==3.0.1" \
  "rapidfuzz==3.9.7" \
  "trafilatura==1.8.0"

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m128.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.4/9.4 MB[0m [31m35.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m296.4/296.4 kB[0m [31m337.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m309.4/309.4 kB[0m [31m347.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m367.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m227.1/227.1 kB[0m [31m345.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m202.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m352.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Cell 3 — Mount Drive & set paths, model, and held-out test match
# Purpose: path config and versioned save dir

from google.colab import drive
drive.mount('/content/drive')

import time, os

# -------- data paths --------
DATA_DIR = "/content/drive/MyDrive/Thesis_data"   # <- change if needed
COMMENTARY_FP = f"{DATA_DIR}/commentary.jsonl"
SCORECARDS_FP = f"{DATA_DIR}/scorecards.jsonl"
REPORTS_FP    = f"{DATA_DIR}/reports.jsonl"

# -------- choose model --------
# Default (recommended on L4/T4/A100 when A100 not available):
MODEL_ID = "meta-llama/Meta-Llama-3.1-3B-Instruct"

# If you later want to try Qwen 7B, just uncomment the next line:
# MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"

# Pretty name for save dir based on model choice
model_tag = (
    "llama31_3b_fp16"
    if "llama" in MODEL_ID.lower()
    else ("qwen25_7b_fp16" if "qwen" in MODEL_ID.lower() else "model_fp16")
)

STAMP = time.strftime("%Y%m%d-%H%M%S")
SAVE_ROOT = "/content/drive/MyDrive/model_weights_tokens_files"
SAVE_DIR_RUN = f"{SAVE_ROOT}/{model_tag}_{STAMP}"
os.makedirs(SAVE_DIR_RUN, exist_ok=True)

# Held-out test match
TEST_ID  = "1422119"   # RCB vs CSK (change if needed)

print("DATA_DIR:", DATA_DIR)
print("SAVE_DIR_RUN:", SAVE_DIR_RUN)
print("MODEL_ID:", MODEL_ID)
print("TEST_ID:", TEST_ID)

Mounted at /content/drive
DATA_DIR: /content/drive/MyDrive/Thesis_data
SAVE_DIR_RUN: /content/drive/MyDrive/model_weights_tokens_files/llama31_3b_fp16_20250819-155716
MODEL_ID: meta-llama/Meta-Llama-3.1-3B-Instruct
TEST_ID: 1422119


In [None]:
# Cell 3 — Mount Drive & set paths, model, and held-out test match
# Purpose: path config and versioned save dir

# FOR MISTRAL AI ONLY

from google.colab import drive
drive.mount('/content/drive')

import time, os
DATA_DIR = "/content/drive/MyDrive/Thesis_data"   # <- change if needed
COMMENTARY_FP = f"{DATA_DIR}/commentary.jsonl"
SCORECARDS_FP = f"{DATA_DIR}/scorecards.jsonl"
REPORTS_FP    = f"{DATA_DIR}/reports.jsonl"

STAMP = time.strftime("%Y%m%d-%H%M%S")
SAVE_ROOT = "/content/drive/MyDrive/model_weights_tokens_files"
SAVE_DIR_RUN = f"{SAVE_ROOT}/mistral7b_fp16_{STAMP}"
os.makedirs(SAVE_DIR_RUN, exist_ok=True)

MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
TEST_ID  = "1422119"   # RCB vs CSK (change if needed)

print("DATA_DIR:", DATA_DIR)
print("SAVE_DIR_RUN:", SAVE_DIR_RUN)
print("MODEL_ID:", MODEL_ID)
print("TEST_ID:", TEST_ID)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
DATA_DIR: /content/drive/MyDrive/Thesis_data
SAVE_DIR_RUN: /content/drive/MyDrive/model_weights_tokens_files/mistral7b_fp16_20250819-161423
MODEL_ID: meta-llama/Meta-Llama-3.1-8B-Instruct
TEST_ID: 1422119


In [None]:
# Cell 4 — Load JSONL & compute splits (no train_manifest)
# Purpose: read the three files and hold out TEST_ID

import json

def load_jsonl(path):
    rows=[]
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip(): rows.append(json.loads(line))
    return rows

def by_id(rows, key="match_id"):
    return {str(r[key]): r for r in rows if key in r}

commentary_rows = load_jsonl(COMMENTARY_FP)
scorecard_rows  = load_jsonl(SCORECARDS_FP)
report_rows     = load_jsonl(REPORTS_FP)

commentary_by_id = by_id(commentary_rows)
scorecards_by_id = by_id(scorecard_rows)
reports_by_id    = by_id(report_rows)

all_ids = sorted(set(commentary_by_id) & set(scorecards_by_id))
gold_ids = sorted(set(all_ids) & set(reports_by_id))

print("Matches with commentary+scorecard:", len(all_ids))
print("Gold reports:", len(gold_ids), "e.g.", gold_ids[:10])
assert TEST_ID in all_ids, "TEST_ID missing from your data files."

train_ids = [m for m in all_ids if m != TEST_ID]
print("Train size (held-out test):", len(train_ids))

Matches with commentary+scorecard: 71
Gold reports: 4 e.g. ['1422119', '1422120', '1422121', '1422122']
Train size (held-out test): 70


In [None]:
# Cell 5 — Core helpers: compact stats, chunk picking, opener, templates
# Purpose: consistent fields, minimal templated targets for weak supervision

import re
from collections import Counter

def get(obj, key, default=None):
    try: return obj.get(key, default)
    except Exception: return default

def compact_stats(stats):
    return {
        "team1": get(stats,"team1",""), "team2": get(stats,"team2",""),
        "winner": get(stats,"winner",""), "result": get(stats,"result",""),
        "result_margin": get(stats,"result_margin",""), "venue": get(stats,"venue",""),
        "toss_winner": get(stats,"toss_winner",""), "toss_decision": get(stats,"toss_decision",""),
        "date": get(stats,"date",""),
        "top_batters": get(stats,"top_batters",[]), "top_bowlers": get(stats,"top_bowlers",[]),
        "batting_card_t1": get(stats,"batting_card_t1",[]), "batting_card_t2": get(stats,"batting_card_t2",[]),
        "bowling_card_t1": get(stats,"bowling_card_t1",[]), "bowling_card_t2": get(stats,"bowling_card_t2",[]),
        "players_t1": get(stats,"players_t1",[]), "players_t2": get(stats,"players_t2",[]),
    }

def pick_chunks(chunks, max_chunks=4):
    if not chunks: return []
    if len(chunks) <= max_chunks: return chunks
    idxs = [0, len(chunks)//2, -1]
    s = set(i for i in idxs if 0<=i<len(chunks))
    while len(s) < max_chunks: s.add(len(s))
    return [chunks[i] for i in sorted(list(s))[:max_chunks]]

def loser_of(stats):
    t1,t2,w = stats.get("team1",""), stats.get("team2",""), stats.get("winner","")
    return t2 if w == t1 else t1

def format_margin(stats):
    margin = stats.get("result_margin")
    res = (stats.get("result") or "").lower()
    if margin in (None,""): return ""
    try:
        n = int(margin)
        unit = "wickets" if "wicket" in res else ("runs" if "run" in res else "")
        return f"{n} {unit}".strip()
    except Exception:
        return str(margin)

def opening_sentence(stats):
    w, l, m, v = stats.get("winner",""), loser_of(stats), format_margin(stats), stats.get("venue","")
    if not (w and l and m and v): return ""
    return f"{w} defeated {l} by {m} at {v}."

def extract_name_list(stats):
    names=[]
    for li in ("top_batters","top_bowlers","players_t1","players_t2"):
        for it in stats.get(li,[]) or []:
            if isinstance(it,dict):
                for k in ("name","player","batter","bowler"):
                    if k in it and it[k]: names.append(str(it[k]))
            elif isinstance(it,str):
                names.append(it)
    seen=set(); out=[]
    for n in names:
        if n not in seen: seen.add(n); out.append(n)
    return out

def template_two_paragraphs(stats):
    opener = opening_sentence(stats)
    p1 = "The match turned on a handful of overs where wickets and boundaries flipped momentum."
    sb_name=""; sb_runs=""
    for b in stats.get("top_batters",[]):
        if isinstance(b,dict) and b.get("batter"):
            sb_name=b["batter"]; sb_runs=b.get("runs",""); break
    bw_name=""; bw_wkts=""
    for w in stats.get("top_bowlers",[]):
        if isinstance(w,dict) and w.get("bowler"):
            bw_name=w["bowler"]; bw_wkts=w.get("wkts",""); break
    tail=[]
    if sb_name and sb_runs!="": tail.append(f"{sb_name} led the scoring with {sb_runs}.")
    if bw_name and bw_wkts!="": tail.append(f"{bw_name} stood out with {bw_wkts} wickets.")
    if stats.get("toss_winner"): tail.append(f"{stats['toss_winner']} won the toss and chose to {stats.get('toss_decision','bat')}.")
    tail.append("The result shapes the next steps in the tournament.")
    p2=" ".join(tail)
    return (opener + "\n\n" + p1 + "\n\n" + p2).strip()

print("Core helpers ready.")

Core helpers ready.


In [None]:
# Cell 6 — Use existing web_summaries.jsonl if present; else build from CSV
# Purpose: prefer your curated summaries; only construct from URLs if missing/empty.

import os, csv, json, re, trafilatura

DATA_DIR = "/content/drive/MyDrive/Thesis_data"
URL_CSV = f"{DATA_DIR}/web_reports_urls.csv"
SILVER_JSONL = f"{DATA_DIR}/web_summaries.jsonl"
TEST_ID = "1422119"  # held-out; do not include in silver to avoid leakage

def count_jsonl(path):
    if not os.path.exists(path): return 0
    n = 0
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip(): n += 1
    return n

silver_count = count_jsonl(SILVER_JSONL)

if silver_count > 0:
    print(f"Found existing web_summaries.jsonl with {silver_count} items → using it as-is.")
else:
    print("No existing web_summaries.jsonl found (or empty). Attempting to build from web_reports_urls.csv ...")

    BAD_PATTERNS = [
        r"Live\s+blog", r"as it happened", r"minute-by-minute", r"ball-by-ball",
        r"subscribe", r"sign up", r"advertisement", r"cookie", r"privacy policy",
    ]

    def load_pairs(csv_path):
        rows=[]
        if not os.path.exists(csv_path):
            print("No web_reports_urls.csv found — skipping web silver build.")
            return rows
        with open(csv_path,"r",encoding="utf-8") as f:
            rd = csv.DictReader(f)
            for r in rd:
                mid = str(r.get("match_id","")).strip()
                url = (r.get("url") or "").strip()
                if mid and url: rows.append((mid,url))
        return rows

    def fetch_clean(url):
        try:
            raw = trafilatura.fetch_url(url, no_ssl=True)
            if not raw: return ""
            txt = trafilatura.extract(raw, include_comments=False, include_tables=False, favor_precision=True) or ""
            return re.sub(r"\s+"," ", txt).strip()
        except Exception:
            return ""

    def basic_cleanup(txt):
        for pat in BAD_PATTERNS:
            txt = re.sub(pat, "", txt, flags=re.IGNORECASE)
        words = txt.split()
        if len(words) > 280: txt = " ".join(words[:280])
        return txt

    def to_two_paras(txt):
        sents = re.split(r'(?<=[.!?])\s+', txt)
        p1 = " ".join(sents[:3]).strip()
        p2 = " ".join(sents[3:6]).strip() or " ".join(sents[2:5]).strip()
        if not p2: p2 = "The result reflects the decisive passages of play."
        out = (p1 + "\n\n" + p2).strip()
        return " ".join(out.split()[:230])

    pairs = load_pairs(URL_CSV)
    n = 0
    if pairs:
        with open(SILVER_JSONL, "w", encoding="utf-8") as out:
            for mid, url in pairs:
                if mid == TEST_ID:  # avoid leakage into eval
                    continue
                text = fetch_clean(url)
                if not text or len(text.split()) < 60:  # too short/noisy
                    continue
                text = basic_cleanup(text)
                shaped = to_two_paras(text)
                if len(shaped.split()) < 60:
                    continue
                rec = {"match_id": mid, "report_text": shaped, "source_url": url, "source_type": "silver_web"}
                out.write(json.dumps(rec, ensure_ascii=False) + "\n")
                n += 1
        print(f"Wrote {n} items -> {SILVER_JSONL}")
    else:
        print("No URL CSV to build from. Proceeding without web silver.")

Found existing web_summaries.jsonl with 11 items → using it as-is.


In [None]:
# A) Time file I/O from Drive vs local
import time, shutil, os, json, itertools, pathlib

DATA_DIR = "/content/drive/MyDrive/thesis_data"
LOCAL = "/content/thesis_tmp"
os.makedirs(LOCAL, exist_ok=True)

t0=time.time()
for fn in ["commentary.jsonl","scorecards.jsonl","reports.jsonl","web_summaries.jsonl"]:
    src=f"{DATA_DIR}/{fn}"
    if os.path.exists(src):
        shutil.copy2(src, LOCAL)
print("Copy -> /content/ :", round(time.time()-t0,2),"s")

# Quick size + first lines
for fn in os.listdir(LOCAL):
    p=f"{LOCAL}/{fn}"
    print(fn, "| MB:", os.path.getsize(p)/1e6)
    with open(p,'r',encoding='utf-8') as f:
        print("head:", next(itertools.islice(f,0,1)).strip()[:160])

Copy -> /content/ : 0.0 s


In [None]:
# B) Re-run just the *build training list* step with timers and progress
from time import time
import json, re
from collections import Counter

def load_jsonl(path):
    rows=[]
    if not os.path.exists(path): return rows
    with open(path,"r",encoding="utf-8") as f:
        for line in f:
            if line.strip(): rows.append(json.loads(line))
    return rows

COMMENTARY_FP = f"{LOCAL}/commentary.jsonl"
SCORECARDS_FP = f"{LOCAL}/scorecards.jsonl"
REPORTS_FP    = f"{LOCAL}/reports.jsonl"
SILVER_FP     = f"{LOCAL}/web_summaries.jsonl"

commentary_rows = load_jsonl(COMMENTARY_FP)
scorecard_rows  = load_jsonl(SCORECARDS_FP)
report_rows     = load_jsonl(REPORTS_FP)
silver_rows     = load_jsonl(SILVER_FP)

def by_id(rows, key="match_id"): return {str(r[key]): r for r in rows if key in r}
commentary_by_id = by_id(commentary_rows)
scorecards_by_id = by_id(scorecard_rows)
reports_by_id    = by_id(report_rows)
silver_by_id     = by_id(silver_rows)

all_ids = sorted(set(commentary_by_id) & set(scorecards_by_id))
TEST_ID = "1422119"
train_ids = [m for m in all_ids if m != TEST_ID]

print("counts | comm:",len(commentary_rows),"score:",len(scorecard_rows),"gold:",len(report_rows),"silver:",len(silver_rows))
print("train_ids:", len(train_ids))

# Minimal helpers (adapt from your notebook)
def get(o,k,d=None):
    try: return o.get(k,d)
    except: return d

def compact_stats(stats):
    return {
        "team1":get(stats,"team1",""),"team2":get(stats,"team2",""),
        "winner":get(stats,"winner",""),"result":get(stats,"result",""),
        "result_margin":get(stats,"result_margin",""),"venue":get(stats,"venue",""),
        "toss_winner":get(stats,"toss_winner",""),"toss_decision":get(stats,"toss_decision",""),
        "top_batters":get(stats,"top_batters",[]),"top_bowlers":get(stats,"top_bowlers",[]),
        "players_t1":get(stats,"players_t1",[]),"players_t2":get(stats,"players_t2",[]),
        "batting_card_t1":get(stats,"batting_card_t1",[]),"batting_card_t2":get(stats,"batting_card_t2",[]),
        "bowling_card_t1":get(stats,"bowling_card_t1",[]),"bowling_card_t2":get(stats,"bowling_card_t2",[]),
    }

def pick_chunks(chunks, max_chunks=4):
    if not chunks: return []
    return chunks[:max_chunks]

def loser_of(s):
    t1,t2,w=s.get("team1",""),s.get("team2",""),s.get("winner","")
    return t2 if w==t1 else t1

def format_margin(s):
    m=s.get("result_margin"); r=(s.get("result") or "").lower()
    if m in (None,""): return ""
    try:
        n=int(m); unit="wickets" if "wicket" in r else ("runs" if "run" in r else "")
        return f"{n} {unit}".strip()
    except: return str(m)

def opening_sentence(s):
    w,l,m,v=s.get("winner",""),loser_of(s),format_margin(s),s.get("venue","")
    return f"{w} defeated {l} by {m} at {v}." if (w and l and m and v) else ""

def extract_name_list(s):
    names=[]
    for li in ("top_batters","top_bowlers","players_t1","players_t2"):
        for it in s.get(li,[]) or []:
            if isinstance(it,dict):
                for k in ("name","player","batter","bowler"):
                    if it.get(k): names.append(str(it[k]))
            elif isinstance(it,str):
                names.append(it)
    out=[]; seen=set()
    for n in names:
        if n not in seen: seen.add(n); out.append(n)
    return out

def template_two_paragraphs(s):
    op=opening_sentence(s)
    p1="The match turned on a handful of overs where wickets and boundaries flipped momentum."
    p2="The result shapes the next steps in the tournament."
    return (op+"\n\n"+p1+"\n\n"+p2).strip()

def normalize_target_to_two_paras(text, opener):
    body=text.strip().replace(opener,"").strip() if opener else text.strip()
    sents=re.split(r'(?<=[.!?])\s+', body)
    p1=" ".join(sents[:3]).strip()
    p2=" ".join(sents[3:6]).strip() or " ".join(sents[2:5]).strip() or "The result underscores the decisive passages of play."
    return (opener+"\n\n"+p1+"\n\n"+p2).strip()

def build_training_example(mid, max_chunks=6):
    com=commentary_by_id[mid]; sc=scorecards_by_id[mid]; stats=compact_stats(sc["stats"])
    chunks=pick_chunks(com.get("commentary_chunks",[]), max_chunks=max_chunks)
    opener=opening_sentence(stats)
    if mid in reports_by_id:
        target=normalize_target_to_two_paras(reports_by_id[mid]["report_text"], opener); label="gold"
    elif mid in silver_by_id:
        target=normalize_target_to_two_paras(silver_by_id[mid]["report_text"], opener); label="silver"
    else:
        target=template_two_paragraphs(stats); label="template"
    allowed=", ".join(extract_name_list(stats)) or "—"
    prompt = (
        "You are a cricket expert. Continue the match summary in exactly two short paragraphs (no headings).\n"
        "The REQUIRED opening sentence has already been written. Do NOT repeat it.\n\n"
        "Rules:\n- Maximum 6 sentences total across both paragraphs.\n"
        "- Use only facts from SCORECARD and COMMENTARY EXCERPTS. Never invent.\n"
        f"- ALLOWED NAMES: {allowed}\n- End with <END>.\n\n"
        f"SCORECARD:\n{json.dumps(stats, ensure_ascii=False)}\n\n"
        "COMMENTARY EXCERPTS:\n" + "\n".join(chunks) + "\n\n"
        "Continue with the two paragraphs only."
    )
    return {"match_id": mid, "text": prompt + "\n" + target + " <END>", "label_type": label}

t0=time()
train_set=[]
for i,mid in enumerate(train_ids,1):
    train_set.append(build_training_example(mid))
    if i%10==0:
        print(f"built {i}/{len(train_ids)} examples...")
print("Done. secs:", round(time()-t0,2))
from collections import Counter
print("labels:", Counter([t["label_type"] for t in train_set]))

counts | comm: 0 score: 0 gold: 0 silver: 0
train_ids: 0
Done. secs: 0.0
labels: Counter()


In [None]:
# Cell 7 — Build training set = GOLD + (optional) SILVER + TEMPLATES (exclude TEST_ID)
# Purpose: mitigate overfitting; create uniform two-paragraph targets; guarantee non-empty output if data exist

import os, json, re
from collections import Counter

# ---- paths (use local copy for speed; change to your DATA_DIR if needed) ----
DATA_DIR = "/content/drive/MyDrive/Thesis_data"
COMMENTARY_FP = f"{DATA_DIR}/commentary.jsonl"
SCORECARDS_FP = f"{DATA_DIR}/scorecards.jsonl"
REPORTS_FP    = f"{DATA_DIR}/reports.jsonl"
SILVER_FP     = f"{DATA_DIR}/web_summaries.jsonl"
TEST_ID       = "1422119"

def load_jsonl(path):
    rows = []
    if not os.path.exists(path): return rows
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    rows.append(json.loads(line))
                except Exception:
                    pass
    return rows

def by_id(rows, key="match_id"):
    out = {}
    for r in rows:
        if key in r:
            out[str(r[key])] = r
    return out

# ---- load all datasets ----
commentary_rows = load_jsonl(COMMENTARY_FP)
scorecard_rows  = load_jsonl(SCORECARDS_FP)
report_rows     = load_jsonl(REPORTS_FP)
silver_rows     = load_jsonl(SILVER_FP)

commentary_by_id = by_id(commentary_rows)
scorecards_by_id = by_id(scorecard_rows)
reports_by_id    = by_id(report_rows)
silver_by_id     = by_id(silver_rows)

# ---- compute train ids: intersection minus TEST_ID ----
all_ids = sorted(set(commentary_by_id) & set(scorecards_by_id))
train_ids = [m for m in all_ids if m != TEST_ID]

print(f"Loaded - commentary:{len(commentary_rows)} scorecards:{len(scorecard_rows)} "
      f"gold:{len(report_rows)} silver:{len(silver_rows)}")
print(f"Intersected match_ids: {len(all_ids)} | train_ids (excl TEST): {len(train_ids)}")

# ---- helpers reused from earlier cells (safe, minimal) ----
def get(o,k,d=None):
    try: return o.get(k,d)
    except: return d

def compact_stats(stats):
    return {
        "team1":get(stats,"team1",""), "team2":get(stats,"team2",""),
        "winner":get(stats,"winner",""), "result":get(stats,"result",""),
        "result_margin":get(stats,"result_margin",""), "venue":get(stats,"venue",""),
        "toss_winner":get(stats,"toss_winner",""), "toss_decision":get(stats,"toss_decision",""),
        "top_batters":get(stats,"top_batters",[]), "top_bowlers":get(stats,"top_bowlers",[]),
        "players_t1":get(stats,"players_t1",[]), "players_t2":get(stats,"players_t2",[]),
        "batting_card_t1":get(stats,"batting_card_t1",[]), "batting_card_t2":get(stats,"batting_card_t2",[]),
        "bowling_card_t1":get(stats,"bowling_card_t1",[]), "bowling_card_t2":get(stats,"bowling_card_t2",[]),
    }

def pick_chunks(chunks, max_chunks=6):
    if not isinstance(chunks, list) or len(chunks) == 0:
        return []
    return chunks[:max_chunks]

def loser_of(stats):
    t1, t2, w = stats.get("team1",""), stats.get("team2",""), stats.get("winner","")
    return t2 if w == t1 else t1

def format_margin(stats):
    margin = stats.get("result_margin")
    res_txt = (stats.get("result") or "").lower()
    if margin in (None, ""): return ""
    try:
        n = int(margin)
        unit = "wickets" if "wicket" in res_txt else ("runs" if "run" in res_txt else "")
        return f"{n} {unit}".strip()
    except Exception:
        return str(margin)

def opening_sentence(stats):
    w = stats.get("winner","")
    l = loser_of(stats)
    m = format_margin(stats)
    v = stats.get("venue","")
    return f"{w} defeated {l} by {m} at {v}." if (w and l and m and v) else ""

def extract_name_list(stats):
    names = []
    for li in ("top_batters","top_bowlers","players_t1","players_t2"):
        for it in stats.get(li,[]) or []:
            if isinstance(it, dict):
                for k in ("name","player","batter","bowler"):
                    if it.get(k): names.append(str(it[k]))
            elif isinstance(it, str):
                names.append(it)
    # uniq, preserve order
    out, seen = [], set()
    for n in names:
        if n not in seen:
            seen.add(n); out.append(n)
    return out

def template_two_paragraphs(stats):
    op = opening_sentence(stats)
    p1 = "The match hinged on a few key overs where wickets and boundaries flipped momentum."
    p2 = "Standout spells and partnerships shaped the result in the closing stages."
    return (op + "\n\n" + p1 + "\n\n" + p2).strip()

def normalize_target_to_two_paras(text, opener):
    body = (text or "").strip()
    if opener and opener in body:
        body = body.replace(opener, "").strip()
    sents = re.split(r'(?<=[.!?])\s+', body)
    p1 = " ".join(sents[:3]).strip()
    p2 = " ".join(sents[3:6]).strip() or " ".join(sents[2:5]).strip()
    if not p2:
        p2 = "The result underscores the decisive passages of play."
    trimmed = (opener + "\n\n" + p1 + "\n\n" + p2).strip()
    return " ".join(trimmed.split()[:230])

# ---- main builder (uses preloaded silver_by_id once; no per-item re-read) ----
def build_training_example(mid, max_chunks=6):
    com = commentary_by_id.get(mid); sc = scorecards_by_id.get(mid)
    if not com or not sc:
        return None
    stats = compact_stats(sc.get("stats", {}))
    chunks = pick_chunks(com.get("commentary_chunks", []), max_chunks=max_chunks)
    opener = opening_sentence(stats)

    source_type = "template"
    target = template_two_paragraphs(stats)

    if mid in reports_by_id:
        target = normalize_target_to_two_paras(get(reports_by_id[mid],"report_text",""), opener)
        source_type = "gold"
    elif mid in silver_by_id:
        target = normalize_target_to_two_paras(get(silver_by_id[mid],"report_text",""), opener)
        source_type = "silver"

    allowed_names = ", ".join(extract_name_list(stats)) or "—"
    prompt = (
        "You are a cricket expert. Continue the match summary in exactly two short paragraphs (no headings).\n"
        "The REQUIRED opening sentence has already been written. Do NOT repeat it.\n\n"
        "Rules:\n"
        "- Maximum 6 sentences total across both paragraphs.\n"
        "- Use only facts from SCORECARD and COMMENTARY EXCERPTS. Never invent.\n"
        "- Use ONLY these player names if you mention players; if unsure, use generic phrases like 'the opener' or 'the seamer':\n"
        f"  ALLOWED NAMES: {allowed_names}\n"
        "- Do not include season-wide or streak claims unless present in SCORECARD.\n"
        "- Do not write headings or labels.\n"
        "- End with <END>.\n\n"
        f"SCORECARD:\n{json.dumps(stats, ensure_ascii=False)}\n\n"
        "COMMENTARY EXCERPTS:\n" + "\n".join(chunks) + "\n\n"
        "Continue with the two paragraphs only."
    )
    return {"match_id": mid, "text": prompt + "\n" + target + " <END>", "label_type": source_type}

train_set = []
for mid in train_ids:
    ex = build_training_example(mid)
    if ex and isinstance(ex.get("text"), str) and len(ex["text"]) > 50:
        train_set.append(ex)

print("Train items:", len(train_set), "| labels:", Counter([t["label_type"] for t in train_set]))
if len(train_set) == 0:
    print("⚠️ train_set is empty. Check that commentary.jsonl and scorecards.jsonl share the same match_id values,")
    print("   and that you copied them to", DATA_DIR, "or updated the paths above.")
else:
    # show an example id and a short preview so you can confirm structure
    sample = train_set[0]
    print("Sample match_id:", sample["match_id"])
    print("Prompt preview:", sample["text"][:220].replace("\n"," ") + " ...")

Loaded - commentary:71 scorecards:71 gold:4 silver:11
Intersected match_ids: 71 | train_ids (excl TEST): 70
Train items: 70 | labels: Counter({'template': 56, 'silver': 11, 'gold': 3})
Sample match_id: 1422120
Prompt preview: You are a cricket expert. Continue the match summary in exactly two short paragraphs (no headings). The REQUIRED opening sentence has already been written. Do NOT repeat it.  Rules: - Maximum 6 sentences total across bot ...


In [None]:
# Cell 8a — Align Triton/BnB to Torch 2.6 (cu124) and free VRAM

import gc, os, torch
for name in ["trainer","model","base_model","base","gen_pipe","tok"]:
    if name in globals():
        try: del globals()[name]
        except: pass
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("torch:", torch.__version__, "| cuda:", torch.version.cuda)

# Install exact matches for torch 2.6.0+cu124
!pip -q install --no-cache-dir --upgrade \
  "triton==3.2.0" \
  "bitsandbytes==0.45.0" \
  "accelerate>=0.34.2" \
  "transformers>=4.43.3"

torch: 2.6.0+cu124 | cuda: 12.4
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 MB[0m [31m59.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Cell 7.9 — Clean out stray bitsandbytes, ensure Triton OK
# Purpose: avoid accidental bnb import on FP16 path

import os, sys, subprocess, importlib

# 1) Uninstall bitsandbytes if present
try:
    import bitsandbytes  # noqa
    print("bitsandbytes is installed -> removing...")
    subprocess.check_call(["pip", "uninstall", "-y", "bitsandbytes"])
except Exception:
    print("bitsandbytes not installed. OK.")

# 2) Make sure Triton matches Torch (Colab usually has torch==2.6.0+cu124 -> triton==3.2.0)
import torch
torch_ver = torch.__version__
print("torch:", torch_ver, "| cuda:", torch.version.cuda)
want_triton = "3.2.0" if torch_ver.startswith("2.6") else ("2.3.1" if torch_ver.startswith("2.3") else None)

if want_triton:
    try:
        import triton
        print("triton:", triton.__version__)
        if triton.__version__ != want_triton:
            print(f"Installing triton=={want_triton} ...")
            subprocess.check_call(["pip", "install", "--no-cache-dir", f"triton=={want_triton}"])
    except Exception:
        if want_triton:
            print(f"Installing triton=={want_triton} ...")
            subprocess.check_call(["pip", "install", "--no-cache-dir", f"triton=={want_triton}"])

# 3) Sanity: ensure bnb can’t be imported anymore
if "bitsandbytes" in sys.modules:
    del sys.modules["bitsandbytes"]
try:
    import bitsandbytes  # noqa
    raise RuntimeError("bitsandbytes still importable; re-run this cell once more.")
except Exception:
    print("✅ bitsandbytes not importable. Proceed with Cell 8.")

bitsandbytes not installed. OK.
torch: 2.6.0+cu124 | cuda: 12.4
triton: 3.2.0
✅ bitsandbytes not importable. Proceed with Cell 8.


In [None]:
# Cell 7.9 — Hard clean bitsandbytes (safe path checks)

import sys, importlib, subprocess, os, shutil, site, glob

def _silent_run(cmd):
    try:
        subprocess.check_call(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
    except subprocess.CalledProcessError:
        pass

print("Uninstall bitsandbytes via pip (if present)...")
_silent_run(["pip", "uninstall", "-y", "bitsandbytes"])

# Candidate site-packages roots in Colab
candidates = set()
for p in [
    getattr(site, "getusersitepackages", lambda: None)(),
    *(getattr(site, "getsitepackages", lambda: [])() or []),
    "/usr/local/lib/python3.11/dist-packages",
    "/usr/local/lib/python3.10/dist-packages",
    "/usr/local/lib/python3.9/dist-packages",
]:
    if p and os.path.isdir(p):
        candidates.add(p)

removed = []
for sp in candidates:
    # remove folders and dist-info matching bitsandbytes*
    for pat in ["bitsandbytes*", "bitsandbytes_cuda*"]:
        for path in glob.glob(os.path.join(sp, pat)):
            try:
                if os.path.isdir(path):
                    shutil.rmtree(path, ignore_errors=True)
                elif os.path.isfile(path):
                    os.remove(path)
                removed.append(path)
            except Exception:
                pass

# Flush import cache and re-check
importlib.invalidate_caches()
if "bitsandbytes" in sys.modules:
    del sys.modules["bitsandbytes"]

print("Removed paths:", removed if removed else "(none found)")
try:
    import bitsandbytes  # noqa
    raise RuntimeError("bitsandbytes is STILL importable — run this cell once more or Restart runtime (Runtime > Restart).")
except Exception:
    print("✅ bitsandbytes not importable. Proceed to Cell 8 (FP16, no-bnb).")

Uninstall bitsandbytes via pip (if present)...
Removed paths: (none found)
✅ bitsandbytes not importable. Proceed to Cell 8 (FP16, no-bnb).


In [None]:
# Cell 7.10 — swap PEFT to a version that doesn't import bitsandbytes at import time

!pip -q uninstall -y peft || true
!pip -q install --no-cache-dir peft==0.9.0

import importlib, sys
if "peft" in sys.modules:
    import peft as _peft_old
    importlib.reload(_peft_old)
print("✅ PEFT pinned to 0.9.0")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/190.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.9/190.9 kB[0m [31m63.6 MB/s[0m eta [36m0:00:00[0m
[?25h✅ PEFT pinned to 0.9.0


In [None]:
# Cell 8 — Load base model (BF16) + attach LoRA  (REPLACEMENT)
import os, torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# MODEL_ID is set in Cell 3
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,   # ← BF16 for training on A100/L4
    attn_implementation="sdpa",
    device_map="auto",
)
base_model.gradient_checkpointing_enable()
base_model.config.use_cache = False

lora_cfg = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    bias="none", task_type="CAUSAL_LM",
)

model = get_peft_model(base_model, lora_cfg)
model.print_trainable_parameters()
print("✅ BF16 + LoRA ready.")

OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct.
403 Client Error. (Request ID: Root=1-68a4a2fd-4c2f98407797bbed2b184040;f6b72bcb-502d-4017-a311-cca72cf81615)

Cannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.1-8B-Instruct is restricted and you are not in the authorized list. Visit https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct to ask for access.

In [None]:
# Cell 8 — Load base model (FP16, no bitsandbytes) + attach LoRA
# Purpose: avoid any bnb dependency; force PEFT to use plain torch layers.

import os, torch, importlib

# 1) Make absolutely sure PEFT won't import bnb adapters
os.environ["PEFT_DISABLE_BNB_ADAPTERS"] = "1"   # <-- key line (must be set BEFORE importing peft)

from transformers import AutoModelForCausalLM, AutoTokenizer
import peft
importlib.reload(peft)  # ensure the env var takes effect if peft was imported earlier
from peft import LoraConfig, get_peft_model

# 2) Torch runtime tweaks
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# 3) Model + tokenizer
MODEL_ID = "meta-llama/Meta-Llama-3.1-3B-Instruct"
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# 4) Load base in FP16, let Accelerate place layers
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    attn_implementation="sdpa",
    device_map="auto",
)
base_model.gradient_checkpointing_enable()
base_model.config.use_cache = False

# 5) Plain LoRA on attention + MLP projections (no bnb layers)
lora_cfg = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    bias="none", task_type="CAUSAL_LM",
)

model = get_peft_model(base_model, lora_cfg)
model.print_trainable_parameters()
print("✅ FP16 + LoRA ready (no bitsandbytes).")

OSError: meta-llama/Meta-Llama-3.1-3B-Instruct is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `hf auth login` or by passing `token=<your_token>`

In [None]:
# Cell 9 — Train with a plain Torch Dataset (bypasses HF Datasets/Numpy/Arrow)
# Purpose: avoid NumPy 2.x copy issues; keep everything in torch tensors.

import os, time, json, torch
from torch.utils.data import Dataset as TorchDataset
from transformers import TrainingArguments, Trainer, default_data_collator

assert "train_set" in globals() and len(train_set) > 0, "Run Cell 7 first."
assert "tok" in globals() and "model" in globals(), "Run Cell 8 first."

MAX_LEN = 1536  # if VRAM is tight, try 1280 or 1024

def encode_text(example_text: str):
    enc = tok(
        example_text,
        truncation=True,
        max_length=MAX_LEN,
        padding="max_length",
        return_tensors="pt",
    )
    # labels = input_ids for causal LM
    enc["labels"] = enc["input_ids"].clone()
    # squeeze to 1D tensors
    return {k: v.squeeze(0) for k, v in enc.items()}

# Pre-tokenize once (fast & saves memory fragmentation during training)
encoded = [encode_text(row["text"]) for row in train_set]

class SimpleLMDataset(TorchDataset):
    def __init__(self, enc_list): self.data = enc_list
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

torch_ds = SimpleLMDataset(encoded)
print("Train samples:", len(torch_ds))

args = TrainingArguments(
    output_dir="/content/lora_out",
    learning_rate=2e-4,
    num_train_epochs=1,                # 2–3 for final runs
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=20,
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    fp16=True,                         # using FP16 weights
    report_to="none",
    remove_unused_columns=False,       # important with custom torch dataset
    ddp_find_unused_parameters=False,
    max_grad_norm=0.5,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=torch_ds,
    data_collator=default_data_collator,   # no extra padding/shaping — we already padded
)

trainer.train()

# Save to a versioned run dir
SAVE_DIR = "/content/drive/MyDrive/model_weights_tokens_files"
os.makedirs(SAVE_DIR, exist_ok=True)
ts = time.strftime("%Y%m%d-%H%M%S")
SAVE_DIR_RUN = os.path.join(SAVE_DIR, f"mistral7b_fp16_noshard_{ts}")
os.makedirs(SAVE_DIR_RUN, exist_ok=True)

model.save_pretrained(SAVE_DIR_RUN)
tok.save_pretrained(SAVE_DIR_RUN)
with open(os.path.join(SAVE_DIR_RUN, "train_manifest_summary.json"), "w", encoding="utf-8") as f:
    json.dump(
        {
            "model_id": "mistralai/Mistral-7B-Instruct-v0.2",
            "train_count": len(train_set),
            "max_len": MAX_LEN,
            "notes": "FP16 LoRA; no sharding; torch Dataset; excludes TEST_ID.",
        },
        f,
        indent=2,
    )
print("Saved to:", SAVE_DIR_RUN)

Train samples: 70


ValueError: Attempting to unscale FP16 gradients.

In [None]:
# Cell 10 — Inference (guarded) with name whitelist + anti-ball-by-ball + fallback if too short
# Purpose: avoid hallucinated ball-by-ball, keep only allowed names, and back off to role-based prose if needed.

import os, re, json, torch
from collections import Counter
from transformers import (
    StoppingCriteria, StoppingCriteriaList, LogitsProcessorList,
    NoBadWordsLogitsProcessor, InfNanRemoveLogitsProcessor
)

assert "SAVE_DIR_RUN" in globals(), "Run Cell 9 first."
assert "commentary_by_id" in globals() and "scorecards_by_id" in globals(), "Run Cell 7 first."
TEST_ID = globals().get("TEST_ID", "1422119")

# ---------- helpers (same as before; trimmed where safe) ----------
def loser_of(stats):
    t1, t2, w = stats.get("team1",""), stats.get("team2",""), stats.get("winner","")
    return t2 if w == t1 else t1

def format_margin(stats):
    margin = stats.get("result_margin"); res_txt=(stats.get("result") or "").lower()
    if margin in (None,""): return ""
    try:
        n=int(margin); unit="wickets" if "wicket" in res_txt else ("runs" if "run" in res_txt else "")
        return f"{n} {unit}".strip()
    except: return str(margin)

def opening_sentence(stats):
    w=stats.get("winner",""); l=loser_of(stats); m=format_margin(stats); v=stats.get("venue","")
    return f"{w} defeated {l} by {m} at {v}." if (w and l and m and v) else ""

def compact_stats(stats):
    get=lambda o,k,d=None:(o or {}).get(k,d)
    return {
        "team1":get(stats,"team1",""), "team2":get(stats,"team2",""),
        "winner":get(stats,"winner",""), "result":get(stats,"result",""),
        "result_margin":get(stats,"result_margin",""), "venue":get(stats,"venue",""),
        "toss_winner":get(stats,"toss_winner",""), "toss_decision":get(stats,"toss_decision",""),
        "top_batters":get(stats,"top_batters",[]), "top_bowlers":get(stats,"top_bowlers",[]),
        "players_t1":get(stats,"players_t1",[]), "players_t2":get(stats,"players_t2",[]),
        "batting_card_t1":get(stats,"batting_card_t1",[]), "batting_card_t2":get(stats,"batting_card_t2",[]),
        "bowling_card_t1":get(stats,"bowling_card_t1",[]), "bowling_card_t2":get(stats,"bowling_card_t2",[]),
    }

def pick_chunks(chunks, max_chunks=4): return (chunks or [])[:max_chunks]

def contains_literal(text, value):
    if not value: return True
    t=re.sub(r"\s+"," ",text).lower(); v=re.sub(r"\s+"," ",str(value)).lower()
    return v in t

def clean_headings(txt):
    txt=re.sub(r"(?m)^\s*(P\d:|Paragraph\s*\d:|Pargraph\s*\d:)\s*","",txt)
    txt=re.sub(r"\s+\.",".",txt); txt=re.sub(r"\s+,",",",txt)
    return txt.strip()

# whitelist + sanitizers
NAME_RX = re.compile(r"[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*")
_GENERIC_BAD = {"Chennai","Super","Kings","Royal","Challengers","Bengaluru","Bangalore",
                "Stadium","Chepauk","Match","Overs","Runs","Wickets","Powerplay","Review",
                "DRS","Paragraph","Pargraph","Section","Heading","Outline"}

def extract_names_from_text(txt):
    cands=[m.group(0).strip() for m in NAME_RX.finditer(txt)]
    return [c for c in cands if c not in _GENERIC_BAD and len(c)<=30]

def build_player_whitelist(stats, chunks):
    # From structured stats
    allowed=set()
    for key in ("top_batters","top_bowlers","players_t1","players_t2",
                "batting_card_t1","batting_card_t2","bowling_card_t1","bowling_card_t2"):
        for item in (stats or {}).get(key, []):
            if isinstance(item, dict):
                for k in ("name","player","batter","bowler"):
                    if item.get(k): allowed.add(str(item[k]).strip())
            elif isinstance(item, str):
                allowed.add(item.strip())
    # Frequent names from commentary (only if appear >=3 times)
    freq=Counter()
    for c in chunks:
        for nm in extract_names_from_text(c): freq[nm]+=1
    for nm,cnt in freq.items():
        if cnt>=3: allowed.add(nm)
    return sorted({re.sub(r"\s+"," ",a).strip() for a in allowed})

def sanitize_names(text, whitelist):
    if not whitelist: return text
    wl_lower={w.lower():w for w in whitelist}
    lines=[]
    for line in text.splitlines():
        fixed=line
        for n in set(extract_names_from_text(line)):
            rep=wl_lower.get(n.lower(), n)
            if rep!=n:
                fixed=re.sub(rf"\b{re.escape(n)}\b", rep, fixed)
        # drop lines that still contain unknown names AND claim specific actions
        oov=[n for n in extract_names_from_text(fixed) if n.lower() not in wl_lower]
        if oov and re.search(r"\b(dismissed|bowled|caught|lbw|stumped|run out|scored|hit|smashed|took|figures|overs|partnership)\b", fixed.lower()):
            continue
        lines.append(fixed)
    return "\n".join(lines).strip()

def valid_over_notation(text): return re.sub(r"(\b4)\.(\d)\s*overs", r"\1 overs", text)

def sanitize_numbers(text, stats):
    txt=re.sub(r"\brespectable total\b","a total",text,flags=re.I)
    res=(stats.get("result") or "").lower()
    if "wicket" in res:
        txt=re.sub(r"\bfell short\b.*?(\.|\n)",". ",txt,flags=re.I)
    return valid_over_notation(txt)

def enforce_two_paragraphs(text):
    paras=[p.strip() for p in text.split("\n\n") if p.strip()]
    return "\n\n".join(paras[:3])

# decoding controls
class StopOnTokens(StoppingCriteria):
    def __init__(self, stop_ids): self.stop_ids=stop_ids
    def __call__(self, input_ids, scores, **kwargs)->bool:
        S=self.stop_ids.shape[-1]
        if input_ids.shape[-1]<S: return False
        return torch.equal(input_ids[0,-S:].cpu(), self.stop_ids.cpu())

def make_stop_criteria(tokenizer, stop_str):
    stop_ids=tokenizer(stop_str, add_special_tokens=False, return_tensors="pt").input_ids[0]
    return StoppingCriteriaList([StopOnTokens(stop_ids)])

def build_bad_words_ids(tokenizer, words):
    ids=[]
    for w in words:
        toks=tokenizer(w, add_special_tokens=False).input_ids
        if toks: ids.append(toks)
    return ids

def contradiction_phrases(stats):
    res=(stats.get("result") or "").lower()
    if "wicket" in res:  # winner chased
        return ["fell short","could not chase","defended the total","won by runs","victory by runs","ran out of overs"]
    if "run" in res:     # winner defended
        return ["won by wickets","victory by wickets","chased down comfortably","reached the target","got over the line in the chase"]
    return []

# extra bans to kill ball-by-ball tone
BALL_BY_BALL_BANS = [
    "first ball", "second ball", "third ball", "fourth ball", "fifth ball", "sixth ball",
    "next ball", "the very next ball", "over-by-over", "ball-by-ball",
    "beams past", "attempted pull shot", "york length", "length ball that swings",
]

def build_prompt(stats, chunks, whitelist, forbid_names=False):
    opener = opening_sentence(stats)
    # Names policy
    if forbid_names:
        names_line = "- Do not use player names; use generic roles like 'the opener' or 'the seamer'.\n"
    else:
        allowed_names = ", ".join(whitelist) or "—"
        names_line = (
            "- Use ONLY these player names if you mention players; "
            "otherwise use roles like 'the opener' or 'the seamer':\n"
            f"  ALLOWED NAMES: {allowed_names}\n"
        )
    return (
        "You are a cricket expert. Continue the match summary in exactly two short paragraphs (no headings).\n"
        "The REQUIRED opening sentence has already been written. Do NOT repeat it.\n\n"
        "Rules:\n"
        "- Max 6 sentences total across both paragraphs.\n"
        "- Use only facts from SCORECARD and COMMENTARY EXCERPTS. Never invent.\n"
        "- Write aggregate events (key overs/partnerships) — DO NOT narrate ball-by-ball.\n"
        f"{names_line}"
        "- Do not include season-wide or streak claims unless present in SCORECARD.\n"
        "- End with <END>.\n\n"
        f"SCORECARD:\n{json.dumps(stats, ensure_ascii=False)}\n\n"
        "COMMENTARY EXCERPTS:\n" + "\n".join(chunks) + "\n\n"
        "Continue with the two paragraphs only."
    ), opener

def generate_once(prompt, min_nt=150, max_nt=230, banned_extra=None):
    enc = tok(prompt, return_tensors="pt")
    input_ids = enc.input_ids.to(model.device)
    attention_mask = enc.attention_mask.to(model.device)

    banned = ["P1:","P2:","Paragraph 1:","Paragraph 2:","Section","Heading","Outline:"]
    banned += (banned_extra or [])
    bad_words_ids = build_bad_words_ids(tok, banned)

    logits_processors = LogitsProcessorList([
        InfNanRemoveLogitsProcessor(),
        NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids, eos_token_id=tok.eos_token_id),
    ])
    stop_criteria = make_stop_criteria(tok, "<END>")

    gen_kwargs = dict(
        do_sample=False,
        repetition_penalty=1.02,
        no_repeat_ngram_size=4,
        min_new_tokens=min_nt,
        max_new_tokens=max_nt,
        pad_token_id=tok.pad_token_id,
        use_cache=True,
        eos_token_id=None,
        logits_processor=logits_processors,
        stopping_criteria=stop_criteria,
    )

    with torch.no_grad():
        out_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)

    prompt_len = input_ids.shape[-1]
    cont_ids = out_ids[0, prompt_len:]
    raw = tok.decode(cont_ids, skip_special_tokens=True).split("<END>")[0].strip()
    return clean_headings(raw)

# Additional sanitizer to strip ball-by-ball invented details
def strip_ball_by_ball_sentences(text: str) -> str:
    patterns = [
        r"\bcaught at\b", r"\byork(ed|er)?\b", r"\brun out\b", r"\bbowled\b",
        r"\bstumped\b", r"\bslower ball\b", r"\battempted\b", r"\b(first|second|third|fourth|fifth|sixth) ball\b",
        r"\b18th over\b", r"\b19th over\b", r"\b20th over\b"
    ]
    out_sents=[]
    for sent in re.split(r'(?<=[.!?])\s+', text):
        if any(re.search(p, sent, flags=re.I) for p in patterns):
            continue
        out_sents.append(sent)
    return " ".join(out_sents).strip()

def infer_for_match(mid: str, max_chunks=4, save=True):
    com = commentary_by_id[mid]; sc = scorecards_by_id[mid]
    stats = compact_stats(sc["stats"])
    chunks = pick_chunks(com.get("commentary_chunks", []), max_chunks=max_chunks)
    whitelist = build_player_whitelist(stats, chunks)

    # Pass 1 — allow names (from whitelist), ban ball-by-ball phrases
    prompt1, opener = build_prompt(stats, chunks, whitelist, forbid_names=False)
    raw1 = generate_once(prompt1, min_nt=140, max_nt=220, banned_extra=contradiction_phrases(stats)+BALL_BY_BALL_BANS)
    report_raw = opener + ("\n\n" if opener else "") + raw1

    # Sanitize
    report_san = sanitize_names(report_raw, whitelist)
    report_san = sanitize_numbers(report_san, stats)
    report_san = strip_ball_by_ball_sentences(report_san)
    report_san = enforce_two_paragraphs(report_san)

    # Guardrails: ensure key literals and minimal length; else fallback Pass 2 (no names)
    need_fallback = False
    for needed in (stats.get("winner",""), format_margin(stats), stats.get("venue","")):
        if not contains_literal(report_san, needed):
            need_fallback = True
            break
    if len(report_san) < 240:  # too short after sanitization
        need_fallback = True

    if need_fallback:
        prompt2, opener2 = build_prompt(stats, chunks, whitelist, forbid_names=True)
        raw2 = generate_once(prompt2, min_nt=130, max_nt=210, banned_extra=contradiction_phrases(stats)+BALL_BY_BALL_BANS)
        report_raw2 = opener2 + ("\n\n" if opener2 else "") + raw2
        report_san2 = sanitize_numbers(report_raw2, stats)  # no name sanitization needed (we forbade names)
        report_san2 = enforce_two_paragraphs(report_san2)
        # prefer the longer, constraint-satisfying version
        if len(report_san2) >= len(report_san):
            report_san = report_san2
            report_raw = report_raw2

    # Save both RAW & SANITIZED
    out_dir = os.path.join(SAVE_DIR_RUN, "reports_gen"); os.makedirs(out_dir, exist_ok=True)
    with open(os.path.join(out_dir, f"{mid}_RAW.txt"), "w", encoding="utf-8") as f: f.write(report_raw)
    with open(os.path.join(out_dir, f"{mid}_SANITIZED.txt"), "w", encoding="utf-8") as f: f.write(report_san)
    print("Saved:", os.path.join(out_dir, f"{mid}_RAW.txt"))
    print("Saved:", os.path.join(out_dir, f"{mid}_SANITIZED.txt"))
    return report_san

# ---- run for the held-out match ----
final_text = infer_for_match(TEST_ID, max_chunks=4, save=True)
print("\n=== PREVIEW ===\n")
for i, p in enumerate([q for q in final_text.split("\n\n") if q.strip()][:3]):
    print(p)
    if i < 2: print()

Saved: /content/drive/MyDrive/model_weights_tokens_files/mistral7b_fp16_noshard_20250818-221113/reports_gen/1422119_RAW.txt
Saved: /content/drive/MyDrive/model_weights_tokens_files/mistral7b_fp16_noshard_20250818-221113/reports_gen/1422119_SANITIZED.txt

=== PREVIEW ===

Chennai Super Kings defeated Royal Challengers Bengaluru by 6 wickets at MA Chidambaram Stadium, Chepauk, Chennai.

The opener's 48-ball 48 was the cornerstone of RCB's 132-run opening stand. But CSK's bowlers struck back with regular intervals. The Fizz, in particular, was on fire. He dismissed Kohli, Maxwell and Patidar in quick succession. The latter two were golden ducks. The pressure built up on the RCB batsmen, and they could not cope with it.
The middle order failed to capitalise on the start provided by the openers. Karthick and Ravindran scored 38 and 37 respectively, but they could not accelerate. The required rate kept climbing, and RCB lost wickets at regular intervals. CSK' s bowlers kept striking back. Gr

In [None]:
# Cell X — Evaluate a generated summary (facts + style + ROUGE-L)
# Purpose: quick automatic checks against scorecard facts and a gold report (if available).

import os, re, json

# --------- utilities (lightweight, self-contained) ----------
def contains_literal(text, value):
    if not value: return True
    t = re.sub(r"\s+", " ", text or "").lower()
    v = re.sub(r"\s+", " ", str(value)).lower()
    return v in t

def lcs_len(a, b):
    n, m = len(a), len(b)
    dp = [ [0]*(m+1) for _ in range(n+1) ]
    for i in range(1, n+1):
        ai = a[i-1]
        dpi = dp[i]
        dpim1 = dp[i-1]
        for j in range(1, m+1):
            dpi[j] = dpim1[j-1] + 1 if ai == b[j-1] else (dpi[j-1] if dpi[j-1] > dpim1[j] else dpim1[j])
    return dp[n][m]

def rouge_l(pred, ref):
    p = (pred or "").split()
    r = (ref or "").split()
    if not p or not r:
        return {"r": 0.0, "p": 0.0, "f": 0.0}
    lcs = lcs_len(p, r)
    rec = lcs / len(r)
    prec = lcs / len(p)
    f = 0.0 if (rec + prec) == 0 else (2 * prec * rec) / (prec + rec + 1e-12)
    return {"r": round(rec, 4), "p": round(prec, 4), "f": round(f, 4)}

def loser_of(stats):
    t1, t2, w = stats.get("team1",""), stats.get("team2",""), stats.get("winner","")
    return t2 if w and w == t1 else (t1 if w else "")

def format_margin(stats):
    margin = stats.get("result_margin")
    res_txt = (stats.get("result") or "").lower()
    if margin in (None, ""): return ""
    try:
        n = int(margin)
        unit = "wickets" if "wicket" in res_txt else ("runs" if "run" in res_txt else "")
        return f"{n} {unit}".strip()
    except Exception:
        return str(margin)

BALL_BY_BALL_PATTERNS = [
    r"\b(first|second|third|fourth|fifth|sixth)\s+ball\b",
    r"\bnext ball\b", r"\bover\-by\-over\b", r"\bball\-by\-ball\b",
    r"\bcaught at\b", r"\byork(ed|er)\b", r"\brun out\b", r"\bstumped\b",
    r"\bslower ball\b", r"\b(\d+)(th)?\s+over\b",
]

NAME_RX = re.compile(r"[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*")
_GENERIC_BAD = {"Chennai","Super","Kings","Royal","Challengers","Bengaluru","Bangalore",
                "Stadium","Chepauk","Match","Overs","Runs","Wickets","Powerplay","Review","DRS"}

def extract_names_from_text(txt):
    cands = [m.group(0).strip() for m in NAME_RX.finditer(txt or "")]
    return [c for c in cands if c not in _GENERIC_BAD and len(c) <= 30]

def build_whitelist(stats):
    wl = set()
    for key in ("top_batters","top_bowlers","players_t1","players_t2",
                "batting_card_t1","batting_card_t2","bowling_card_t1","bowling_card_t2"):
        for item in (stats or {}).get(key, []) or []:
            if isinstance(item, dict):
                for k in ("name","player","batter","bowler"):
                    if item.get(k): wl.add(str(item[k]).strip())
            elif isinstance(item, str):
                wl.add(item.strip())
    return {w.lower() for w in wl}

# --------- load inputs ---------
mid = globals().get("TEST_ID", "1422119")
save_dir = globals().get("SAVE_DIR_RUN", None)
assert save_dir, "SAVE_DIR_RUN not found (run training cell that sets it)."

# Prefer SANITIZED; fall back to RAW if needed
p_san = os.path.join(save_dir, "reports_gen", f"{mid}_SANITIZED.txt")
p_raw = os.path.join(save_dir, "reports_gen", f"{mid}_RAW.txt")
with open(p_san if os.path.exists(p_san) else p_raw, "r", encoding="utf-8") as f:
    pred_text = f.read()

# scorecard + (optional) gold
assert "scorecards_by_id" in globals(), "Run the data-loading cell (Cell 7)."
stats = scorecards_by_id[mid]["stats"]
gold_text = None
if "reports_by_id" in globals() and mid in reports_by_id:
    gold_text = reports_by_id[mid].get("report_text", None)

# --------- factual checks ---------
facts_ok = {
    "winner": contains_literal(pred_text, stats.get("winner","")),
    "loser":  contains_literal(pred_text, loser_of(stats)) if loser_of(stats) else True,
    "margin_text": contains_literal(pred_text, format_margin(stats)),
    "venue":  contains_literal(pred_text, stats.get("venue","")),
    "toss":   (contains_literal(pred_text, stats.get("toss_winner","")) and
               contains_literal(pred_text, stats.get("toss_decision",""))),
}
facts_ok["all_pass"] = all(facts_ok.values())

# --------- style & safety checks ---------
ball_by_ball_hits = sum(bool(re.search(p, pred_text, flags=re.I)) for p in BALL_BY_BALL_PATTERNS)
names_in_text = extract_names_from_text(pred_text)
whitelist = build_whitelist(stats)
oov_names = [n for n in names_in_text if n.lower() not in whitelist]

style = {
    "chars": len(pred_text),
    "sentences": len(re.split(r'(?<=[.!?])\s+', pred_text.strip())) if pred_text.strip() else 0,
    "ball_by_ball_flags": ball_by_ball_hits,
    "names_total": len(names_in_text),
    "names_oov": len(oov_names),
    "oov_samples": oov_names[:5],
}

# --------- ROUGE-L (if gold exists) ---------
rouge = None
if gold_text:
    rouge = rouge_l(pred_text, gold_text)

# --------- display ---------
print("MATCH_ID:", mid)
print("\nFACTS:", facts_ok)
print("\nSTYLE:", style)
if rouge:
    print("\nROUGE-L:", rouge)
else:
    print("\nROUGE-L: (skipped — no gold text available for this match)")

# Small excerpt print for human glance
print("\n--- PRED EXCERPT ---")
print(pred_text[:600].replace("\n"," ") + ("..." if len(pred_text) > 600 else ""))

MATCH_ID: 1422119

FACTS: {'winner': True, 'loser': True, 'margin_text': True, 'venue': True, 'toss': True, 'all_pass': True}

STYLE: {'chars': 867, 'sentences': 15, 'ball_by_ball_flags': 0, 'names_total': 20, 'names_oov': 20, 'oov_samples': ['Chennai Super Kings', 'Royal Challengers Bengaluru', 'Chidambaram Stadium', 'The', 'But']}

ROUGE-L: {'r': 0.1791, 'p': 0.1667, 'f': 0.1727}

--- PRED EXCERPT ---
Chennai Super Kings defeated Royal Challengers Bengaluru by 6 wickets at MA Chidambaram Stadium, Chepauk, Chennai.  The opener's 48-ball 48 was the cornerstone of RCB's 132-run opening stand. But CSK's bowlers struck back with regular intervals. The Fizz, in particular, was on fire. He dismissed Kohli, Maxwell and Patidar in quick succession. The latter two were golden ducks. The pressure built up on the RCB batsmen, and they could not cope with it. The middle order failed to capitalise on the start provided by the openers. Karthick and Ravindran scored 38 and 37 respectively, but they 

In [None]:
# Cell 10 — Inference that always completes (delayed stop + safe auto-continue)
# Purpose: guarantee a full two-paragraph summary; never end mid-thought.

import os, re, json, torch
from transformers import (
    StoppingCriteria, StoppingCriteriaList, LogitsProcessorList,
    NoBadWordsLogitsProcessor, InfNanRemoveLogitsProcessor
)

assert "SAVE_DIR_RUN" in globals(), "Run training/saving (Cell 9) first."
assert "commentary_by_id" in globals() and "scorecards_by_id" in globals(), "Run data load (Cell 7) first."
TEST_ID = globals().get("TEST_ID", "1422119")

# -------- helpers --------
def loser_of(stats):
    t1, t2, w = stats.get("team1",""), stats.get("team2",""), stats.get("winner","")
    return t2 if w == t1 else t1

def format_margin(stats):
    margin = stats.get("result_margin"); res_txt=(stats.get("result") or "").lower()
    if margin in (None,""): return ""
    try:
        n=int(margin); unit="wickets" if "wicket" in res_txt else ("runs" if "run" in res_txt else "")
        return f"{n} {unit}".strip()
    except: return str(margin)

def opening_sentence(stats):
    w=stats.get("winner",""); l=loser_of(stats); m=format_margin(stats); v=stats.get("venue","")
    return f"{w} defeated {l} by {m} at {v}." if (w and l and m and v) else ""

def compact_stats(stats):
    g=lambda o,k,d=None:(o or {}).get(k,d)
    return {
        "team1":g(stats,"team1",""), "team2":g(stats,"team2",""),
        "winner":g(stats,"winner",""), "result":g(stats,"result",""),
        "result_margin":g(stats,"result_margin",""), "venue":g(stats,"venue",""),
        "toss_winner":g(stats,"toss_winner",""), "toss_decision":g(stats,"toss_decision",""),
        "top_batters":g(stats,"top_batters",[]), "top_bowlers":g(stats,"top_bowlers",[]),
        "players_t1":g(stats,"players_t1",[]), "players_t2":g(stats,"players_t2",[]),
        "batting_card_t1":g(stats,"batting_card_t1",[]), "batting_card_t2":g(stats,"batting_card_t2",[]),
        "bowling_card_t1":g(stats,"bowling_card_t1",[]), "bowling_card_t2":g(stats,"bowling_card_t2",[]),
    }

def pick_chunks(chunks, k=3): return (chunks or [])[:k]

def clean_headings(txt):
    txt=re.sub(r"(?m)^\s*(P\d:|Paragraph\s*\d:|Pargraph\s*\d:)\s*","",txt)
    txt=re.sub(r"\s+\.",".",txt); txt=re.sub(r"\s+,",",",txt)
    return txt.strip()

def contains_literal(text, value):
    if not value: return True
    t=re.sub(r"\s+"," ", text or "").lower()
    v=re.sub(r"\s+"," ", str(value)).lower()
    return v in t

# --- soft name handling (no line drops) ---
NAME_RX = re.compile(r"[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*")
def extract_names(txt): return [m.group(0) for m in NAME_RX.finditer(txt or "")]

def build_whitelist(stats):
    wl=set()
    for key in ("top_batters","top_bowlers","players_t1","players_t2",
                "batting_card_t1","batting_card_t2","bowling_card_t1","bowling_card_t2"):
        for it in (stats or {}).get(key, []) or []:
            if isinstance(it, dict):
                for k in ("name","player","batter","bowler"):
                    if it.get(k): wl.add(str(it[k]).strip())
            elif isinstance(it, str):
                wl.add(it.strip())
    # also allow team names, venue & common aliases
    for k in ("team1","team2","venue"):
        v=(stats or {}).get(k,"")
        if v: wl.add(v)
    t1, t2 = stats.get("team1",""), stats.get("team2","")
    if "Chennai Super Kings" in (t1+t2): wl.update({"CSK","Chennai"})
    if "Royal Challengers" in (t1+t2): wl.update({"RCB","Bengaluru","Bangalore"})
    return {w.lower():w for w in wl}

def sanitize_names_soft(text, wl_map):
    out=text
    for n in set(extract_names(text)):
        rep = wl_map.get(n.lower())
        if rep and rep != n:
            out = re.sub(rf"\b{re.escape(n)}\b", rep, out)
    # fix a couple of annoying artifacts
    out = out.replace("Player response", "In response")
    out = out.replace("Player (IPL)", "the IPL")
    return out

# --- keep only egregious ball-by-ball phrases out ---
BB_PATTERNS = [r"\bball\-by\-ball\b", r"\bover\-by\-over\b"]
def strip_ball_by_ball(text):
    sents = re.split(r'(?<=[.!?])\s+', text)
    kept=[s for s in sents if not any(re.search(p, s, flags=re.I) for p in BB_PATTERNS)]
    return " ".join(kept).strip()

def enforce_two_paragraphs(text):
    paras=[p.strip() for p in text.split("\n\n") if p.strip()]
    if len(paras)>=3: return "\n\n".join(paras[:3])
    if len(paras)==1:
        sents=re.split(r'(?<=[.!?])\s+', paras[0])
        cut=max(2, len(sents)//2)
        return ( " ".join(sents[:cut]).strip() + "\n\n" + " ".join(sents[cut:]).strip() ).strip()
    return "\n\n".join(paras)

# --- decoding control: delayed stop on <END> ---
class StopOnTokensAfterMin(StoppingCriteria):
    def __init__(self, stop_ids, prompt_len, min_new_tokens):
        self.stop_ids=stop_ids; self.prompt_len=prompt_len; self.min_new_tokens=min_new_tokens
    def __call__(self, input_ids, scores, **kwargs)->bool:
        gen_len = input_ids.shape[-1]-self.prompt_len
        if gen_len < self.min_new_tokens: return False
        S=self.stop_ids.shape[-1]
        if input_ids.shape[-1] < S: return False
        return torch.equal(input_ids[0,-S:].cpu(), self.stop_ids.cpu())

def make_stop_after_min(tokenizer, stop_str, prompt_len, min_new_tokens):
    stop_ids = tokenizer(stop_str, add_special_tokens=False, return_tensors="pt").input_ids[0]
    return StoppingCriteriaList([StopOnTokensAfterMin(stop_ids, prompt_len, min_new_tokens)])

def build_bad_words_ids(tokenizer, words):
    ids=[];
    for w in words:
        toks=tokenizer(w, add_special_tokens=False).input_ids
        if toks: ids.append(toks)
    return ids

def contradiction_phrases(stats):
    res=(stats.get("result") or "").lower()
    if "wicket" in res:  # winner chased
        return ["fell short","could not chase","defended the total","won by runs","victory by runs","ran out of overs"]
    if "run" in res:     # winner defended
        return ["won by wickets","victory by wickets","chased down comfortably","reached the target","got over the line in the chase"]
    return []

# --- prompting & generation ---
def build_prompt(stats, chunks, wl_map):
    allowed = ", ".join(sorted(set(wl_map.values())))
    return (
        "You are a cricket expert. Continue the match summary in exactly two short paragraphs (no headings).\n"
        "The REQUIRED opening sentence has already been written. Do NOT repeat it.\n\n"
        "Rules:\n"
        "- Max 6 sentences total across both paragraphs.\n"
        "- Use only facts from SCORECARD and COMMENTARY EXCERPTS. Never invent.\n"
        "- Summarise key overs/partnerships; do NOT narrate ball-by-ball.\n"
        f"- If you mention players, use ONLY these names (or use roles like 'the opener'/'the seamer'): {allowed or '—'}\n"
        "- End with <END>.\n\n"
        f"SCORECARD:\n{json.dumps(stats, ensure_ascii=False)}\n\n"
        "COMMENTARY EXCERPTS:\n" + "\n".join(chunks) + "\n\n"
        "Continue with the two paragraphs only."
    )

def generate_once(prompt, min_nt=220, max_nt=340, banned_extra=None):
    enc = tok(prompt, return_tensors="pt")
    input_ids = enc.input_ids.to(model.device)
    attention_mask = enc.attention_mask.to(model.device)

    banned = ["P1:","P2:","Paragraph 1:","Paragraph 2:","Section","Heading","Outline:"]
    if banned_extra: banned += banned_extra
    bad_words_ids = build_bad_words_ids(tok, banned)

    stop_crit = make_stop_after_min(tok, "<END>", prompt_len=input_ids.shape[-1], min_new_tokens=min_nt)
    logits_processors = LogitsProcessorList([
        InfNanRemoveLogitsProcessor(),
        NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids, eos_token_id=tok.eos_token_id),
    ])

    gen_kwargs = dict(
        do_sample=False,
        repetition_penalty=1.02,
        no_repeat_ngram_size=4,
        min_new_tokens=min_nt,
        max_new_tokens=max_nt,
        pad_token_id=tok.pad_token_id,
        use_cache=True,
        eos_token_id=None,  # stop only on <END>
        logits_processor=logits_processors,
        stopping_criteria=stop_crit,
    )

    with torch.no_grad():
        out_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)

    prompt_len = input_ids.shape[-1]
    cont_ids = out_ids[0, prompt_len:]
    text = tok.decode(cont_ids, skip_special_tokens=True)
    # Return up to <END> if present; else full continuation
    return text.split("<END>")[0].strip() if "<END>" in text else text.strip()

def continue_once(opener, partial):
    # Finish cleanly with one extra sentence and <END>
    cont_prompt = (
        "You are a cricket expert. The summary stopped early. "
        "Write ONE closing sentence that wraps up why the result happened. Then output <END>.\n\n"
        f"OPENING:\n{opener}\n\n"
        f"SUMMARY SO FAR:\n{partial}\n\n"
        "One sentence only. End with <END>."
    )
    add = generate_once(cont_prompt, min_nt=12, max_nt=40, banned_extra=[])
    if "<END>" in add: add = add.split("<END>")[0].strip()
    if not add.endswith((".", "!", "?")): add += "."
    # append to paragraph 2
    if "\n\n" in partial:
        return (partial.rstrip() + " " + add).strip()
    # if model returned one long para, split then append
    sents = re.split(r'(?<=[.!?])\s+', partial)
    cut = max(2, len(sents)//2)
    return (" ".join(sents[:cut]).strip() + "\n\n" + " ".join(sents[cut:]).strip() + " " + add).strip()

def infer_for_match(mid, max_chunks=3, save=True):
    com = commentary_by_id[mid]; sc = scorecards_by_id[mid]
    stats = compact_stats(sc["stats"])
    chunks = pick_chunks(com.get("commentary_chunks", []), k=max_chunks)
    wl_map = build_whitelist(stats)

    prompt = build_prompt(stats, chunks, wl_map)
    banned_extra = contradiction_phrases(stats)

    body = generate_once(prompt, min_nt=220, max_nt=340, banned_extra=banned_extra)
    opener = opening_sentence(stats)
    body = clean_headings(body)

    if "<END>" not in body:
        body = continue_once(opener, body)

    # Assemble & soft-clean
    raw_report = (opener + ("\n\n" if opener else "") + body).strip()
    san = sanitize_names_soft(raw_report, wl_map)
    san = strip_ball_by_ball(san)
    final_report = enforce_two_paragraphs(san)

    # guardrail: keep crucial literals
    need = [stats.get("winner",""), format_margin(stats), stats.get("venue","")]
    if not all(contains_literal(final_report, x) for x in need):
        if not final_report.endswith((".", "!", "?")): final_report += "."
        final_report += f" The result reflected decisive spells and partnerships at {stats.get('venue','the venue')}."

    # Save RAW / SANITIZED / FINAL
    out_dir = os.path.join(SAVE_DIR_RUN, "reports_gen"); os.makedirs(out_dir, exist_ok=True)
    base = os.path.join(out_dir, f"match_{mid}")
    with open(base + "_RAW.txt", "w", encoding="utf-8") as f: f.write(raw_report)
    with open(base + "_SANITIZED.txt", "w", encoding="utf-8") as f: f.write(final_report)
    with open(base + ".txt", "w", encoding="utf-8") as f: f.write(final_report)
    print("Saved:", base + "_RAW.txt")
    print("Saved:", base + "_SANITIZED.txt")
    print("Saved:", base + ".txt")
    return final_report

# Run once for the held-out match (prints full final report)
final_text = infer_for_match(TEST_ID, max_chunks=3, save=True)
print("\n=== PREVIEW ===\n")
print(final_text)