## Data Preparation

In [26]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoConfig

In [27]:
class NERDataset(Dataset):
    def __init__(self, data_path, tokenizer, label_pad_id=-100, max_length=128):
        with open(data_path, "r", encoding="utf-8") as f:
            raw = json.load(f)["examples"]
        self.data = raw
        self.tokenizer = tokenizer
        self.label_pad_id = label_pad_id
        self.max_length = max_length

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        tokens = self.data[idx]["tokens"]
        ner_tags = self.data[idx]["ner_tags"]

        # buat encoding untuk tokens 
        encoding = self.tokenizer(
            tokens,
            is_split_into_words=True,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors="pt"
        )

        # align labels dengan tokens yang sudah diencoding (jadi kepotong2 sesuai tokenization)
        word_ids = encoding.word_ids(batch_index=0)
        aligned_labels = []
        previous_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                aligned_labels.append(self.label_pad_id)
            elif word_idx != previous_word_idx:
                aligned_labels.append(ner_tags[word_idx])
            else:
                aligned_labels.append(self.label_pad_id)
            previous_word_idx = word_idx
        
        item = {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(aligned_labels, dtype=torch.long)
        }

        return item

In [28]:
def load_label_info(model_name):
    config = AutoConfig.from_pretrained(model_name)
    id2label = config.id2label
    label2id = config.label2id
    num_labels = config.num_labels

    label_info = {
        "id2label": id2label,
        "label2id": label2id,
        "num_labels": num_labels
    }

    return label_info

def create_dataloaders(
        train_path, val_path, test_path,
        model_name,
        batch_size=32,
        max_length=128
):
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    train_dataset = NERDataset(train_path, tokenizer, max_length=max_length)
    val_dataset = NERDataset(val_path, tokenizer, max_length=max_length)
    test_dataset = NERDataset(test_path, tokenizer, max_length=max_length)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

In [29]:
train_loader, val_loader, test_loader = create_dataloaders(
    train_path=r"D:\Dafa\Project\queryner-kd\data\processed\train.json",
    val_path=r"D:\Dafa\Project\queryner-kd\data\processed\validation.json",
    test_path=r"D:\Dafa\Project\queryner-kd\data\processed\test.json",
    model_name="bltlab/queryner-augmented-data-bert-base-uncased",
    batch_size=16,
    max_length=128
)

label_info = load_label_info("bltlab/queryner-augmented-data-bert-base-uncased")

## Model Architecture

In [30]:
from torch import nn
from torchcrf import CRF
from transformers import AutoModel, AutoConfig

In [31]:
class CRFOutputLayer(nn.Module):
    def __init__(self, hidden_dim, num_labels):
        super().__init__()
        self.fc = nn.Linear(hidden_dim, num_labels)
        self.crf = CRF(num_tags=num_labels, batch_first=True)

    def forward(self, outputs, labels=None, mask=None):
        emissions = self.fc(outputs)

        if labels is not None:
            # CRF requires first token to be valid, so we create a modified mask
            # that ensures first token is always included
            if mask is None:
                mask = torch.ones_like(labels, dtype=torch.bool)
            else:
                mask = mask.bool()
            
            # Ensure first position is always valid for CRF
            mask[:, 0] = True
            
            # Replace -100 with 0 (dummy label) to avoid index issues
            labels_crf = labels.clone()
            labels_crf[labels == -100] = 0
            
            # Calculate loss
            log_likelihood = self.crf(emissions, tags=labels_crf, mask=mask, reduction="mean")
            loss = -log_likelihood
            return {"logits": emissions, "loss": loss}
        else:
            if mask is None:
                mask = torch.ones(outputs.shape[:2], dtype=torch.bool, device=outputs.device)
            pred = self.crf.decode(emissions, mask=mask.bool())
            return {"logits": emissions, "pred": pred}


In [32]:
class BaseNERModel(nn.Module):
    def __init__(self, num_labels, use_crf=False):
        super().__init__()
        self.num_labels = num_labels
        self.use_crf = use_crf

    def forward(self, input_ids, attention_mask, labels=None):
        raise NotImplementedError("Forward method must be implemented in subclass.")

In [33]:
class QueryNERTeacher(BaseNERModel):
    def __init__(self, model_name, label_info, use_crf=False):
        super().__init__(num_labels=label_info["num_labels"], use_crf=use_crf)

        self.config = AutoConfig.from_pretrained(
            model_name,
            num_labels=label_info["num_labels"],
            id2label=label_info["id2label"],
            label2id=label_info["label2id"]
        )

        self.bert = AutoModel.from_pretrained(model_name, config=self.config)
        self.dropout = nn.Dropout(0.1)

        if self.use_crf:
            self.crf_output = CRFOutputLayer(self.config.hidden_size, self.config.num_labels)
        else:
            self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)
            self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(self, input_ids, attention_mask, labels=None):

        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)

        if self.use_crf:
            mask = attention_mask.bool()
            result = self.crf_output(sequence_output, labels=labels, mask=mask)
            return result

        else:
            logits = self.classifier(sequence_output)
            if labels is not None:
                loss = self.loss_fn(logits.view(-1, self.num_labels), labels.view(-1))
                return {"logits": logits, "loss": loss}
            else:
                pred = logits.argmax(dim=-1)
                return {"logits": logits, "pred": pred}

