## Import library

In [21]:
import os
from typing import List, Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, TaskType
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')
from IPython.display import clear_output
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

## Config

In [22]:
class CFG:
    MODEL_NAME = "Qwen/Qwen3-4B"
    TRUST_REMOTE_CODE = True
    MAX_LEN =1024 + 512

    LORA_R = 8
    LORA_ALPHA = 16
    LORA_DROPOUT = 0.1
    TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]

    LR = 3e-5
    WD = 0.01
    EPOCHS = 20
    BATCH_SIZE = 4
    ACCUM_STEPS = 8
    WARMUP_RATIO = 0.15

    GRPO_BETA = 0.08
    REWARD_SCALE = 3.0

    WARMUP_EPOCHS = 5

    CONTEXT_RATIO = 0.6
    MIN_RESPONSE_LEN = 100

    WORKERS = 8

    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    SEED = 42

torch.manual_seed(CFG.SEED)
np.random.seed(CFG.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CFG.SEED)

## Convert label into number

In [23]:
LABEL2ID = {"NO": 0, "INTRINSIC": 1, "EXTRINSIC": 2}
ID2LABEL = {v: k for k, v in LABEL2ID.items()}

def _normalize_label(x: str) -> str:
    s = ("" if x is None else str(x)).strip().lower()
    if s in {"no", "none", "0", "negative", "non-hallu", "non_hallu", "correct"}:
        return "NO"
    if s in {"intrinsic", "intra", "1", "internal"}:
        return "INTRINSIC"
    if s in {"extrinsic", "extra", "2", "external"}:
        return "EXTRINSIC"
    raise ValueError(f"Invalid label: {x}")

In [24]:
class ViHalluSet(Dataset):
    REQUIRED_COLS = ["id", "context", "prompt", "response", "label"]

    def __init__(self, data):
        self.samples = []
        if isinstance(data, pd.DataFrame):
            df = data.fillna("")
            self._load_from_df(df)

    def _load_from_df(self, df: pd.DataFrame):
        missing = [c for c in self.REQUIRED_COLS if c not in df.columns]
        if missing:
            raise ValueError(f"Missing columns: {missing}")
        for _, row in df.iterrows():
            self.samples.append({
                "id": str(row["id"]).strip(),
                "context": str(row["context"]).strip(),
                "prompt": str(row["prompt"]).strip(),
                "response": str(row["response"]).strip(),
                "label": _normalize_label(row["label"])
            })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, i):
        return self.samples[i]

