In [None]:
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

# ----------------------------
# Seed
# ----------------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ----------------------------
# Dataset
# ----------------------------
class ReviewDataset(Dataset):
    def __init__(self, texts, y_quality, y_relevance, tokenizer, max_len=128, meta=None):
        self.texts = texts
        self.yq = np.asarray(y_quality, dtype=np.int64)
        self.yr = np.asarray(y_relevance, dtype=np.int64)
        self.tok = tokenizer
        self.max_len = max_len
        self.meta = meta

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, i):
        enc = self.tok(
            self.texts[i], truncation=True, padding="max_length", max_length=self.max_len, return_tensors="pt"
        )
        item = {k: v.squeeze(0) for k, v in enc.items()}
        item["yq"] = torch.tensor(int(self.yq[i]))
        item["yr"] = torch.tensor(int(self.yr[i]))
        if self.meta is not None:
            item["meta"] = torch.tensor(self.meta[i], dtype=torch.float32)
        return item

# ----------------------------
# Mean Pooler
# ----------------------------
class MeanPooler(nn.Module):
    def forward(self, last_hidden_state, attention_mask):
        mask = attention_mask.unsqueeze(-1).float()
        summed = (last_hidden_state * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1e-9)
        return summed / denom

# ----------------------------
# Dual-head Model
# ----------------------------
class ReviewMTL(nn.Module):
    def __init__(self, encoder_name, n_quality=3, n_relevance=2, proj_dim=256, dropout=0.1, meta_dim=0):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(encoder_name)
        self.pool = MeanPooler()
        H = self.encoder.config.hidden_size
        self.proj = nn.Sequential(
            nn.Linear(H, H),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(H, proj_dim)
        )
        self.meta_bn = nn.BatchNorm1d(meta_dim) if meta_dim > 0 else None
        head_in = H + (meta_dim if meta_dim > 0 else 0)
        self.head_quality = nn.Sequential(nn.Dropout(dropout), nn.Linear(head_in, n_quality))
        self.head_relev = nn.Sequential(nn.Dropout(dropout), nn.Linear(head_in, n_relevance))

    def forward(self, input_ids, attention_mask, meta=None):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        emb = self.pool(out.last_hidden_state, attention_mask)
        z = nn.functional.normalize(self.proj(emb), p=2, dim=-1)
        if meta is not None and meta.numel() > 0:
            if self.meta_bn is not None:
                meta = self.meta_bn(meta)
            emb = torch.cat([emb, meta], dim=1)
        q_logits = self.head_quality(emb)
        r_logits = self.head_relev(emb)
        return z, q_logits, r_logits

# ----------------------------
# Losses
# ----------------------------
def supcon_loss(z, y, temperature=0.07):
    B = z.size(0)
    sim = (z @ z.t()) / temperature
    eye = torch.eye(B, device=z.device, dtype=torch.bool)
    sim.masked_fill_(eye, -1e9)
    pos = (y.unsqueeze(0) == y.unsqueeze(1)) & (~eye)
    log_denom = torch.logsumexp(sim, dim=1, keepdim=True)
    log_prob = sim - log_denom
    pos_counts = pos.sum(dim=1).clamp(min=1)
    loss_per = -(log_prob * pos).sum(dim=1) / pos_counts
    return loss_per.mean()

def ce_smooth_weighted(logits, targets, n_classes, smoothing=0.0, class_weight=None):
    if smoothing <= 0.0 and class_weight is None:
        return nn.functional.cross_entropy(logits, targets)
    with torch.no_grad():
        true = torch.zeros_like(logits)
        true.fill_(smoothing / (n_classes - 1))
        true.scatter_(1, targets.unsqueeze(1), 1.0 - smoothing)
    logp = nn.functional.log_softmax(logits, dim=-1)
    loss_per = -(true * logp).sum(dim=1)
    if class_weight is not None:
        w = class_weight[targets]
        loss_per = loss_per * w
        return loss_per.sum() / w.sum().clamp(min=1e-9)
    return loss_per.mean()