In [52]:
class DistilBERTStudent(BaseNERModel):
    def __init__(self, model_name="distilbert-base-uncased", label_info=None, use_crf=False):
        self.use_crf = use_crf
        self.num_labels = label_info["num_labels"]
        super().__init__(num_labels=self.num_labels, use_crf=self.use_crf)

        self.config = AutoConfig.from_pretrained(
            model_name,
            num_labels=label_info["num_labels"],
            id2label=label_info["id2label"],
            label2id=label_info["label2id"]
        )

        self.bert = AutoModel.from_pretrained(model_name, config=self.config)
        self.dropout = nn.Dropout(0.1)

        if self.use_crf:
            self.crf_output = CRFOutputLayer(self.config.hidden_size, self.num_labels)
        else:
            self.classifier = nn.Linear(self.config.hidden_size, self.num_labels)
            self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)

        if self.use_crf:
            mask = attention_mask.bool()
            result = self.crf_output(sequence_output, labels=labels, mask=mask)
            return result
        else:
            logits = self.classifier(sequence_output)
            if labels is not None:
                loss = self.loss_fn(logits.view(-1, self.num_labels), labels.view(-1))
                return {"logits": logits, "loss": loss}
            else:
                pred = logits.argmax(dim=-1)
                return {"logits": logits, "pred": pred}


In [22]:
class TinyBertStudent(BaseNERModel):
    def __init__(self, model_name="huawei-noah/TinyBERT_General_4L_312D", label_info=None, use_crf=False):
        self.use_crf = use_crf
        self.num_labels = label_info["num_labels"]
        super().__init__(num_labels=self.num_labels, use_crf=self.use_crf)

        self.config = AutoConfig.from_pretrained(
            model_name,
            num_labels=label_info["num_labels"],
            id2label=label_info["id2label"],
            label2id=label_info["label2id"]
        )

        self.bert = AutoModel.from_pretrained(model_name, config=self.config)
        self.dropout = nn.Dropout(0.1)

        if self.use_crf:
            self.crf_output = CRFOutputLayer(self.config.hidden_size, self.num_labels)
        else:
            self.classifier = nn.Linear(self.config.hidden_size, self.num_labels)
            self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)

        if self.use_crf:
            mask = attention_mask.bool()
            result = self.crf_output(sequence_output, labels=labels, mask=mask)
            return result
        else:
            logits = self.classifier(sequence_output)
            if labels is not None:
                loss = self.loss_fn(logits.view(-1, self.num_labels), labels.view(-1))
                return {"logits": logits, "loss": loss}
            else:
                pred = logits.argmax(dim=-1)
                return {"logits": logits, "pred": pred}

In [45]:
class BiLSTMStudent(BaseNERModel):
    def __init__(
            self, 
            num_labels, 
            use_crf=False,
            model_name_for_vocab = 'bert-base-uncased',
            emb_dim = 300,
            lstm_hidden = 300,
            label_info = None,
            pad_token_id = 0
        ):
        super().__init__(num_labels, use_crf)
        self.use_crf = use_crf
        self.num_labels = num_labels

        config = AutoConfig.from_pretrained(model_name_for_vocab)
        vocab_size = config.vocab_size
        pad_token_id = config.pad_token_id

        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_token_id)
        self.dropout = nn.Dropout(0.1)
        self.lstm = nn.LSTM(
            input_size=emb_dim,
            hidden_size=lstm_hidden,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        self.classifier = nn.Linear(lstm_hidden * 2, num_labels)
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

        if self.use_crf:
            self.crf_output = CRFOutputLayer(hidden_dim=lstm_hidden * 2, num_labels=num_labels)

    def forward(self, input_ids, attention_mask, labels=None):
        emb = self.embedding(input_ids)
        emb = self.dropout(emb)
        outputs, _ = self.lstm(emb)
        sequence_output = outputs

        if self.use_crf:
            mask = attention_mask.bool()
            result = self.crf_output(sequence_output, labels=labels, mask=mask)
            return result
        else:
            logits = self.classifier(sequence_output)
            if labels is not None:
                loss = self.loss_fn(logits.view(-1, self.num_labels), labels.view(-1))
                return {"logits": logits, "loss": loss}
            else:
                pred = logits.argmax(dim=-1)
                return {"logits": logits, "pred": pred}


In [53]:
teacher = QueryNERTeacher(model_name="bltlab/queryner-augmented-data-bert-base-uncased", label_info=label_info, use_crf=False)
student = DistilBERTStudent(model_name="distilbert-base-uncased", label_info=label_info, use_crf=False)

## Knowledge Distillation Scheme

In [25]:
import torch.nn.functional as F

In [26]:
# src/training/kd_trainer.py
def softmax_with_temperature(logits, temperature):
    return F.softmax(logits / temperature, dim=-1)

def kl_divergence_loss(student_logits, teacher_logits, temperature):
    p_teacher = F.log_softmax(teacher_logits / temperature, dim=-1)
    p_student = F.softmax(student_logits / temperature, dim=-1)
    loss = F.kl_div(p_teacher, p_student, reduction='batchmean')
    loss = loss * (temperature ** 2)
    return loss

def kl_divergence_loss_masked(student_logits, teacher_logits, temperature, mask=None, eps=1e-12):
    T = float(temperature)

    student_log_prob = F.log_softmax(student_logits / T, dim=-1)   # (B, L, C)
    teacher_prob = F.softmax(teacher_logits / T, dim=-1)           # (B, L, C)

    kl_elem = F.kl_div(student_log_prob, teacher_prob, reduction='none')  # (B, L, C)

    kl_token = kl_elem.sum(dim=-1)  # (B, L)

    if mask is not None:
        mask = mask.bool()
        valid_sum = mask.float().sum()
        if valid_sum.item() == 0:
            return torch.tensor(0.0, device=student_logits.device)
        kl_sum = (kl_token * mask.float()).sum()
        return (kl_sum / valid_sum) * (T * T)
    else:
        return kl_token.mean() * (T * T)

In [27]:
# src/training/kd_trainer.py
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm


def kl_divergence_loss_masked(student_logits, teacher_logits, temperature, mask=None, eps=1e-12):
    """
    Compute masked KL divergence loss for knowledge distillation.
    
    Args:
        student_logits: (B, L, C) logits from student
        teacher_logits: (B, L, C) logits from teacher
        temperature: Temperature for softening distributions
        mask: (B, L) attention mask (1 for valid tokens)
        eps: Small constant for numerical stability
    
    Returns:
        Scalar loss value
    """
    T = float(temperature)

    student_log_prob = F.log_softmax(student_logits / T, dim=-1)   # (B, L, C)
    teacher_prob = F.softmax(teacher_logits / T, dim=-1)           # (B, L, C)

    kl_elem = F.kl_div(student_log_prob, teacher_prob, reduction='none')  # (B, L, C)
    kl_token = kl_elem.sum(dim=-1)  # (B, L)

    if mask is not None:
        mask = mask.bool()
        valid_sum = mask.float().sum()
        if valid_sum.item() == 0:
            return torch.tensor(0.0, device=student_logits.device)
        kl_sum = (kl_token * mask.float()).sum()
        return (kl_sum / valid_sum) * (T * T)
    else:
        return kl_token.mean() * (T * T)


def _to_tensor_preds(preds, batch_size, seq_len, device):
    """
    Convert CRF decode output (list[list[int]] or list of tensors) into a tensor
    of shape (batch_size, seq_len) padded with 0s. Caller must mask invalid tokens.
    """
    pred_tensor = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)
    for i, p in enumerate(preds):
        if isinstance(p, torch.Tensor):
            p = p.tolist()
        L = len(p)
        if L > 0:
            pred_tensor[i, :L] = torch.tensor(p, dtype=torch.long, device=device)
    return pred_tensor