class SmartCollator:
    def __init__(self, tok: AutoTokenizer, max_len: int, cfg: CFG):
        self.tok = tok
        self.max_len = max_len
        self.cfg = cfg
        self.instruction_template = """Bạn là một chuyên gia nhận diện ảo giác trong câu trả lời so với ngữ cảnh và câu hỏi, nhiệm vụ phân loại câu trả lời theo 3 loại:
Nhận diện "No":
- Response hoàn toàn nhất quán và đúng sự thật với thông tin được cung cấp trong context.
- Response không chứa bất kỳ thông tin nào sai lệch hoặc không thể suy luận trực tiếp từ context.
- Response trả lời đúng dựa trên context.

Nhận diện "Intrinsic":
- Mâu thuẫn trực tiếp trong phản hồi
- Bóp méo thông tin đã được cung cấp rõ ràng trong context
- Chứa thực thể hoặc khái niệm trong context nhưng thông tin chúng bị thay đổi, sai lệch
- Response sai lệch nhưng nghe có vẻ khá hợp lý (plausible) trong ngữ cảnh đó

Nhận diện "Extrinsic":
- Response bổ sung thông tin không có trong context
- Thông tin bổ sung không thể suy luận được từ context
- Thông tin bổ sung có thể đúng trong thế giới thực nhưng nó không được cung cấp trong context
"""

    def _smart_truncate(self, context: str, response: str, max_context_len: int, max_response_len: int):
        ctx_tokens = self.tok(context, add_special_tokens=False, truncation=False)["input_ids"]
        rsp_tokens = self.tok(response, add_special_tokens=False, truncation=False)["input_ids"]

        if len(ctx_tokens) <= max_context_len and len(rsp_tokens) <= max_response_len:
            return context, response

        if len(ctx_tokens) > max_context_len:
            keep_start = max_context_len * 2 // 3
            keep_end = max_context_len - keep_start
            ctx_start = self.tok.decode(ctx_tokens[:keep_start], skip_special_tokens=True)
            ctx_end = self.tok.decode(ctx_tokens[-keep_end:], skip_special_tokens=True)
            context = ctx_start + " [...] " + ctx_end

        if len(rsp_tokens) > max_response_len:
            keep_start = min(len(rsp_tokens), max_response_len * 3 // 4)
            keep_end = max_response_len - keep_start
            if keep_end > 0 and len(rsp_tokens) > keep_start:
                rsp_start = self.tok.decode(rsp_tokens[:keep_start], skip_special_tokens=True)
                rsp_end = self.tok.decode(rsp_tokens[-keep_end:], skip_special_tokens=True)
                response = rsp_start + " [...] " + rsp_end
            else:
                response = self.tok.decode(rsp_tokens[:max_response_len], skip_special_tokens=True)

        return context, response

    def __call__(self, batch: List[Dict[str, Any]]):
        input_ids, attention_mask, resp_mask, labels, metas = [], [], [], [], []

        for smp in batch:
            ctx = smp.get("context", "")
            prm = smp.get("prompt", "")
            rsp = smp.get("response", "")
            lbl = smp.get("label", "NO").upper()
            lbl_id = LABEL2ID.get(lbl, 0)

            instruction_len = len(self.tok(self.instruction_template, add_special_tokens=False)["input_ids"])
            prompt_len = len(self.tok(prm, add_special_tokens=False)["input_ids"])
            overhead = instruction_len + prompt_len + 50

            available_len = self.max_len - overhead
            max_context_len = int(available_len * self.cfg.CONTEXT_RATIO)
            max_response_len = available_len - max_context_len

            ctx, rsp = self._smart_truncate(ctx, rsp, max_context_len, max_response_len)

            prefix = f"""{self.instruction_template}

Context Information:
{ctx}

User Query:
{prm}

AI Response to Analyze:
"""

            enc_pre = self.tok(prefix, add_special_tokens=True, truncation=True,
                               max_length=self.max_len - self.cfg.MIN_RESPONSE_LEN)
            enc_rsp = self.tok(rsp, add_special_tokens=False, truncation=True,
                               max_length=self.max_len - len(enc_pre["input_ids"]))

            ids = enc_pre["input_ids"] + enc_rsp["input_ids"]
            ids = ids[:self.max_len]
            attn = [1] * len(ids)

            pre_len = len(enc_pre["input_ids"])
            resp_len = len(ids) - pre_len
            rmask = [0] * pre_len + [1] * resp_len

            pad_id = self.tok.pad_token_id or self.tok.eos_token_id or 0
            if len(ids) < self.max_len:
                pad_n = self.max_len - len(ids)
                ids += [pad_id] * pad_n
                attn += [0] * pad_n
                rmask += [0] * pad_n

            input_ids.append(torch.tensor(ids, dtype=torch.long))
            attention_mask.append(torch.tensor(attn, dtype=torch.long))
            resp_mask.append(torch.tensor(rmask, dtype=torch.float))
            labels.append(lbl_id)
            metas.append({"id": smp.get("id", None), "label_text": lbl})

        return {
            "input_ids": torch.stack(input_ids),
            "attention_mask": torch.stack(attention_mask),
            "resp_mask": torch.stack(resp_mask),
            "labels": torch.tensor(labels, dtype=torch.long),
            "meta": metas
        }

In [25]:
class ImprovedViHalluGRPO(nn.Module):
    def __init__(self, backbone: AutoModel, hidden_size: int, num_labels: int = 3):
        super().__init__()
        self.backbone = backbone

        self.head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_size // 2, num_labels)
        )

        self.value_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, 1)
        )

        self.crit = nn.CrossEntropyLoss(reduction='none')
        self._init_weights()

    def _init_weights(self):
        for module in [self.head, self.value_head]:
            for layer in module:
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_uniform_(layer.weight)
                    if layer.bias is not None:
                        nn.init.zeros_(layer.bias)

    def forward(self, input_ids, attention_mask, resp_mask, labels=None):
        out = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=False,
            return_dict=True
        )
        last = out.last_hidden_state

        weights = resp_mask / (resp_mask.sum(1, keepdim=True).clamp(min=1.0))
        weights = weights.unsqueeze(-1)
        pooled = (last * weights).sum(1)

        logits = self.head(pooled)
        values = self.value_head(pooled).squeeze(-1)

        loss = None
        if labels is not None:
            loss = self.crit(logits, labels)

        return logits, loss, values

    def get_policy_logprobs(self, logits, actions):
        log_probs = F.log_softmax(logits, dim=-1)
        action_log_probs = log_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
        return action_log_probs

