In [25]:
import os
import random
import math
import pandas as pd
import numpy as np
from typing import List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torch.optim import AdamW

from nltk.corpus import wordnet as wn
from nltk import pos_tag, word_tokenize
import nltk

for pkg in ['punkt', 'punkt_tab', 'averaged_perceptron_tagger', 'wordnet', 'omw-1.4']:
    nltk.download(pkg, quiet=True)

In [26]:
CSV_PATH = "../data/cleaned_courses.csv"  # adjust if needed
TEXT_COLUMN = "TextForBERT"

MODEL_NAME = "bert-base-uncased"
MAX_LEN = 64
BATCH_SIZE = 16
EPOCHS = 5
LR = 2e-5
PROJ_DIM = 256
TEMPERATURE = 0.05

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

In [27]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(SEED)

In [28]:

# =========================
# Text Augmentations
#   - random deletion
#   - synonym replacement
#   - random insertion
#   - random swap
# =========================

def get_wordnet_pos(treebank_tag: str):
    """
    Map NLTK POS tags to WordNet POS tags.
    We mostly care about adjectives for your use case.
    """
    if treebank_tag.startswith('J'):
        return wn.ADJ
    elif treebank_tag.startswith('V'):
        return wn.VERB
    elif treebank_tag.startswith('N'):
        return wn.NOUN
    elif treebank_tag.startswith('R'):
        return wn.ADV
    else:
        return None


def get_synonym(word: str, wn_pos=None):
    """Get a random synonym for a word (if available)."""
    try:
        synsets = wn.synsets(word, pos=wn_pos) if wn_pos else wn.synsets(word)
        if not synsets:
            return None

        synset = random.choice(synsets)
        lemmas = [l.name().replace('_', ' ') for l in synset.lemmas() if l.name().lower() != word.lower()]
        if not lemmas:
            return None

        return random.choice(lemmas)
    except Exception:
        return None


def random_deletion(words: List[str], p: float = 0.1) -> List[str]:
    if len(words) == 1:
        return words
    kept = [w for w in words if random.random() > p]
    if not kept:
        kept = [random.choice(words)]
    return kept


def synonym_replacement(words: List[str], n: int = 1) -> List[str]:
    """Replace up to n words with synonyms (prefer adjectives)."""
    if len(words) == 0:
        return words

    new_words = words.copy()
    # POS tag to find adjectives etc.
    tagged = pos_tag(new_words)
    candidates = list(range(len(new_words)))  # indices

    random.shuffle(candidates)
    num_replaced = 0

    for idx in candidates:
        if num_replaced >= n:
            break

        word = new_words[idx]
        if not word.isalpha():
            continue

        _, tag = tagged[idx]
        wn_pos = get_wordnet_pos(tag)

        # Prefer adjectives but allow others
        synonym = get_synonym(word, wn_pos=wn_pos)
        if synonym is None:
            continue

        new_words[idx] = synonym
        num_replaced += 1

    return new_words


def random_swap(words: List[str], n: int = 1) -> List[str]:
    if len(words) < 2:
        return words
    new_words = words.copy()
    for _ in range(n):
        idx1, idx2 = random.sample(range(len(new_words)), 2)
        new_words[idx1], new_words[idx2] = new_words[idx2], new_words[idx1]
    return new_words


def random_insertion(words: List[str], n: int = 1) -> List[str]:
    """Insert n synonyms of random words at random positions."""
    new_words = words.copy()
    length = len(new_words)
    if length == 0:
        return new_words

    for _ in range(n):
        idx = random.randrange(length)
        word = new_words[idx]
        if not word.isalpha():
            continue

        # Try to get synonym (any POS)
        synonym = get_synonym(word)
        if synonym is None:
            continue

        insert_pos = random.randrange(len(new_words) + 1)
        new_words.insert(insert_pos, synonym)

    return new_words


def augment_text(text: str) -> str:
    """
    Create one augmented view.
    We apply a couple of augmentations with some probability,
    trying not to completely destroy semantics.
    """
    # Tokenize by words (nltk tokenizer is a bit smarter than split)
    words = word_tokenize(text)

    if len(words) == 0:
        return text

    # 1. Random deletion
    if random.random() < 0.8:
        words = random_deletion(words, p=0.1)

    # 2. Synonym replacement (1–2 words)
    if random.random() < 0.6:
        words = synonym_replacement(words, n=random.choice([1, 2]))

    # 3. Random insertion (0–2)
    if random.random() < 0.4:
        words = random_insertion(words, n=random.choice([1, 2]))

    # 4. Random swap (once)
    if random.random() < 0.3:
        words = random_swap(words, n=1)

    return " ".join(words)


def make_two_views(text: str) -> Tuple[str, str]:
    """Return 2 independently augmented views of the same base text."""
    view1 = augment_text(text)
    view2 = augment_text(text)
    return view1, view2

In [29]:
# =========================
# Dataset
# =========================