def _safe_get_pred_tensor(output, batch_size, seq_len, device):
    """
    Return a (batch, seq_len) tensor of predictions from model output.
    Handles:
      - output["pred"] is a tensor (batch, seq_len)
      - output["pred"] is a list of lists (per-seq predicted label ids)
      - output has no "pred" (use logits.argmax)
    """
    if "pred" in output:
        pred = output["pred"]
        if isinstance(pred, torch.Tensor):
            return pred.to(device)
        else:
            # assume list of lists
            return _to_tensor_preds(pred, batch_size, seq_len, device)
    elif "logits" in output:
        return output["logits"].argmax(dim=-1).to(device)
    else:
        raise ValueError("No 'pred' or 'logits' in model output to produce predictions.")


def _accumulate_confusion_counts(preds_flat, labels_flat):
    """
    Compute per-class TP, predicted_counts, actual_counts using vectors.
    preds_flat and labels_flat are 1D torch.Long tensors on CPU or device.
    Returns (tp_sum, pred_sum, actual_sum) and also total_tp, total_pred, total_actual per class sums.
    """
    if preds_flat.numel() == 0:
        return 0, 0, 0, None  # no valid tokens in this batch

    max_label = int(max(int(preds_flat.max().item()), int(labels_flat.max().item())))
    num_classes = max_label + 1

    # compute per-class counts
    tp_per_class = torch.zeros(num_classes, dtype=torch.long, device=preds_flat.device)
    pred_per_class = torch.zeros(num_classes, dtype=torch.long, device=preds_flat.device)
    actual_per_class = torch.zeros(num_classes, dtype=torch.long, device=preds_flat.device)

    for c in range(num_classes):
        pred_mask = preds_flat == c
        lab_mask = labels_flat == c
        tp_per_class[c] = int((pred_mask & lab_mask).sum().item())
        pred_per_class[c] = int(pred_mask.sum().item())
        actual_per_class[c] = int(lab_mask.sum().item())

    tp_sum = int(tp_per_class.sum().item())
    pred_sum = int(pred_per_class.sum().item())
    actual_sum = int(actual_per_class.sum().item())

    return tp_sum, pred_sum, actual_sum, (tp_per_class.cpu().numpy(), pred_per_class.cpu().numpy(), actual_per_class.cpu().numpy())


def _batch_metrics(pred_tensor, label_tensor, attention_mask):
    """
    pred_tensor: (B, L)
    label_tensor: (B, L) with -100 for ignored positions
    attention_mask: (B, L) with 1 for valid tokens
    Returns TP, predicted_count, actual_count (ints)
    """
    mask = attention_mask.bool()
    # also ensure labels not equal to -100 in valid positions
    valid = mask & (label_tensor != -100)
    if valid.sum().item() == 0:
        return 0, 0, 0

    preds_flat = pred_tensor[valid].view(-1)
    labels_flat = label_tensor[valid].view(-1)

    tp_sum, pred_sum, actual_sum, _ = _accumulate_confusion_counts(preds_flat, labels_flat)
    return tp_sum, pred_sum, actual_sum


def _final_metrics(tp_sum, pred_sum, actual_sum):
    """
    Compute micro precision, recall, f1 from aggregated counts.
    """
    precision = tp_sum / pred_sum if pred_sum > 0 else 0.0
    recall = tp_sum / actual_sum if actual_sum > 0 else 0.0
    if precision + recall > 0:
        f1 = 2 * precision * recall / (precision + recall)
    else:
        f1 = 0.0
    return precision, recall, f1