In [26]:
class EnhancedRewardCalculator:
    def __init__(self, config: CFG):
        self.config = config
        self.class_weights = torch.tensor([1.0, 2.0, 3.0])

    def set_class_weights(self, train_labels):
        """Tính class weights dựa trên phân bố labels trong training set"""
        class_counts = np.bincount(train_labels, minlength=3)
        total = class_counts.sum()
        self.class_weights = torch.tensor([
            total / (3 * (count + 1)) for count in class_counts
        ], dtype=torch.float32)
        print(f"Updated class weights: {self.class_weights}")

    def calculate_combined_reward(self, predictions, labels, logits=None,
                                   alpha=0.5, beta=0.5):
        if logits is not None:
            probs = F.softmax(logits, dim=-1)
            confidences = probs.max(dim=-1)[0]
        else:
            confidences = torch.ones(len(predictions))

        r_metric = torch.zeros(len(predictions), dtype=torch.float32)
        r_f1 = torch.zeros(len(predictions), dtype=torch.float32)

        for i, (pred, label, conf) in enumerate(zip(predictions, labels, confidences)):
            pred_item = pred.item() if torch.is_tensor(pred) else pred
            label_item = label.item() if torch.is_tensor(label) else label
            conf_item = conf.item() if torch.is_tensor(conf) else conf

            if pred_item == label_item:
                base_reward = self.class_weights[label_item].item()
                confidence_bonus = 0.3 * conf_item
                reward = base_reward + confidence_bonus
                if label_item > 0:
                    reward *= 1.5
                r_metric[i] = reward
            else:
                base_penalty = -1.0
                confidence_penalty = -0.5 * conf_item
                if label_item > 0:
                    base_penalty *= 2.0
                if pred_item > 0 and label_item == 0:
                    base_penalty *= 1.5
                r_metric[i] = base_penalty + confidence_penalty


            if pred_item == label_item and pred_item > 0:
                r_f1[i] = 3.0
            elif pred_item == label_item and pred_item == 0:
                r_f1[i] = 1.0
            elif label_item > 0 and pred_item != label_item:
                r_f1[i] = -3.0
            elif pred_item > 0 and label_item == 0:
                r_f1[i] = -2.0
            else:
                r_f1[i] = -1.0

        combined = alpha * r_metric + beta * r_f1
        combined = combined * self.config.REWARD_SCALE

        return combined


