In [37]:
import os
import random
import math
import pandas as pd
import numpy as np
from typing import List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torch.optim import AdamW

from nltk.corpus import wordnet as wn
from nltk import pos_tag, word_tokenize
import nltk

for pkg in ['punkt', 'punkt_tab', 'averaged_perceptron_tagger', 'wordnet', 'omw-1.4']:
    nltk.download(pkg, quiet=True)

In [None]:
CSV_PATH = "../data/cleaned_courses.csv"  # adjust if needed
CSV_PATH_AUGMENTED_DATA = "../data/students_clean_train.csv"
TEXT_COLUMN = "TextForBERT"

MODEL_NAME = "bert-base-uncased"
MAX_LEN = 64
BATCH_SIZE = 16
EPOCHS = 5
LR = 2e-5
PROJ_DIM = 256
TEMPERATURE = 0.075
LAMBDA_ISO = 0.05  # weight for isotropy regularizer

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

In [39]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(SEED)

In [40]:

# =========================
# Text Augmentations
#   - random deletion
#   - synonym replacement
#   - random insertion
#   - random swap
# =========================

df_2 = pd.read_csv(CSV_PATH_AUGMENTED_DATA)

def get_wordnet_pos(treebank_tag: str):
    """
    Map NLTK POS tags to WordNet POS tags.
    We mostly care about adjectives for your use case.
    """
    if treebank_tag.startswith('J'):
        return wn.ADJ
    elif treebank_tag.startswith('V'):
        return wn.VERB
    elif treebank_tag.startswith('N'):
        return wn.NOUN
    elif treebank_tag.startswith('R'):
        return wn.ADV
    else:
        return None


def get_synonym(word: str, wn_pos=None):
    """Get a random synonym for a word (if available)."""
    try:
        synsets = wn.synsets(word, pos=wn_pos) if wn_pos else wn.synsets(word)
        if not synsets:
            return None

        synset = random.choice(synsets)
        lemmas = [l.name().replace('_', ' ') for l in synset.lemmas() if l.name().lower() != word.lower()]
        if not lemmas:
            return None

        return random.choice(lemmas)
    except Exception:
        return None


def random_deletion(words: List[str], p: float = 0.1) -> List[str]:
    if len(words) == 1:
        return words
    kept = [w for w in words if random.random() > p]
    if not kept:
        kept = [random.choice(words)]
    return kept


def synonym_replacement(words: List[str], n: int = 1) -> List[str]:
    """Replace up to n words with synonyms (prefer adjectives)."""
    if len(words) == 0:
        return words

    new_words = words.copy()
    # POS tag to find adjectives etc.
    tagged = pos_tag(new_words)
    candidates = list(range(len(new_words)))  # indices

    random.shuffle(candidates)
    num_replaced = 0

    for idx in candidates:
        if num_replaced >= n:
            break

        word = new_words[idx]
        if not word.isalpha():
            continue

        _, tag = tagged[idx]
        wn_pos = get_wordnet_pos(tag)

        # Prefer adjectives but allow others
        synonym = get_synonym(word, wn_pos=wn_pos)
        if synonym is None:
            continue

        new_words[idx] = synonym
        num_replaced += 1

    return new_words


def random_swap(words: List[str], n: int = 1) -> List[str]:
    if len(words) < 2:
        return words
    new_words = words.copy()
    for _ in range(n):
        idx1, idx2 = random.sample(range(len(new_words)), 2)
        new_words[idx1], new_words[idx2] = new_words[idx2], new_words[idx1]
    return new_words


def random_insertion(words: List[str], n: int = 1) -> List[str]:
    """Insert n synonyms of random words at random positions."""
    new_words = words.copy()
    length = len(new_words)
    if length == 0:
        return new_words

    for _ in range(n):
        idx = random.randrange(length)
        word = new_words[idx]
        if not word.isalpha():
            continue

        # Try to get synonym (any POS)
        synonym = get_synonym(word)
        if synonym is None:
            continue

        insert_pos = random.randrange(len(new_words) + 1)
        new_words.insert(insert_pos, synonym)

    return new_words


