In [None]:
import os
import json
import random
import math
from dataclasses import dataclass

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from scipy.stats import spearmanr

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Adafactor,
    get_linear_schedule_with_warmup,
)

In [2]:
RATING_OFFSET = 0.5
RATING_SCALE = 5.0

In [3]:
@dataclass
class Config:
    model_name: str = "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli"
    train_path: str = "/kaggle/input/ambistory-raw/train.json"
    dev_path: str = "/kaggle/input/ambistory-raw/dev.json"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    max_length: int = 256
    train_batch_size: int = 8
    eval_batch_size: int = 64
    epochs: int = 10
    learning_rate: float = 2e-5
    warmup_ratio: float = 0.06
    weight_decay: float = 0.0
    gradient_accumulation_steps: int = 2

    fp16: bool = True
    gradient_checkpointing: bool = True

    output_model_path: str = "best_deberta_nli.pt"
    predictions_path: str = "predictions_deberta_nli.jsonl"
    seed: int = 42

In [4]:
class EnvironmentManager:
    @staticmethod
    def set_seed(seed: int = 42):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


In [5]:
class AmbiStoryNliDataset(Dataset):
    def __init__(self, json_path: str, tokenizer, max_length: int):
        self.tokenizer = tokenizer
        self.max_length = max_length

        with open(json_path, "r") as f:
            self.data = json.load(f)
        self.sids = list(self.data.keys())

    def __len__(self):
        return len(self.sids)

    def __getitem__(self, idx):
        sid = self.sids[idx]
        item = self.data[sid]

        pre = item.get("precontext", "").strip()
        sent = item.get("sentence", "").strip()
        end = item.get("ending", "").strip()

        hom = item.get("homonym", "").strip()
        meaning = item.get("judged_meaning", "").strip()
        ex_sent = item.get("example_sentence", "").strip()

        premise = " ".join(p for p in [pre, sent, end] if p)
        hypothesis = (
            f'The definition of "{hom}" is: "{meaning}" '
            f'as in the following sentence: "{ex_sent}"'
        )

        enc = self.tokenizer(
            premise,
            hypothesis,
            max_length=self.max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )

        avg = float(item.get("average", 0.0))
        scaled = (avg - RATING_OFFSET) / RATING_SCALE

        choices = item.get("choices", [])
        choices = list(choices) if isinstance(choices, (list, tuple)) else [choices]
        if len(choices) >= 2:
            gold_mean = float(np.mean(choices))
            gold_stdev = float(np.std(choices, ddof=1))
        elif len(choices) == 1:
            gold_mean = float(choices[0])
            gold_stdev = 0.0
        else:
            gold_mean = avg
            gold_stdev = 0.0

        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": torch.tensor(scaled, dtype=torch.float),
            "gold_mean": torch.tensor(gold_mean, dtype=torch.float),
            "gold_stdev": torch.tensor(gold_stdev, dtype=torch.float),
            "id": int(sid) if str(sid).isdigit() else sid,
        }

In [6]:
@torch.no_grad()
def official_scores_from_stats(
    model,
    loader: DataLoader,
    device: torch.device,
    clamp_min: float = RATING_OFFSET,
    clamp_max: float = 5.5,
    round_to_int: bool = False,
):
    model.eval()
    preds = []
    gold_means = []
    gold_stdevs = []

    for batch in loader:
        ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)

        outputs = model(input_ids=ids, attention_mask=mask)
        logits = outputs.logits.squeeze(-1)

        p = logits.detach().cpu().numpy()
        p = p * RATING_SCALE + RATING_OFFSET
        if clamp_min is not None and clamp_max is not None:
            p = np.clip(p, clamp_min, clamp_max)

        gold_mean = batch["gold_mean"].numpy()
        gold_stdev = batch["gold_stdev"].numpy()

        if round_to_int:
            p = np.rint(p).astype(int)

        preds.extend(p.tolist())
        gold_means.extend(gold_mean.tolist())
        gold_stdevs.extend(gold_stdev.tolist())

    corr, _ = spearmanr(preds, gold_means)
    corr = float(corr)

    correct = 0
    total = len(preds)
    for pred, m, sd in zip(preds, gold_means, gold_stdevs):
        ok = ((m - sd) < pred < (m + sd)) or (abs(m - pred) < 1.0)
        correct += int(ok)
    acc = correct / total if total else 0.0

    return corr, float(acc)

In [7]:
@torch.no_grad()
def save_predictions_jsonl(
    model,
    loader: DataLoader,
    out_path: str,
    device: torch.device,
    clamp_min: float = RATING_OFFSET,
    clamp_max: float = 5.5,
    round_to_int: bool = True,
):
    model.eval()
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)

    with open(out_path, "w") as f:
        for batch in loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)

            outputs = model(input_ids=ids, attention_mask=mask)
            logits = outputs.logits.squeeze(-1)

            p = logits.detach().cpu().numpy()
            p = p * RATING_SCALE + RATING_OFFSET
            if clamp_min is not None and clamp_max is not None:
                p = np.clip(p, clamp_min, clamp_max)
            if round_to_int:
                p = np.rint(p).astype(int)

            for sid, pred in zip(batch["id"], p.tolist()):
                if isinstance(sid, torch.Tensor):
                    sid = sid.item()
                rec = {"id": sid, "prediction": pred}
                f.write(json.dumps(rec) + "\n")