In [27]:
class AdvancedGRPOTrainer:
    def __init__(self, model, ref_model, config: CFG):
        self.model = model
        self.ref_model = ref_model
        self.config = config
        self.reward_calc = EnhancedRewardCalculator(config)

        self.warmup_epochs = CFG.WARMUP_EPOCHS
        self.current_epoch = 0

        self.ref_model.eval()
        for param in self.ref_model.parameters():
            param.requires_grad = False

    def set_epoch(self, epoch):
        self.current_epoch = epoch

    def get_loss_weights(self):
        if self.current_epoch <= self.warmup_epochs:
            return {"ce": 1.0, "pg": 0.0, "kl": 0.0, "value": 0.0}
        elif self.current_epoch <= self.warmup_epochs + 2:
            progress = (self.current_epoch - self.warmup_epochs) / 3.0
            return {
                "ce": 0.8 - 0.3 * progress,
                "pg": 0.4 * progress,
                "kl": 0.03 * progress,
                "value": 0.2 * progress
            }
        else:

            return {"ce": 0.5, "pg": 0.4, "kl": 0.03, "value": 0.2}

    def compute_advantages(self, rewards, values):
        advantages = rewards.detach() - values.detach()
        adv_mean = advantages.mean()
        adv_std = advantages.std() + 1e-6
        advantages = (advantages - adv_mean) / adv_std
        advantages = torch.clamp(advantages, -5.0, 5.0)
        return advantages

    def compute_kl_penalty(self, logits, ref_logits):
        p = F.softmax(logits, dim=-1)
        ref_p = F.softmax(ref_logits.detach(), dim=-1)

        kl = F.kl_div(p.log(), ref_p, reduction='none').sum(-1)
        kl = torch.clamp(kl, max=5.0)
        return kl

    def compute_policy_gradient_loss(self, logits, ref_logits, actions, advantages):

        current_log_probs = F.log_softmax(logits, dim=-1)
        current_action_log_probs = current_log_probs.gather(1, actions.unsqueeze(1)).squeeze(1)

        with torch.no_grad():
            ref_log_probs = F.log_softmax(ref_logits, dim=-1)
            ref_action_log_probs = ref_log_probs.gather(1, actions.unsqueeze(1)).squeeze(1)

        ratio = torch.exp(current_action_log_probs - ref_action_log_probs)
        ratio = torch.clamp(ratio, 0.1, 10.0)

        pg_loss1 = -advantages * ratio
        pg_loss2 = -advantages * torch.clamp(ratio, 0.8, 1.2)
        pg_loss = torch.max(pg_loss1, pg_loss2).mean()

        return pg_loss

    def grpo_step(self, batch, optimizer, scheduler, accumulation_step=0):
        self.model.train()
        input_ids = batch["input_ids"].to(self.config.DEVICE)
        attention_mask = batch["attention_mask"].to(self.config.DEVICE)
        resp_mask = batch["resp_mask"].to(self.config.DEVICE)
        labels = batch["labels"].to(self.config.DEVICE)

        weights = self.get_loss_weights()

        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            logits, ce_loss, values = self.model(input_ids, attention_mask, resp_mask, labels)

            with torch.no_grad():
                ref_logits, _, _ = self.ref_model(input_ids, attention_mask, resp_mask)

        probs = F.softmax(logits / 1.15, dim=-1)
        if self.model.training and self.current_epoch > self.warmup_epochs:
            actions = torch.multinomial(probs, 1).squeeze(-1)
        else:
            actions = logits.argmax(-1)

        rewards = self.reward_calc.calculate_combined_reward(
                predictions=actions.cpu(),
                labels=labels.cpu(),
                logits=logits.cpu(),
                alpha=0.3,  # metric
                beta=0.7    # F1
            ).to(self.config.DEVICE)

        total_loss = 0
        loss_info = {}

        ce_loss_mean = ce_loss.mean()
        total_loss += weights["ce"] * ce_loss_mean
        loss_info["ce_loss"] = ce_loss_mean.item()

        if weights["pg"] > 0:
            advantages = self.compute_advantages(rewards, values)

            pg_loss = self.compute_policy_gradient_loss(logits, ref_logits, actions, advantages)
            total_loss += weights["pg"] * pg_loss
            loss_info["pg_loss"] = pg_loss.item()

            if weights["kl"] > 0:
                kl_penalty = self.compute_kl_penalty(logits, ref_logits)
                kl_loss = weights["kl"] * kl_penalty.mean()
                total_loss += kl_loss
                loss_info["kl_loss"] = kl_loss.item()
                loss_info["kl_div"] = kl_penalty.mean().item()

            if weights["value"] > 0:
                value_targets = rewards.detach()
                value_loss = F.smooth_l1_loss(values, value_targets)
                total_loss += weights["value"] * value_loss
                loss_info["value_loss"] = value_loss.item()

            loss_info["reward"] = rewards.mean().item()
        else:
            loss_info.update({
                "pg_loss": 0.0, "kl_loss": 0.0, "value_loss": 0.0,
                "reward": 0.0, "kl_div": 0.0
            })

        total_loss = total_loss / self.config.ACCUM_STEPS
        total_loss.backward()

        if (accumulation_step + 1) % self.config.ACCUM_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        preds = logits.argmax(-1)
        acc = (preds == labels).float().mean()

        loss_info.update({
            "loss": total_loss.item() * self.config.ACCUM_STEPS,
            "accuracy": acc.item()
        })

        return loss_info