def augment_text(text: str, mode: str = "light") -> str:
    """
    Create one augmented view.
    mode = "light" or "heavy"

    light  -> small changes, stays very close to original
    heavy  -> larger perturbations, but try to keep semantics
    """
    words = word_tokenize(text)

    if len(words) == 0:
        return text

    # ---- Hyperparams by mode ----
    if mode == "light":
        del_prob = 0.05          # lower deletion prob
        syn_prob = 0.4           # fewer synonym ops
        syn_n_choices = [1]      # replace at most 1 word
        ins_prob = 0.2           # rarely insert
        ins_n_choices = [1]
        swap_prob = 0.2
    elif mode == "heavy":
        del_prob = 0.15          # more deletion
        syn_prob = 0.8           # more synonym operations
        syn_n_choices = [1, 2, 3]
        ins_prob = 0.5           # more insertion
        ins_n_choices = [1, 2]
        swap_prob = 0.4
    else:
        # fallback to light if unknown
        del_prob = 0.05
        syn_prob = 0.4
        syn_n_choices = [1]
        ins_prob = 0.2
        ins_n_choices = [1]
        swap_prob = 0.2

    # 1. Random deletion
    if random.random() < 0.8:  # keep same probability of *using* deletion
        words = random_deletion(words, p=del_prob)

    # 2. Synonym replacement
    if random.random() < syn_prob:
        words = synonym_replacement(words, n=random.choice(syn_n_choices))

    # 3. Random insertion
    if random.random() < ins_prob:
        words = random_insertion(words, n=random.choice(ins_n_choices))

    # 4. Random swap
    if random.random() < swap_prob:
        words = random_swap(words, n=1)

    return " ".join(words)


def make_two_views(text: str, course_code:str) -> Tuple[str, str]:
    """
    Return 2 differently augmented views of the same base text.

    view1: light augmentation (close to original)
    view2: heavy augmentation (more aggressively perturbed)
    """
    rows = df_2.loc[df_2["LikedCourses"].str.strip() == course_code, "StudentText"]
    if not rows.empty:
        view1 = rows.iloc[0]
    else:
        view1 = augment_text(text, mode="light")
    view2 = augment_text(text, mode="heavy")
    return view1, view2

In [41]:
class ContrastiveCourseDataset(Dataset):
    def __init__(self, texts: List[str], labels: List[int], tokenizer, full_codes: List[str], max_len: int):
        """
        texts:  list of course descriptions (len = N)
        labels: list of int faculty IDs aligned with texts (len = N)
        """
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.course_codes = full_codes

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        base_text = self.texts[idx]
        view1, view2 = make_two_views(base_text, self.course_codes[idx])

        encoded1 = self.tokenizer(
            view1,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        encoded2 = self.tokenizer(
            view2,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        item = {
            "input_ids_a":      encoded1["input_ids"].squeeze(0),
            "attention_mask_a": encoded1["attention_mask"].squeeze(0),
            "input_ids_b":      encoded2["input_ids"].squeeze(0),
            "attention_mask_b": encoded2["attention_mask"].squeeze(0),
            "label":            torch.tensor(self.labels[idx], dtype=torch.long),
        }
        return item


In [42]:
# =========================
# Model: BERT + Projection Head
# =========================

class CourseEncoder(nn.Module):
    def __init__(self, base_model_name: str = MODEL_NAME, proj_dim: int = PROJ_DIM):
        super().__init__()
        self.bert = AutoModel.from_pretrained(base_model_name)
        hidden = self.bert.config.hidden_size
        self.proj = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, proj_dim)
        )

    def forward(self, input_ids, attention_mask):
        out = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=False
        )
        last_hidden = out.last_hidden_state  # (B, L, H)

        # mean pooling
        mask = attention_mask.unsqueeze(-1).float()  # (B, L, 1)
        masked = last_hidden * mask
        summed = masked.sum(dim=1)
        counts = mask.sum(dim=1).clamp(min=1e-9)
        pooled = summed / counts                    # (B, H)

        z = self.proj(pooled)
        z = nn.functional.normalize(z, p=2, dim=-1)
        return z

In [43]:
# =========================
# Contrastive Loss (InfoNCE / NT-Xent)
# =========================