In [8]:
class Trainer:
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self._prepare_components()

        self.best_combined = -1.0
        self.best_epoch = -1

        self.scaler = GradScaler(
            enabled=(self.device.type == "cuda" and self.cfg.fp16)
        )

    def _prepare_components(self):
        print("Loading tokenizer and datasets...")
        self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name)

        self.train_dataset = AmbiStoryNliDataset(
            json_path=self.cfg.train_path,
            tokenizer=self.tokenizer,
            max_length=self.cfg.max_length,
        )
        self.dev_dataset = AmbiStoryNliDataset(
            json_path=self.cfg.dev_path,
            tokenizer=self.tokenizer,
            max_length=self.cfg.max_length,
        )

        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.cfg.train_batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
        )
        self.dev_loader = DataLoader(
            self.dev_dataset,
            batch_size=self.cfg.eval_batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True,
        )

        print("Loading model...")
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.cfg.model_name,
            num_labels=1,
            problem_type="regression",
            ignore_mismatched_sizes=True,
        )

        if self.cfg.gradient_checkpointing:
            try:
                self.model.gradient_checkpointing_enable(
                    gradient_checkpointing_kwargs={"use_reentrant": False}
                )
            except TypeError:
                self.model.gradient_checkpointing_enable()

        if hasattr(self.model.config, "use_cache"):
            self.model.config.use_cache = False

        self.model.to(self.device)

        self.optimizer = Adafactor(
            self.model.parameters(),
            lr=self.cfg.learning_rate,
            scale_parameter=False,
            relative_step=False,
            weight_decay=self.cfg.weight_decay,
        )

        num_update_steps_per_epoch = math.ceil(
            len(self.train_loader) / self.cfg.gradient_accumulation_steps
        )
        num_training_steps = self.cfg.epochs * num_update_steps_per_epoch
        num_warmup_steps = int(self.cfg.warmup_ratio * num_training_steps)

        self.lr_scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
        )

    def fit(self):
        print(f"Starting training on {self.device} for {self.cfg.epochs} epochs.")
        self.optimizer.zero_grad(set_to_none=True)

        for epoch in range(self.cfg.epochs):
            self.model.train()
            total_loss = 0.0
            step = 0

            for batch_idx, batch in enumerate(self.train_loader):
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch["labels"].to(self.device)

                with autocast(
                    device_type=self.device.type,
                    enabled=(self.device.type == "cuda" and self.cfg.fp16),
                ):
                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels,
                    )
                    loss = outputs.loss
                    loss = loss / self.cfg.gradient_accumulation_steps

                self.scaler.scale(loss).backward()
                total_loss += loss.item()
                step += 1

                if step % self.cfg.gradient_accumulation_steps == 0:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad(set_to_none=True)
                    self.lr_scheduler.step()

            avg_loss = total_loss / max(1, step)
            print(f"Epoch {epoch+1}/{self.cfg.epochs} - train loss: {avg_loss:.4f}")

            corr, acc = official_scores_from_stats(
                self.model,
                self.dev_loader,
                device=self.device,
                clamp_min=RATING_OFFSET,
                clamp_max=5.5,
                round_to_int=False,
            )
            combined = 0.5 * (corr + acc)
            print(
                f"Epoch {epoch+1} - dev Spearman: {corr:.4f}, "
                f"dev Acc@SD/1: {acc:.4f}, combined: {combined:.4f}"
            )

            if combined > self.best_combined:
                self.best_combined = combined
                self.best_epoch = epoch + 1
                torch.save(self.model.state_dict(), self.cfg.output_model_path)
                print(
                    f"New best model saved to {self.cfg.output_model_path} "
                    f"(epoch {self.best_epoch}, combined={self.best_combined:.4f})"
                )

        print("Training finished.")
        print(f"Best epoch: {self.best_epoch}, best combined: {self.best_combined:.4f}")

    def load_best_and_predict(self):
        if os.path.exists(self.cfg.output_model_path):
            self.model.load_state_dict(
                torch.load(self.cfg.output_model_path, map_location=self.device)
            )
            print(
                f"Loaded best model from {self.cfg.output_model_path} "
                f"(epoch {self.best_epoch}, combined={self.best_combined:.4f})"
            )
        else:
            print("Best model file not found, using current model weights.")

        save_predictions_jsonl(
            model=self.model,
            loader=self.dev_loader,
            out_path=self.cfg.predictions_path,
            device=self.device,
            clamp_min=RATING_OFFSET,
            clamp_max=5.5,
            round_to_int=True,
        )
        print(f"Dev predictions saved to {self.cfg.predictions_path}")

In [None]:
if __name__ == "__main__":
    cfg = Config()
    EnvironmentManager.set_seed(cfg.seed)

    print("Config:")
    print(cfg)

    trainer = Trainer(cfg)
    trainer.fit()
    trainer.load_best_and_predict()