In [28]:
def create_reference_model(model_name, lora_config, hidden_size):
    base = AutoModel.from_pretrained(model_name, trust_remote_code=True)
    base = get_peft_model(base, lora_config)
    ref_model = ImprovedViHalluGRPO(base, hidden_size=hidden_size, num_labels=3)
    return ref_model

In [29]:
def train_grpo_epoch(model, ref_model, loader, optimizer, scheduler, config, epoch):
    trainer = AdvancedGRPOTrainer(model, ref_model, config)
    trainer.set_epoch(epoch)

    all_labels = []
    for batch in loader:
        all_labels.extend(batch["labels"].cpu().numpy())
    trainer.reward_calc.set_class_weights(all_labels)

    metrics = defaultdict(list)
    model.train()
    pbar = tqdm(loader, desc=f"Epoch {epoch} - GRPO Training (Stage: {'Warmup' if epoch <= CFG.WARMUP_EPOCHS else 'RL'})")

    for idx, batch in enumerate(pbar):
        step_metrics = trainer.grpo_step(batch, optimizer, scheduler, idx)
        for k, v in step_metrics.items():
            metrics[k].append(v)

        if len(metrics["loss"]) > 0:
            smoothed_metrics = {
                "loss": np.mean(metrics["loss"][-100:]),
                "acc": np.mean(metrics["accuracy"][-100:]),
                "reward": np.mean(metrics["reward"][-100:]),
                "ce": np.mean(metrics["ce_loss"][-100:])
            }
            pbar.set_postfix(smoothed_metrics)

    return {k: np.mean(v) for k, v in metrics.items()}

In [30]:
@torch.no_grad()
def evaluate_grpo(model, loader, config, desc="Evaluating"):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    total_loss = 0
    num_batches = 0

    for batch in tqdm(loader, desc=desc):
        input_ids = batch["input_ids"].to(config.DEVICE)
        attention_mask = batch["attention_mask"].to(config.DEVICE)
        resp_mask = batch["resp_mask"].to(config.DEVICE)
        labels = batch["labels"].to(config.DEVICE)

        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            logits, loss, _ = model(input_ids, attention_mask, resp_mask, labels)

        probs = F.softmax(logits, dim=-1)
        preds = logits.argmax(-1)

        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())
        all_probs.extend(probs.cpu().numpy())
        total_loss += loss.mean().item()
        num_batches += 1

    macro_f1 = f1_score(all_labels, all_preds, average="macro")
    weighted_f1 = f1_score(all_labels, all_preds, average="weighted")
    per_class_f1 = f1_score(all_labels, all_preds, average=None)
    cm = confusion_matrix(all_labels, all_preds)
    report = classification_report(
        all_labels, all_preds,
        target_names=[ID2LABEL[i] for i in range(3)],
        digits=4,
        zero_division=0
    )
    return {
        "loss": total_loss / max(1, num_batches),
        "macro_f1": macro_f1,
        "weighted_f1": weighted_f1,
        "per_class_f1": {ID2LABEL[i]: f1 for i, f1 in enumerate(per_class_f1)},
        "confusion_matrix": cm,
        "report": report,
        "predictions": all_preds,
        "probabilities": all_probs
    }

In [31]:
def plot_training_curves(history):
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    metrics_to_plot = ['loss', 'accuracy', 'macro_f1', 'reward', 'kl_div']

    for idx, metric in enumerate(metrics_to_plot):
        ax = axes.flatten()[idx]
        if f'train_{metric}' in history:
            ax.plot(history[f'train_{metric}'], label='Train', marker='o')
        if f'val_{metric}' in history:
            ax.plot(history[f'val_{metric}'], label='Val', marker='s')

        ax.set_xlabel('Epoch')
        ax.set_ylabel(metric.replace('_', ' ').title())
        ax.set_title(f'{metric.replace("_", " ").title()} Curve')
        ax.legend()
        ax.grid(True, alpha=0.3)

    fig.tight_layout()
    plt.savefig('training_curves.png', dpi=100)
    plt.show()