class SupervisedNTXentLoss(nn.Module):
    """
    Multi-positive supervised contrastive loss.

    - Uses two views z_i, z_j: each (B, D), already L2-normalized.
    - labels: (B,) integer labels (e.g., faculty IDs).
    - Positives for each anchor = all embeddings in the batch (both views)
      that share the same label (excluding itself).
    """
    def __init__(self, temperature: float = 0.05):
        super().__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j, labels):
        device = z_i.device
        batch_size = z_i.size(0)

        # Stack the two views: (2B, D)
        z = torch.cat([z_i, z_j], dim=0)
        z = nn.functional.normalize(z, p=2, dim=-1)

        # Duplicate labels for both views: (2B,)
        labels = labels.to(device)
        labels_all = torch.cat([labels, labels], dim=0)

        # Similarity matrix: (2B, 2B)
        sim = torch.matmul(z, z.T) / self.temperature

        # Mask to remove self-similarity
        self_mask = torch.eye(2 * batch_size, dtype=torch.bool, device=device)

        # Label equality mask: (2B, 2B)
        label_eq = labels_all.unsqueeze(0) == labels_all.unsqueeze(1)
        # Positives: same label, not self
        positive_mask = label_eq & (~self_mask)

        # For numerical stability
        sim_max, _ = sim.max(dim=1, keepdim=True)
        sim = sim - sim_max.detach()

        exp_sim = torch.exp(sim)

        # Denominator: sum over all j != i
        exp_sim = exp_sim.masked_fill(self_mask, 0.0)
        denom = exp_sim.sum(dim=1) + 1e-12

        # Numerator: sum over positives
        pos_exp = exp_sim * positive_mask.float()
        numer = pos_exp.sum(dim=1) + 1e-12

        # Only anchors that have at least one positive
        valid = positive_mask.sum(dim=1) > 0
        loss = -torch.log(numer / denom)
        loss = loss[valid].mean()

        return loss

    
# =========================
# Isotropy regularizer  <<< NEW
# =========================
def isotropy_regularizer(z: torch.Tensor) -> torch.Tensor:
    """
    z: (batch, dim), assumed already L2-normalized along dim=-1.

    Encourages:
      - feature-wise mean ≈ 0
      - feature-wise variance ≈ 1
    """
    # feature means across the batch
    mean = z.mean(dim=0)                    # (dim,)
    # feature variances across the batch
    var = z.var(dim=0, unbiased=False)      # (dim,)

    mean_loss = (mean ** 2).mean()          # want mean -> 0
    var_loss  = ((var - 1.0) ** 2).mean()   # want var -> 1

    return mean_loss + var_loss

In [44]:
# =========================
# Training Loop
# =========================