class KDTrainer:
    """
    Trainer for both baseline (fine-tuning) and knowledge distillation.
    
    For baseline:
        teacher_model=None, alpha=0, beta=1
    
    For KD:
        teacher_model=<trained_model>, alpha=0.5, beta=0.5
    """
    
    def __init__(
        self,
        teacher_model,
        student_model,
        train_loader,
        val_loader,
        optimizer,
        scheduler=None,
        device="cuda",
        alpha=0.5,
        beta=0.5,
        temperature=2.0,
        scheduler_type="plateau"  # "plateau", "cosine", "step", or None
    ):
        """
        Args:
            teacher_model: Teacher model (None for baseline)
            student_model: Student model to train
            train_loader: Training DataLoader
            val_loader: Validation DataLoader
            optimizer: Optimizer for student model
            scheduler: Learning rate scheduler (optional)
            device: Device to use
            alpha: Weight for KD loss (0 for baseline)
            beta: Weight for student loss (1 for baseline)
            temperature: Temperature for KD
            scheduler_type: Type of scheduler for proper step() call
        """
        self.student = student_model.to(device)
        self.teacher = None  # ← FIX: Initialize to None
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.scheduler_type = scheduler_type
        self.device = device
        self.alpha = alpha
        self.beta = beta
        self.T = temperature

        # Setup teacher if provided
        if teacher_model is not None:
            self.teacher = teacher_model.to(device)
            self.teacher.eval()  # Set to eval mode
            for p in self.teacher.parameters():
                p.requires_grad = False
        
        # Validate configuration
        if self.alpha > 0 and self.teacher is None:
            raise ValueError("alpha > 0 requires a teacher model!")
        
        # Print training mode
        mode = "BASELINE" if self.teacher is None else "KNOWLEDGE DISTILLATION"
        print(f"\n{'='*60}")
        print(f"Training Mode: {mode}")
        print(f"Alpha (KD loss weight): {self.alpha}")
        print(f"Beta (Student loss weight): {self.beta}")
        print(f"Temperature: {self.T}")
        print(f"{'='*60}\n")

    def compute_losses(self, batch):
        """
        Compute losses for one batch.
        
        Returns:
            loss_total: Combined loss
            loss_kd: KD loss (0 if no teacher)
            loss_student: Student task loss
            pred_tensor: Predictions for metrics
            labels: Ground truth labels
            attention_mask: Attention mask
        """
        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch["attention_mask"].to(self.device)
        labels = batch["labels"].to(self.device)

        batch_size, seq_len = input_ids.shape

        # Get teacher logits if needed
        teacher_logits = None
        if self.alpha > 0 and self.teacher is not None:
            with torch.no_grad():
                self.teacher.eval()
                teacher_out = self.teacher(
                    input_ids=input_ids, 
                    attention_mask=attention_mask
                )
                teacher_logits = teacher_out["logits"]

        # Get student outputs
        student_out = self.student(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            labels=labels
        )
        student_logits = student_out["logits"]

        # Compute KD loss if teacher provided
        if teacher_logits is not None:
            loss_kd = kl_divergence_loss_masked(
                student_logits, 
                teacher_logits, 
                self.T, 
                mask=attention_mask
            )
        else:
            loss_kd = torch.tensor(0.0, device=self.device)

        # Get student task loss
        loss_student = student_out.get("loss", torch.tensor(0.0, device=self.device))

        # Combined loss
        loss_total = self.alpha * loss_kd + self.beta * loss_student

        # Get predictions for metrics
        pred_tensor = _safe_get_pred_tensor(student_out, batch_size, seq_len, self.device)

        return loss_total, loss_kd, loss_student, pred_tensor, labels, attention_mask

    def train_epoch(self):
        """Train for one epoch."""
        self.student.train()
        total_loss, total_kd, total_stu = 0.0, 0.0, 0.0
        tp_acc, pred_acc, actual_acc = 0, 0, 0

        for batch in tqdm(self.train_loader, desc="Training"):
            self.optimizer.zero_grad()
            
            loss_total, loss_kd, loss_student, pred_tensor, labels, attention_mask = \
                self.compute_losses(batch)
            
            loss_total.backward()
            self.optimizer.step()

            total_loss += float(loss_total.item())
            total_kd += float(loss_kd.item())
            total_stu += float(loss_student.item()) if isinstance(loss_student, torch.Tensor) else float(loss_student)

            tp, pred_count, actual_count = _batch_metrics(pred_tensor, labels, attention_mask)
            tp_acc += tp
            pred_acc += pred_count
            actual_acc += actual_count

        avg_loss = total_loss / len(self.train_loader)
        avg_kd = total_kd / len(self.train_loader)
        avg_stu = total_stu / len(self.train_loader)

        # Update scheduler if provided
        if self.scheduler:
            if self.scheduler_type == "plateau":
                self.scheduler.step(avg_loss)
            else:
                # For cosine, step, etc. that don't need metrics
                self.scheduler.step()

        precision, recall, f1 = _final_metrics(tp_acc, pred_acc, actual_acc)
        return avg_loss, avg_kd, avg_stu, precision, recall, f1

    def validate(self):
        """Validate on validation set."""
        self.student.eval()
        total_loss, total_kd, total_stu = 0.0, 0.0, 0.0
        tp_acc, pred_acc, actual_acc = 0, 0, 0

        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validation"):
                loss_total, loss_kd, loss_student, pred_tensor, labels, attention_mask = \
                    self.compute_losses(batch)
                
                total_loss += float(loss_total.item())
                total_kd += float(loss_kd.item())
                total_stu += float(loss_student.item()) if isinstance(loss_student, torch.Tensor) else float(loss_student)

                tp, pred_count, actual_count = _batch_metrics(pred_tensor, labels, attention_mask)
                tp_acc += tp
                pred_acc += pred_count
                actual_acc += actual_count

        avg_loss = total_loss / len(self.val_loader)
        avg_kd = total_kd / len(self.val_loader)
        avg_stu = total_stu / len(self.val_loader)

        precision, recall, f1 = _final_metrics(tp_acc, pred_acc, actual_acc)
        return avg_loss, avg_kd, avg_stu, precision, recall, f1

    def train(self, num_epochs):
        """
        Train for multiple epochs.
        
        Args:
            num_epochs: Number of epochs to train
            
        Returns:
            history: Dictionary containing training history
        """
        history = {
            "train_loss": [], "val_loss": [],
            "train_kd": [], "val_kd": [],
            "train_stu": [], "val_stu": [],
            "train_precision": [], "train_recall": [], "train_f1": [],
            "val_precision": [], "val_recall": [], "val_f1": []
        }

        best_val_f1 = 0.0
        
        for epoch in range(1, num_epochs + 1):
            print(f"\n{'='*60}")
            print(f"EPOCH {epoch}/{num_epochs}")
            print(f"{'='*60}")
            
            train_loss, train_kd, train_stu, train_prec, train_rec, train_f1 = self.train_epoch()
            val_loss, val_kd, val_stu, val_prec, val_rec, val_f1 = self.validate()

            print(f"\nTrain Loss: {train_loss:.4f} (KD: {train_kd:.4f}, Student: {train_stu:.4f})")
            print(f"Val Loss:   {val_loss:.4f} (KD: {val_kd:.4f}, Student: {val_stu:.4f})")
            print(f"\nTrain Metrics -> P: {train_prec:.4f}, R: {train_rec:.4f}, F1: {train_f1:.4f}")
            print(f"Val Metrics   -> P: {val_prec:.4f}, R: {val_rec:.4f}, F1: {val_f1:.4f}")
            
            # Track best validation F1
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                print(f"✓ New best Val F1: {best_val_f1:.4f}")

            # Store history
            history["train_loss"].append(train_loss)
            history["train_kd"].append(train_kd)
            history["train_stu"].append(train_stu)
            history["val_loss"].append(val_loss)
            history["val_kd"].append(val_kd)
            history["val_stu"].append(val_stu)

            history["train_precision"].append(train_prec)
            history["train_recall"].append(train_rec)
            history["train_f1"].append(train_f1)
            history["val_precision"].append(val_prec)
            history["val_recall"].append(val_rec)
            history["val_f1"].append(val_f1)

        print(f"\n{'='*60}")
        print(f"Training Complete!")
        print(f"Best Val F1: {best_val_f1:.4f}")
        print(f"{'='*60}\n")

        return history