In [32]:
def plot_confusion_matrix(cm, save_path='confusion_matrix.png'):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=list(ID2LABEL.values()),
                yticklabels=list(ID2LABEL.values()))
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(save_path, dpi=100)
    plt.show()

In [33]:
def build_model_and_tokenizer():
    tok = AutoTokenizer.from_pretrained(
        CFG.MODEL_NAME, use_fast=True, trust_remote_code=CFG.TRUST_REMOTE_CODE
    )
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token or tok.unk_token

    base = AutoModel.from_pretrained(
        CFG.MODEL_NAME,
        trust_remote_code=CFG.TRUST_REMOTE_CODE,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
    )
    hidden = base.config.hidden_size

    lora_cfg = LoraConfig(
        r=CFG.LORA_R,
        lora_alpha=CFG.LORA_ALPHA,
        lora_dropout=CFG.LORA_DROPOUT,
        target_modules=CFG.TARGET_MODULES,
        task_type=TaskType.FEATURE_EXTRACTION,
        bias="none"
    )
    base = get_peft_model(base, lora_cfg)

    trainable_params = sum(p.numel() for p in base.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in base.parameters())
    print(f"Trainable params: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")

    model = ImprovedViHalluGRPO(base, hidden_size=hidden, num_labels=3).to(CFG.DEVICE)
    ref_model = create_reference_model(CFG.MODEL_NAME, lora_cfg, hidden).to(CFG.DEVICE)
    ref_model.load_state_dict(model.state_dict())

    return model, ref_model, tok

In [34]:
def create_weighted_sampler(labels):
    from torch.utils.data import WeightedRandomSampler
    class_counts = np.bincount(labels)
    class_weights = 1.0 / (class_counts + 1e-6)
    class_weights = class_weights / class_weights.sum() * len(class_weights)
    sample_weights = [class_weights[label] for label in labels]
    sampler = WeightedRandomSampler(
        weights=sample_weights, num_samples=len(sample_weights), replacement=True
    )
    return sampler

In [35]:
def save_model_for_inference(model, tokenizer, save_path):
    torch.save({
        'model_state_dict': model.state_dict(),
        'model_name': tokenizer.name_or_path,
    }, save_path)
    print(f"Inference model saved to {save_path}")


In [None]:
def train():
    if not os.path.exists("vihallu-train.csv"):
        # download data 
    df_train = pd.read_csv("vihallu-train.csv")
    df_train, df_val = train_test_split(
            df_train, test_size=0.1, random_state=CFG.SEED, stratify=df_train['label']
        )

    model, ref_model, tok = build_model_and_tokenizer()

    train_dataset = ViHalluSet(df_train)
    val_dataset = ViHalluSet(df_val)

    collator = SmartCollator(tok, CFG.MAX_LEN, CFG)
    train_labels = [LABEL2ID[_normalize_label(row['label'])] for _, row in df_train.iterrows()]
    sampler = create_weighted_sampler(train_labels)

    train_loader = DataLoader(
        train_dataset, batch_size=CFG.BATCH_SIZE, sampler=sampler,
        collate_fn=collator, num_workers=CFG.WORKERS, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=CFG.BATCH_SIZE * 2, shuffle=False,
        collate_fn=collator, num_workers=CFG.WORKERS, pin_memory=True
    )

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=CFG.LR, weight_decay=CFG.WD
    )
    total_steps = len(train_loader) * CFG.EPOCHS
    warmup_steps = int(total_steps * CFG.WARMUP_RATIO)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )

    history = defaultdict(list)
    best_f1 = 0
    limit = 3
    end_ = 0

    save_dir = "outputs"
    os.makedirs(save_dir, exist_ok=True)

    for epoch in range(1, CFG.EPOCHS + 1):
        print(f"\n--- Epoch {epoch}/{CFG.EPOCHS} ---")
        train_metrics = train_grpo_epoch(model, ref_model, train_loader, optimizer, scheduler, CFG, epoch)
        for k, v in train_metrics.items():
            history[f'train_{k}'].append(v)

        val_metrics = evaluate_grpo(model, val_loader, CFG, desc=f"Validating Epoch {epoch}")
        for k in ['loss', 'macro_f1', 'weighted_f1']:
            history[f'val_{k}'].append(val_metrics[k])

        print(f"\n Validation Macro F1: {val_metrics['macro_f1']:.4f}")
        if val_metrics['macro_f1'] > best_f1:
            end_ = 0
            # clear_output(wait=True)
            best_f1 = val_metrics['macro_f1']
            print(f">> New best model found with Macro F1: {best_f1:.4f}")
            print(val_metrics['report'])
            save_model_for_inference(model, tok, os.path.join(save_dir, "best_model.pt"))
            plot_confusion_matrix(val_metrics['confusion_matrix'],save_path=f'confusion_matrix_epoch_{epoch}.png')
        else:
            plot_confusion_matrix(val_metrics['confusion_matrix'],save_path=f'confusion_matrix_epoch_{epoch}.png')
            end_+=1
            print(f"No improvement for {end_}/{limit} epochs")
            if end_>=limit:
                print(f"Early stopping triggered after {limit} epochs without improvement")
                break
    plot_training_curves(history)