def train():
    # ---- Load data ----
    df = pd.read_csv(CSV_PATH)
    if TEXT_COLUMN not in df.columns:
        raise ValueError(f"Column {TEXT_COLUMN} not found in CSV.")

    texts = df[TEXT_COLUMN].astype(str).tolist()

    print(f"Loaded {len(texts)} course descriptions.")

    # ---- NEW: build faculty labels ----
    if "Faculty" not in df.columns:
        raise ValueError("Expected a 'Faculty' column in the CSV for labels.")

    faculties = df["Faculty"].astype(str).tolist()
    full_codes = df["Faculty"].astype(str) + " " + df["Code"].astype(str)
    unique_faculties = sorted(set(faculties))
    fac2id = {fac: i for i, fac in enumerate(unique_faculties)}
    labels = [fac2id[f] for f in faculties]    # list[int], same length as texts

    print(f"Found {len(unique_faculties)} unique faculties.")

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    dataset = ContrastiveCourseDataset(texts, labels, tokenizer, full_codes, MAX_LEN)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

    # ---- Model, optimizer, scheduler, loss ----
    model = CourseEncoder().to(DEVICE)

    # Optionally: freeze some lower BERT layers if you want
    # for name, param in model.bert.named_parameters():
    #     if "encoder.layer." in name:
    #         layer_num = int(name.split("encoder.layer.")[-1].split(".")[0])
    #         if layer_num < 8:  # freeze first 8 layers, for example
    #             param.requires_grad = False

    optimizer = AdamW(model.parameters(), lr=LR)
    total_steps = len(dataloader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    criterion = SupervisedNTXentLoss(temperature=TEMPERATURE)

    model.train()
    step = 0

    for epoch in range(EPOCHS):
        running_loss = 0.0
        running_ctr = 0.0      # for logging
        running_iso = 0.0      # for logging

        for batch in dataloader:
            step += 1
            input_ids_a      = batch["input_ids_a"].to(DEVICE)
            attention_mask_a = batch["attention_mask_a"].to(DEVICE)
            input_ids_b      = batch["input_ids_b"].to(DEVICE)
            attention_mask_b = batch["attention_mask_b"].to(DEVICE)
            labels_batch     = batch["label"].to(DEVICE)   # <<< NEW

            optimizer.zero_grad()

            z_i = model(input_ids_a, attention_mask_a)  # (N, D) normalized
            z_j = model(input_ids_b, attention_mask_b)  # (N, D) normalized

            # 1) supervised contrastive loss (multi-positive by faculty)
            loss_contrastive = criterion(z_i, z_j, labels_batch)

            # 2) isotropy regularizer on all embeddings in the batch
            z_all = torch.cat([z_i, z_j], dim=0)        # (2N, D)
            loss_iso = isotropy_regularizer(z_all)

            # 3) total loss
            loss = loss_contrastive + LAMBDA_ISO * loss_iso

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()
            running_ctr  += loss_contrastive.item()
            running_iso  += loss_iso.item()

            if step % 10 == 0:
                avg_loss = running_loss / 10
                avg_ctr  = running_ctr  / 10
                avg_iso  = running_iso  / 10
                print(
                    f"Epoch {epoch+1}/{EPOCHS} | Step {step} | "
                    f"Loss: {avg_loss:.4f} "
                    f"(ctr: {avg_ctr:.4f}, iso: {avg_iso:.4f})"
                )
                running_loss = 0.0
                running_ctr  = 0.0
                running_iso  = 0.0

    # ---- Save model ----
    save_dir = "model_v6_temp_0.75"
    os.makedirs(save_dir, exist_ok=True)

    # Save both BERT + projection head
    torch.save(model.state_dict(), os.path.join(save_dir, "pytorch_model.bin"))
    tokenizer.save_pretrained(save_dir)

    # Also save some config for reloading later
    with open(os.path.join(save_dir, "config.txt"), "w") as f:
        f.write(f"MODEL_NAME={MODEL_NAME}\n")
        f.write(f"PROJ_DIM={PROJ_DIM}\n")
        f.write(f"MAX_LEN={MAX_LEN}\n")

    print(f"Model saved to {save_dir}")


In [45]:
train()

Loaded 518 course descriptions.
Found 7 unique faculties.
Epoch 1/5 | Step 10 | Loss: 1.1491 (ctr: 1.0990, iso: 1.0012)
Epoch 1/5 | Step 20 | Loss: 0.3983 (ctr: 0.3485, iso: 0.9976)
Epoch 1/5 | Step 30 | Loss: 0.2067 (ctr: 0.1569, iso: 0.9953)
Epoch 2/5 | Step 40 | Loss: 0.1270 (ctr: 0.0922, iso: 0.6962)
Epoch 2/5 | Step 50 | Loss: 0.1592 (ctr: 0.1095, iso: 0.9943)
Epoch 2/5 | Step 60 | Loss: 0.1160 (ctr: 0.0663, iso: 0.9940)
Epoch 3/5 | Step 70 | Loss: 0.0262 (ctr: 0.0064, iso: 0.3975)
Epoch 3/5 | Step 80 | Loss: 0.1051 (ctr: 0.0554, iso: 0.9935)
Epoch 3/5 | Step 90 | Loss: 0.0723 (ctr: 0.0227, iso: 0.9935)
Epoch 4/5 | Step 100 | Loss: 0.0052 (ctr: 0.0002, iso: 0.0993)
Epoch 4/5 | Step 110 | Loss: 0.0925 (ctr: 0.0428, iso: 0.9934)
Epoch 4/5 | Step 120 | Loss: 0.0729 (ctr: 0.0232, iso: 0.9934)
Epoch 4/5 | Step 130 | Loss: 0.0663 (ctr: 0.0166, iso: 0.9933)
Epoch 5/5 | Step 140 | Loss: 0.0580 (ctr: 0.0183, iso: 0.7946)
Epoch 5/5 | Step 150 | Loss: 0.0739 (ctr: 0.0243, iso: 0.9933)
Epoch 