In [54]:
# one batch diagnostic
batch = next(iter(train_loader))
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]

# forward student (no KD/teacher)
out = student(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
logits = out["logits"]
loss = out.get("loss", None)

# count valid tokens
valid_mask = (attention_mask.bool()) & (labels != -100)
num_valid = valid_mask.sum().item()

In [72]:
print(logits.view(-1, 35))
print(labels.view(-1))

tensor([[ 0.1143, -0.2218,  0.1262,  ..., -0.1069,  0.2699, -0.1671],
        [-0.1747, -0.3804,  0.2549,  ..., -0.3591,  0.1842,  0.0710],
        [-0.4831, -0.2506,  0.4335,  ..., -0.3707,  0.0520,  0.2112],
        ...,
        [-0.0063, -0.1429,  0.0920,  ..., -0.0759,  0.0116, -0.0576],
        [-0.0807, -0.0690, -0.0154,  ..., -0.1664, -0.0500,  0.0161],
        [-0.1625,  0.0667, -0.1272,  ...,  0.0073, -0.2190, -0.0056]],
       grad_fn=<ViewBackward0>)
tensor([-100,    9, -100,  ..., -100, -100, -100])


In [55]:
print("batch_size, seq_len:", input_ids.shape)
print("num_valid_tokens:", num_valid)

if loss is not None:
    print("raw loss scalar:", float(loss.item()))
    if num_valid > 0:
        print("loss per valid token:", float(loss.item()) / num_valid)

batch_size, seq_len: torch.Size([16, 128])
num_valid_tokens: 60
raw loss scalar: 3.5641229152679443
loss per valid token: 0.05940204858779907


In [None]:
log_likelihood_sum =   # whatever internal return
print("crf -log_likelihood_sum:", float(-log_likelihood_sum.item()))

In [15]:
# import torch.optim as optim

# optimizer = optim.AdamW(student_bilstm.parameters(), lr=2e-5)

# trainer = KDTrainer(
#     teacher_model=None,
#     student_model=student_bilstm,
#     train_loader=train_loader,
#     val_loader=val_loader,
#     optimizer=optimizer,
#     alpha=0,
#     beta=1,
#     temperature=2.0
# )

In [16]:
# history = trainer.train(num_epochs=10)

## Baseline Model

In [24]:
"""
Comprehensive Baseline Training Script for NER Models
Trains 4 models × 2 datasets × 3 learning rates × 2 CRF settings = 48 experiments
"""

import json
import os
import torch
import torch.optim as optim
from datetime import datetime
from pathlib import Path

# Import your existing modules (assumed to be available)
# from your_module import (
#     NERDataset, create_dataloaders, load_label_info,
#     QueryNERTeacher, DistilBERTStudent, TinyBertStudent, BiLSTMStudent,
#     KDTrainer
# )


def create_experiment_config():
    """Generate all 48 experiment configurations"""
    
    # Configuration space
    models = [
        # ("teacher", "QueryNERTeacher", "bltlab/queryner-augmented-data-bert-base-uncased"),
        ("distilbert", "DistilBERTStudent", "distilbert-base-uncased"),
        ("tinybert", "TinyBertStudent", "huawei-noah/TinyBERT_General_4L_312D"),
        ("bilstm", "BiLSTMStudent", "bert-base-uncased")
    ]
    
    datasets = [
        ("processed", {
            "train": r"D:\Dafa\Project\queryner-kd\data\processed\train.json",
            "val": r"D:\Dafa\Project\queryner-kd\data\processed\validation.json",
            "test": r"D:\Dafa\Project\queryner-kd\data\processed\test.json"
        }),
        ("raw", {
            "train": r"D:\Dafa\Project\queryner-kd\data\raw\train.json",
            "val": r"D:\Dafa\Project\queryner-kd\data\raw\validation.json",
            "test": r"D:\Dafa\Project\queryner-kd\data\raw\test.json"
        })
    ]
    
    learning_rates = [
        (2e-5, "2e-5"),   # Conservative, stable
        (5e-5, "5e-5"),   # Balanced speed/stability
        (1e-4, "1e-4")    # Faster convergence
    ]
    
    crf_settings = [
        (True, "crf"),
        (False, "nocrf")
    ]
    
    # Generate all combinations
    experiments = []
    exp_id = 1
    
    for model_name, model_class, model_path in models:
        for data_name, data_paths in datasets:
            for lr_value, lr_name in learning_rates:
                for use_crf, crf_name in crf_settings:
                    exp = {
                        "id": exp_id,
                        "model_name": model_name,
                        "model_class": model_class,
                        "model_path": model_path,
                        "data_name": data_name,
                        "data_paths": data_paths,
                        "learning_rate": lr_value,
                        "lr_name": lr_name,
                        "use_crf": use_crf,
                        "crf_name": crf_name,
                        "exp_name": f"{model_name}_{data_name}_{lr_name}_{crf_name}"
                    }
                    experiments.append(exp)
                    exp_id += 1
    
    return experiments


def instantiate_model(model_class, model_path, label_info, use_crf, device):
    """Instantiate the correct model based on class name"""
    
    if model_class == "QueryNERTeacher":
        model = QueryNERTeacher(
            model_name=model_path,
            label_info=label_info,
            use_crf=use_crf
        )
    elif model_class == "DistilBERTStudent":
        model = DistilBERTStudent(
            model_name=model_path,
            label_info=label_info,
            use_crf=use_crf
        )
    elif model_class == "TinyBertStudent":
        model = TinyBertStudent(
            model_name=model_path,
            label_info=label_info,
            use_crf=use_crf
        )
    elif model_class == "BiLSTMStudent":
        model = BiLSTMStudent(
            num_labels=label_info["num_labels"],
            use_crf=use_crf,
            model_name_for_vocab=model_path,
            label_info=label_info
        )
    else:
        raise ValueError(f"Unknown model class: {model_class}")
    
    return model.to(device)


def train_single_experiment(exp, label_info, device, num_epochs=10, batch_size=16, max_length=128):
    """Train a single experiment configuration"""
    
    print(f"\n{'='*80}")
    print(f"Experiment {exp['id']}/48: {exp['exp_name']}")
    print(f"{'='*80}")
    print(f"Model: {exp['model_name']}")
    print(f"Dataset: {exp['data_name']}")
    print(f"Learning Rate: {exp['lr_name']}")
    print(f"CRF: {exp['use_crf']}")
    print(f"{'='*80}\n")
    
    # Create dataloaders
    print("Loading data...")
    train_loader, val_loader, test_loader = create_dataloaders(
        train_path=exp['data_paths']['train'],
        val_path=exp['data_paths']['val'],
        test_path=exp['data_paths']['test'],
        model_name="bert-base-uncased",  # tokenizer
        batch_size=batch_size,
        max_length=max_length
    )
    
    # Instantiate model
    print(f"Instantiating model: {exp['model_class']}...")
    model = instantiate_model(
        model_class=exp['model_class'],
        model_path=exp['model_path'],
        label_info=label_info,
        use_crf=exp['use_crf'],
        device=device
    )
    
    # Create optimizer
    optimizer = optim.AdamW(model.parameters(), lr=exp['learning_rate'])
    
    # Create trainer (using KD trainer with alpha=0, beta=1 for baseline)
    trainer = KDTrainer(
        teacher_model=None,  # No teacher for baseline
        student_model=model,
        train_loader=train_loader,
        scheduler=scheduler,
        val_loader=val_loader,
        optimizer=optimizer,
        device=device,
        alpha=0.0,  # No KD loss
        beta=1.0,   # Only student loss
        temperature=2.0  # Not used when alpha=0
    )
    
    # Train
    print(f"Training for {num_epochs} epochs...")
    history = trainer.train(num_epochs=num_epochs)
    
    # Save results
    save_experiment_results(exp, history)
    
    # Clear memory
    del model, trainer, optimizer
    torch.cuda.empty_cache()
    
    return history


def save_experiment_results(exp, history):
    """Save experiment results in organized structure"""
    
    # Create directory structure
    base_dir = Path("results/baseline")
    json_dir = base_dir / "json"
    img_dir = base_dir / "img"
    
    json_dir.mkdir(parents=True, exist_ok=True)
    img_dir.mkdir(parents=True, exist_ok=True)
    
    # Save history as JSON
    json_path = json_dir / f"{exp['exp_name']}.json"
    
    result_data = {
        "experiment": exp,
        "history": history,
        "timestamp": datetime.now().isoformat()
    }
    
    with open(json_path, "w") as f:
        json.dump(result_data, f, indent=4)
    
    print(f"✓ Results saved to: {json_path}")


def run_all_baselines(
    device="cuda",
    num_epochs=10,
    batch_size=16,
    max_length=128,
    start_from=1,
    end_at=48
):
    """Run all 48 baseline experiments"""
    
    print("\n" + "="*80)
    print("BASELINE TRAINING: 48 EXPERIMENTS")
    print("="*80)
    print(f"Device: {device}")
    print(f"Epochs: {num_epochs}")
    print(f"Batch Size: {batch_size}")
    print(f"Max Length: {max_length}")
    print(f"Running experiments {start_from} to {end_at}")
    print("="*80 + "\n")
    
    # Load label info (use teacher model config)
    print("Loading label information...")
    label_info = load_label_info("bltlab/queryner-augmented-data-bert-base-uncased")
    print(f"Number of labels: {label_info['num_labels']}")
    
    # Generate all experiment configs
    experiments = create_experiment_config()
    
    # Filter experiments based on start_from and end_at
    experiments = [exp for exp in experiments if start_from <= exp['id'] <= end_at]
    
    print(f"\nTotal experiments to run: {len(experiments)}\n")
    
    # Track results
    all_results = []
    failed_experiments = []
    
    # Run each experiment
    for i, exp in enumerate(experiments, 1):
        try:
            history = train_single_experiment(
                exp=exp,
                label_info=label_info,
                device=device,
                num_epochs=num_epochs,
                batch_size=batch_size,
                max_length=max_length
            )
            
            # Store summary
            final_metrics = {
                "exp_name": exp['exp_name'],
                "val_f1": history['val_f1'][-1],
                "val_precision": history['val_precision'][-1],
                "val_recall": history['val_recall'][-1],
                "best_val_f1": max(history['val_f1'])
            }
            all_results.append(final_metrics)
            
            print(f"\n✓ Experiment {exp['id']} completed successfully!")
            print(f"Final Val F1: {final_metrics['val_f1']:.4f}")
            print(f"Best Val F1: {final_metrics['best_val_f1']:.4f}\n")
            
        except Exception as e:
            print(f"\n✗ Experiment {exp['id']} FAILED!")
            print(f"Error: {str(e)}\n")
            failed_experiments.append({
                "exp_id": exp['id'],
                "exp_name": exp['exp_name'],
                "error": str(e)
            })
            continue
    
    # Save summary
    save_summary(all_results, failed_experiments)
    
    print("\n" + "="*80)
    print("ALL EXPERIMENTS COMPLETED")
    print("="*80)
    print(f"Successful: {len(all_results)}")
    print(f"Failed: {len(failed_experiments)}")
    print("="*80 + "\n")


def save_summary(all_results, failed_experiments):
    """Save summary of all experiments"""
    
    summary_dir = Path("results/baseline")
    summary_dir.mkdir(parents=True, exist_ok=True)
    
    # Save results summary
    summary_path = summary_dir / "summary.json"
    with open(summary_path, "w") as f:
        json.dump({
            "successful_experiments": all_results,
            "failed_experiments": failed_experiments,
            "timestamp": datetime.now().isoformat()
        }, f, indent=4)
    
    print(f"\n✓ Summary saved to: {summary_path}")
    
    # Print top performing models
    if all_results:
        print("\nTop 5 Models by Best Val F1:")
        sorted_results = sorted(all_results, key=lambda x: x['best_val_f1'], reverse=True)
        for i, result in enumerate(sorted_results[:5], 1):
            print(f"{i}. {result['exp_name']}: F1={result['best_val_f1']:.4f}")


# Example usage
if __name__ == "__main__":
    # Check device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # Run all baselines
    # You can also run in batches:
    # run_all_baselines(device=device, start_from=1, end_at=12)  # First 12
    # run_all_baselines(device=device, start_from=13, end_at=24)  # Next 12
    # run_all_baselines(device=device, start_from=25, end_at=36)  # Next 12
    # run_all_baselines(device=device, start_from=37, end_at=48)  # Last 12
    
    run_all_baselines(
        device=device,
        num_epochs=10,
        batch_size=16,
        max_length=128,
        start_from=1,
        end_at=3
    )

Using device: cuda

BASELINE TRAINING: 48 EXPERIMENTS
Device: cuda
Epochs: 10
Batch Size: 16
Max Length: 128
Running experiments 1 to 3

Loading label information...
Number of labels: 35

Total experiments to run: 3


Experiment 1/48: distilbert_processed_2e-5_crf
Model: distilbert
Dataset: processed
Learning Rate: 2e-5
CRF: True

Loading data...
Instantiating model: DistilBERTStudent...

✗ Experiment 1 FAILED!
Error: name 'DistilBERTStudent' is not defined


Experiment 2/48: distilbert_processed_2e-5_nocrf
Model: distilbert
Dataset: processed
Learning Rate: 2e-5
CRF: False

Loading data...
Instantiating model: DistilBERTStudent...

✗ Experiment 2 FAILED!
Error: name 'DistilBERTStudent' is not defined


Experiment 3/48: distilbert_processed_5e-5_crf
Model: distilbert
Dataset: processed
Learning Rate: 5e-5
CRF: True

Loading data...
Instantiating model: DistilBERTStudent...

✗ Experiment 3 FAILED!
Error: name 'DistilBERTStudent' is not defined


✓ Summary saved to: results\baseline\summ

# Evaluate Teacher Model on Test Set

In [19]:
def evaluate_on_test(model, test_loader, device):
    """
    Evaluate model on test set.
    
    Args:
        model: Trained model (should be in eval mode)
        test_loader: Test DataLoader
        device: Device to use
        
    Returns:
        test_loss, test_kd, test_stu, test_precision, test_recall, test_f1
    """
    from tqdm.auto import tqdm
    
    model.eval()
    total_loss = 0.0
    tp_acc, pred_acc, actual_acc = 0, 0, 0
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            batch_size, seq_len = input_ids.shape
            
            # Get model outputs
            output = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            # Get loss
            loss = output.get("loss", torch.tensor(0.0, device=device))
            total_loss += float(loss.item())
            
            # Get predictions
            if "pred" in output:
                pred = output["pred"]
                if isinstance(pred, torch.Tensor):
                    pred_tensor = pred.to(device)
                else:
                    # Handle CRF list output
                    pred_tensor = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)
                    for i, p in enumerate(pred):
                        if isinstance(p, torch.Tensor):
                            p = p.tolist()
                        L = len(p)
                        if L > 0:
                            pred_tensor[i, :L] = torch.tensor(p, dtype=torch.long, device=device)
            else:
                pred_tensor = output["logits"].argmax(dim=-1)
            
            # Compute metrics
            mask = attention_mask.bool()
            valid = mask & (labels != -100)
            
            if valid.sum().item() > 0:
                preds_flat = pred_tensor[valid].view(-1)
                labels_flat = labels[valid].view(-1)
                
                # Compute TP, predicted, actual counts
                tp = int((preds_flat == labels_flat).sum().item())
                pred_count = int(preds_flat.numel())
                actual_count = int(labels_flat.numel())
                
                tp_acc += tp
                pred_acc += pred_count
                actual_acc += actual_count
    
    # Compute averages
    avg_loss = total_loss / len(test_loader)
    
    # Compute final metrics
    precision = tp_acc / pred_acc if pred_acc > 0 else 0.0
    recall = tp_acc / actual_acc if actual_acc > 0 else 0.0
    if precision + recall > 0:
        f1 = 2 * precision * recall / (precision + recall)
    else:
        f1 = 0.0
    
    return avg_loss, 0.0, avg_loss, precision, recall, f1

