In [14]:
import os, json, random
from dataclasses import dataclass
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from tqdm import tqdm

# ------------------- CONFIG -------------------
TRAIN_PATH = "data/imdb_triplets_train.jsonl"   # set this
VAL_PATH   = "data/imdb_triplets_val.jsonl"     # set this

MODEL_NAME = "albert-base-v2"
BATCH_SIZE = 32
EPOCHS = 2
LR = 2e-5
MAX_LEN = 256

TAU = 0.07
LAMBDA_CONT = 0.5
CE_ON_ALL_VIEWS = True   # CE on anchor+para+style
SEED = 42
# ---------------------------------------------

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def read_jsonl(path: str) -> List[Dict]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows

class CACTTripletDataset(Dataset):
    def __init__(self, rows: List[Dict]):
        self.rows = rows

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx: int):
        r = self.rows[idx]
        return {
            "anchor": r["anchor"],
            "para": r["positive_para"],
            "style": r["positive_style"],
            "label": int(r["label"]),
        }

@dataclass
class Batch:
    a_ids: torch.Tensor
    a_attn: torch.Tensor
    p_ids: torch.Tensor
    p_attn: torch.Tensor
    s_ids: torch.Tensor
    s_attn: torch.Tensor
    y: torch.Tensor

def collate(tok, items: List[Dict]) -> Batch:
    anchors = [x["anchor"] for x in items]
    paras   = [x["para"] for x in items]
    styles  = [x["style"] for x in items]
    y = torch.tensor([x["label"] for x in items], dtype=torch.long)

    def enc(texts):
        out = tok(
            texts,
            padding=True,
            truncation=True,
            max_length=MAX_LEN,
            return_tensors="pt",
        )
        return out["input_ids"], out["attention_mask"]

    a_ids, a_attn = enc(anchors)
    p_ids, p_attn = enc(paras)
    s_ids, s_attn = enc(styles)

    return Batch(a_ids, a_attn, p_ids, p_attn, s_ids, s_attn, y)

class CACTModel(nn.Module):
    def __init__(self, backbone: str):
        super().__init__()
        self.enc = AutoModel.from_pretrained(backbone)
        hid = self.enc.config.hidden_size
        self.cls = nn.Linear(hid, 2)

    def embed(self, input_ids, attn_mask):
        out = self.enc(input_ids=input_ids, attention_mask=attn_mask)
        h = out.last_hidden_state[:, 0]          # [CLS]
        h = F.normalize(h, dim=-1)               # important for cosine similarity
        return h

    def logits(self, h):
        return self.cls(h)

def two_pos_infonce(h_a, h_p, h_s, tau: float):
    """
    positives: (a,p) and (a,s)
    negatives: other anchors in batch
    """
    B = h_a.size(0)

    sim_aa = (h_a @ h_a.t()) / tau              # [B,B]
    sim_ap = (h_a * h_p).sum(dim=-1, keepdim=True) / tau  # [B,1]
    sim_as = (h_a * h_s).sum(dim=-1, keepdim=True) / tau  # [B,1]

    eye = torch.eye(B, device=h_a.device).bool()
    sim_aa = sim_aa.masked_fill(eye, -1e9)      # remove self-neg

    num = torch.exp(sim_ap) + torch.exp(sim_as)                 # [B,1]
    den = num + torch.exp(sim_aa).sum(dim=1, keepdim=True)       # [B,1]
    return (-torch.log(num / (den + 1e-12))).mean()

@torch.no_grad()
def evaluate(model, dl, tok):
    model.eval()
    correct, total = 0, 0
    for batch in dl:
        h = model.embed(batch.a_ids.to(DEVICE), batch.a_attn.to(DEVICE))
        logits = model.logits(h)
        pred = logits.argmax(dim=-1).cpu()
        correct += (pred == batch.y).sum().item()
        total += batch.y.size(0)
    model.train()
    return correct / max(total, 1)

def main():
    random.seed(SEED)
    torch.manual_seed(SEED)

    print("DEVICE:", DEVICE)
    print("Loading data...")
    train_rows = read_jsonl(TRAIN_PATH)
    val_rows = read_jsonl(VAL_PATH)
    print("train:", len(train_rows), "val:", len(val_rows))

    tok = AutoTokenizer.from_pretrained(MODEL_NAME)

    train_ds = CACTTripletDataset(train_rows)
    val_ds = CACTTripletDataset(val_rows)

    train_dl = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=lambda x: collate(tok, x),
        num_workers=0,
    )
    val_dl = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=lambda x: collate(tok, x),
        num_workers=0,
    )

    model = CACTModel(MODEL_NAME).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LR)
    ce = nn.CrossEntropyLoss()

    total_steps = EPOCHS * len(train_dl)
    sched = get_linear_schedule_with_warmup(
        opt,
        num_warmup_steps=int(0.06 * total_steps),
        num_training_steps=total_steps,
    )

    for ep in range(EPOCHS):
        pbar = tqdm(train_dl, desc=f"epoch {ep+1}/{EPOCHS}")
        for batch in pbar:
            a_ids, a_attn = batch.a_ids.to(DEVICE), batch.a_attn.to(DEVICE)
            p_ids, p_attn = batch.p_ids.to(DEVICE), batch.p_attn.to(DEVICE)
            s_ids, s_attn = batch.s_ids.to(DEVICE), batch.s_attn.to(DEVICE)
            y = batch.y.to(DEVICE)

            h_a = model.embed(a_ids, a_attn)
            h_p = model.embed(p_ids, p_attn)
            h_s = model.embed(s_ids, s_attn)

            # CE (label preservation)
            L_ce = ce(model.logits(h_a), y)
            if CE_ON_ALL_VIEWS:
                L_ce = (L_ce + ce(model.logits(h_p), y) + ce(model.logits(h_s), y)) / 3.0

            # Contrastive invariance (style + paraphrase)
            L_cont = two_pos_infonce(h_a, h_p, h_s, tau=TAU)

            loss = L_ce + LAMBDA_CONT * L_cont

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
            sched.step()

            pbar.set_postfix({
                "L": float(loss.detach().cpu()),
                "CE": float(L_ce.detach().cpu()),
                "C": float(L_cont.detach().cpu())
            })

        acc = evaluate(model, val_dl, tok)
        print("val_acc:", acc)

    os.makedirs("ckpts", exist_ok=True)
    torch.save(model.state_dict(), "ckpts/cact_albert.pt")
    print("saved: ckpts/cact_albert.pt")

if __name__ == "__main__":
    main()


DEVICE: cuda
Loading data...
train: 4500 val: 500


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/760k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/47.4M [00:00<?, ?B/s]

epoch 1/2: 100%|██████████| 141/141 [1:14:52<00:00, 31.86s/it, L=0.685, CE=0.68, C=0.0109]  


val_acc: 0.73


epoch 2/2: 100%|██████████| 141/141 [1:16:21<00:00, 32.49s/it, L=0.661, CE=0.658, C=0.00637]


val_acc: 0.768
saved: ckpts/cact_albert.pt


In [None]:
import pandas as pd
import numpy as np