# ----------------------------
# Training
# ----------------------------
def train_model(train_ds, val_ds, n_quality=3, n_relevance=2, cfg=None):
    set_seed(42)
    meta_dim = train_ds.meta.shape[1] if train_ds.meta is not None else 0
    tok = AutoTokenizer.from_pretrained(cfg.encoder_name, use_fast=True)
    model = ReviewMTL(cfg.encoder_name, n_quality=n_quality, n_relevance=n_relevance, meta_dim=meta_dim).to(cfg.device)
    train_dl = DataLoader(train_ds, batch_size=cfg.classes_per_batch*cfg.samples_per_class, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=cfg.classes_per_batch*cfg.samples_per_class, shuffle=False)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr_heads)
    scaler = torch.cuda.amp.GradScaler(enabled=(cfg.device=="cuda"))

    for ep in range(cfg.epochs_joint):
        model.train()
        for b in train_dl:
            opt.zero_grad(set_to_none=True)
            with torch.autocast(device_type=("cuda" if cfg.device=="cuda" else "cpu"), dtype=torch.float16, enabled=(cfg.device=="cuda")):
                meta = b.get("meta")
                meta = meta.to(cfg.device) if isinstance(meta, torch.Tensor) else None
                z, q_logits, r_logits = model(b["input_ids"].to(cfg.device), b["attention_mask"].to(cfg.device), meta)
                Jq = ce_smooth_weighted(q_logits, b["yq"].to(cfg.device), n_classes=n_quality, smoothing=cfg.label_smoothing)
                Jr = ce_smooth_weighted(r_logits, b["yr"].to(cfg.device), n_classes=n_relevance)
                Jc = supcon_loss(z, b["yq"].to(cfg.device), temperature=cfg.temperature)
                loss = Jq + Jr + cfg.lambda_contrastive * Jc
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
    return model, tok

# ----------------------------
# Prediction / Evaluation
# ----------------------------
@torch.no_grad()
def predict(model, dataset, cfg):
    model.eval()
    all_q, all_r = [], []
    dl = DataLoader(dataset, batch_size=64, shuffle=False)
    for b in dl:
        meta = b.get("meta")
        meta = meta.to(cfg.device) if isinstance(meta, torch.Tensor) else None
        _, q_logits, r_logits = model(b["input_ids"].to(cfg.device), b["attention_mask"].to(cfg.device), meta)
        all_q.append(q_logits.softmax(-1).cpu().numpy())
        all_r.append(r_logits.softmax(-1).cpu().numpy())
    return np.vstack(all_q), np.vstack(all_r)

def evaluate_model_classification_report(model, dataset, cfg):
    yq_true = dataset.yq
    yr_true = dataset.yr
    q_pred, r_pred = predict(model, dataset, cfg)
    yq_pred = q_pred.argmax(axis=1)
    yr_pred = r_pred.argmax(axis=1)
    print("Quality Classification Report")
    print(classification_report(yq_true, yq_pred, digits=4))
    print("Relevance Classification Report")
    print(classification_report(yr_true, yr_pred, digits=4))

# ----------------------------
# Train/Test Split
# ----------------------------
# X_texts = list of your text data
# y_quality = list/array of quality labels
# y_relevance = list/array of relevance labels
X_texts   = df['reviewText'].tolist()      # list of review strings
y_quality = df['qualityLevel'].values            # numpy array of quality labels
y_relevance = df['isRelevant'].values        # numpy array of relevance labels
X_train, X_val, yq_train, yq_val, yr_train, yr_val = train_test_split(
    X_texts, y_quality, y_relevance, test_size=0.2, random_state=42, stratify=y_quality
)

tok = AutoTokenizer.from_pretrained(cfg.encoder_name, use_fast=True)
train_ds = ReviewDataset(X_train, yq_train, yr_train, tok, max_len=cfg.max_len)
val_ds = ReviewDataset(X_val, yq_val, yr_val, tok, max_len=cfg.max_len)

# ----------------------------
# Train and Evaluate
# ----------------------------
model, tok = train_model(train_ds, val_ds, n_quality=3, n_relevance=2, cfg=cfg)
evaluate_model_classification_report(model, val_ds, cfg)


In [None]:
from sklearn.metrics import classification_report

@torch.no_grad()
def evaluate_model_classification_report(model, dl, cfg):
    model.eval()
    all_q_preds, all_r_preds = [], []
    all_q_true, all_r_true = [], []

    for b in dl:
        meta = b.get("meta")
        meta = meta.to(cfg.device) if isinstance(meta, torch.Tensor) else None
        _, q_logits, r_logits = model(
            b["input_ids"].to(cfg.device),
            b["attention_mask"].to(cfg.device),
            meta
        )
        q_preds = q_logits.argmax(-1).cpu().numpy()
        r_preds = r_logits.argmax(-1).cpu().numpy()
        all_q_preds.append(q_preds)
        all_r_preds.append(r_preds)
        all_q_true.append(b["yq"].numpy())
        all_r_true.append(b["yr"].numpy())

    all_q_preds = np.concatenate(all_q_preds)
    all_r_preds = np.concatenate(all_r_preds)
    all_q_true = np.concatenate(all_q_true)
    all_r_true = np.concatenate(all_r_true)

    print("=== Quality Task ===")
    print(classification_report(all_q_true, all_q_preds, digits=4))

    print("=== Relevance Task ===")
    print(classification_report(all_r_true, all_r_preds, digits=4))