In [21]:
from transformers import AutoModelForTokenClassification

teacher = AutoModelForTokenClassification.from_pretrained("bltlab/queryner-augmented-data-bert-base-uncased")

In [22]:
teacher.to("cuda")
test_loss, test_kd, test_stu, test_precision, test_recall, test_f1 = evaluate_on_test(teacher, test_loader, device="cuda")

Testing:   0%|          | 0/63 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [23]:
print(f"\nTest Loss: {test_loss:.4f}")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")
print(f"Test F1: {test_f1:.4f}\n")


Test Loss: 3.3270
Test Precision: 0.6327
Test Recall: 0.6327
Test F1: 0.6327



Raw
- Test Loss: 3.3919
- Test Precision: 0.6219
- Test Recall: 0.6219
- Test F1: 0.6219


Processed
- Test Loss: 3.3270
- Test Precision: 0.6327
- Test Recall: 0.6327
- Test F1: 0.6327

# Load Best Model

In [36]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# recreate model architecture exactly as during training
teacher = QueryNERTeacher(
    model_name="bltlab/queryner-augmented-data-bert-base-uncased",
    label_info=label_info,
    use_crf=True  # or False depending how you trained
)
teacher.to(device)

# load the state dict
ckpt_path = r"D:\Dafa\Project\queryner-kd\teacher_processed_5e-5_crf_best.pt"
state = torch.load(ckpt_path, map_location=device)  # state is likely a dict of tensors

# if you saved state_dict directly:
if all(isinstance(v, torch.Tensor) for v in state.values()):
    # probably you loaded state_dict directly
    teacher.load_state_dict(state)
else:
    # fallback: if the checkpoint contains keys like 'model_state_dict'
    if "model_state_dict" in state:
        teacher.load_state_dict(state["model_state_dict"])
    elif "state_dict" in state:
        teacher.load_state_dict(state["state_dict"])
    else:
        # try to find the sub-dict that looks like state_dict
        for k, v in state.items():
            if isinstance(v, dict) and any(isinstance(t, torch.Tensor) for t in v.values()):
                teacher.load_state_dict(v)
                break


Some weights of BertModel were not initialized from the model checkpoint at bltlab/queryner-augmented-data-bert-base-uncased and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [35]:
teacher.to("cuda")
test_loss, test_kd, test_stu, test_precision, test_recall, test_f1 = evaluate_on_test(teacher, test_loader, device="cuda")

Testing:   0%|          | 0/63 [00:00<?, ?it/s]

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