In [37]:
class InferenceModel:
    def __init__(self, checkpoint_path: str, lora_config: Dict):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lora_config = lora_config
        self.collator = None
        self.load_model(checkpoint_path)

    def load_model(self, checkpoint_path: str):
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        model_name = checkpoint['model_name']

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        base_model = AutoModel.from_pretrained(model_name)

        peft_config = LoraConfig(
            r=self.lora_config['r'],
            lora_alpha=self.lora_config['lora_alpha'],
            target_modules=self.lora_config['target_modules'],
            task_type=TaskType.FEATURE_EXTRACTION,
            bias="none"
        )
        peft_model = get_peft_model(base_model, peft_config)

        self.model = ImprovedViHalluGRPO(
            peft_model,
            hidden_size=base_model.config.hidden_size,
            num_labels=len(LABEL2ID)
        ).to(self.device)

        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()
        self.collator = SmartCollator(self.tokenizer, CFG.MAX_LEN, CFG)

    @torch.no_grad()
    def predict(self, context: str, prompt: str, response: str):
        sample = [{"context": context, "prompt": prompt, "response": response, "label": "NO"}]
        batch = self.collator(sample)

        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch["attention_mask"].to(self.device)
        resp_mask = batch["resp_mask"].to(self.device)

        logits, _, _ = self.model(input_ids, attention_mask, resp_mask)
        probs = F.softmax(logits, dim=-1).flatten()

        pred_id = torch.argmax(probs).item()
        prediction = ID2LABEL[pred_id]
        confidence = probs[pred_id].item()

        return {"prediction": prediction, "confidence": confidence}

In [38]:
# train()

In [None]:
if not os.path.exists("vihallu-private-test.csv"):
    # download data
df_test = pd.read_csv("vihallu-private-test.csv")

In [None]:
save_dir = "outputs"
os.makedirs(save_dir, exist_ok=True)
!cp /content/best_model.pt   /content/outputs/best_model.pt

In [41]:
lora_config_for_inference = {
        "r": CFG.LORA_R,
        "lora_alpha": CFG.LORA_ALPHA,
        "target_modules": CFG.TARGET_MODULES,
    }

inference_model = InferenceModel(
        checkpoint_path="outputs/best_model.pt",
        lora_config=lora_config_for_inference
    )

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [43]:
predictions = []
for row in tqdm(df_test.itertuples(), total=len(df_test), desc="Predicting on test set"):
    result = inference_model.predict(
            context=row.context,
            prompt=row.prompt,
            response=row.response
        )
    predictions.append(result['prediction'])
df_test['predict_label'] = predictions
print("Inference complete. Displaying results:")
results = pd.DataFrame({
        "id": df_test['id'],
        "predict_label" : df_test['predict_label'].apply(lambda x: str(x).lower())
    })
results
results.to_csv("submit.csv", index=False)

!zip submit.zip submit.csv

Predicting on test set: 100%|██████████| 2000/2000 [27:59<00:00,  1.19it/s]

Inference complete. Displaying results:
  adding: submit.csv (deflated 49%)





In [44]:
print(df_test['predict_label'].value_counts())

predict_label
INTRINSIC    690
NO           667
EXTRINSIC    643
Name: count, dtype: int64