class ContrastiveCourseDataset(Dataset):
    def __init__(self, texts: List[str], tokenizer, max_len: int):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        base_text = self.texts[idx]
        view1, view2 = make_two_views(base_text)

        encoded1 = self.tokenizer(
            view1,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        encoded2 = self.tokenizer(
            view2,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        item = {
            "input_ids_a": encoded1["input_ids"].squeeze(0),
            "attention_mask_a": encoded1["attention_mask"].squeeze(0),
            "input_ids_b": encoded2["input_ids"].squeeze(0),
            "attention_mask_b": encoded2["attention_mask"].squeeze(0),
        }
        return item

In [30]:
# =========================
# Model: BERT + Projection Head
# =========================

class CourseEncoder(nn.Module):
    def __init__(self, base_model_name: str = MODEL_NAME, proj_dim: int = PROJ_DIM):
        super().__init__()
        self.bert = AutoModel.from_pretrained(base_model_name)
        hidden = self.bert.config.hidden_size
        self.proj = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, proj_dim)
        )

    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # CLS pooling
        cls = out.last_hidden_state[:, 0, :]  # (batch, hidden)
        z = self.proj(cls)
        # Normalize for cosine similarity
        z = nn.functional.normalize(z, p=2, dim=-1)
        return z

In [31]:
# =========================
# Contrastive Loss (InfoNCE / NT-Xent)
# =========================

class NTXentLoss(nn.Module):
    def __init__(self, temperature: float = 0.05):
        super().__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, z_i, z_j):
        """
        z_i, z_j: (batch, dim)
        Positive pairs are (i, i), negatives are all others in the batch.
        """
        batch_size = z_i.size(0)

        logits = torch.matmul(z_i, z_j.T) / self.temperature  # (N, N)
        labels = torch.arange(batch_size, device=z_i.device)

        loss_i = self.criterion(logits, labels)
        loss_j = self.criterion(logits.T, labels)

        loss = (loss_i + loss_j) / 2.0
        return loss

In [32]:
# =========================
# Training Loop
# =========================

def train():
    # ---- Load data ----
    df = pd.read_csv(CSV_PATH)
    if TEXT_COLUMN not in df.columns:
        raise ValueError(f"Column {TEXT_COLUMN} not found in CSV.")

    texts = df[TEXT_COLUMN].astype(str).tolist()
    print(f"Loaded {len(texts)} course descriptions.")

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    dataset = ContrastiveCourseDataset(texts, tokenizer, MAX_LEN)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

    # ---- Model, optimizer, scheduler, loss ----
    model = CourseEncoder().to(DEVICE)

    # Optionally: freeze some lower BERT layers if you want
    # for name, param in model.bert.named_parameters():
    #     if "encoder.layer." in name:
    #         layer_num = int(name.split("encoder.layer.")[-1].split(".")[0])
    #         if layer_num < 8:  # freeze first 8 layers, for example
    #             param.requires_grad = False

    optimizer = AdamW(model.parameters(), lr=LR)
    total_steps = len(dataloader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    criterion = NTXentLoss(temperature=TEMPERATURE)

    model.train()
    step = 0

    for epoch in range(EPOCHS):
        running_loss = 0.0
        for batch in dataloader:
            step += 1
            input_ids_a = batch["input_ids_a"].to(DEVICE)
            attention_mask_a = batch["attention_mask_a"].to(DEVICE)
            input_ids_b = batch["input_ids_b"].to(DEVICE)
            attention_mask_b = batch["attention_mask_b"].to(DEVICE)

            optimizer.zero_grad()

            z_i = model(input_ids_a, attention_mask_a)  # (N, D)
            z_j = model(input_ids_b, attention_mask_b)  # (N, D)

            loss = criterion(z_i, z_j)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()

            if step % 10 == 0:
                avg_loss = running_loss / 10
                print(f"Epoch {epoch+1}/{EPOCHS} | Step {step} | Loss: {avg_loss:.4f}")
                running_loss = 0.0

    # ---- Save model ----
    save_dir = "course_encoder_contrastive"
    os.makedirs(save_dir, exist_ok=True)

    # Save both BERT + projection head
    # Easiest is to save state_dict and tokenizer separately
    torch.save(model.state_dict(), os.path.join(save_dir, "pytorch_model.bin"))
    tokenizer.save_pretrained(save_dir)

    # Also save some config for reloading later
    with open(os.path.join(save_dir, "config.txt"), "w") as f:
        f.write(f"MODEL_NAME={MODEL_NAME}\n")
        f.write(f"PROJ_DIM={PROJ_DIM}\n")
        f.write(f"MAX_LEN={MAX_LEN}\n")

    print(f"Model saved to {save_dir}")


In [33]:
train()

Loaded 527 course descriptions.
Epoch 1/5 | Step 10 | Loss: 1.4313
Epoch 1/5 | Step 20 | Loss: 0.3182
Epoch 1/5 | Step 30 | Loss: 0.0502
Epoch 2/5 | Step 40 | Loss: 0.0180
Epoch 2/5 | Step 50 | Loss: 0.0098
Epoch 2/5 | Step 60 | Loss: 0.0090
Epoch 3/5 | Step 70 | Loss: 0.0028
Epoch 3/5 | Step 80 | Loss: 0.0160
Epoch 3/5 | Step 90 | Loss: 0.0022
Epoch 4/5 | Step 100 | Loss: 0.0001
Epoch 4/5 | Step 110 | Loss: 0.0041
Epoch 4/5 | Step 120 | Loss: 0.0036
Epoch 4/5 | Step 130 | Loss: 0.0019
Epoch 5/5 | Step 140 | Loss: 0.0025
Epoch 5/5 | Step 150 | Loss: 0.0170
Epoch 5/5 | Step 160 | Loss: 0.0058
Model saved to course_encoder_contrastive
