In [7]:
# ==========================================================
# 1. SETUP: Load Preprocessed Features
# ==========================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import copy

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the tensors we saved in the Preprocessing notebook
data = torch.load("preprocessed_features.pth", weights_only=False)
X_train = data["X_train"] # [N, 768]
y_all = data["y_all"]     # [N, 531] - Full hierarchy
y_core = data["y_core"]   # [N, 531] - Just keywords
ancestor_list = data["ancestors"]

E_label_768 = data["E_label_768"] 
if isinstance(E_label_768, np.ndarray):
    E_label_768 = torch.from_numpy(E_label_768)

class MultiLabelDataset(Dataset):
    def __init__(self, X, y, y_c):
        self.X = X
        self.y = y
        self.y_c = y_c
    def __len__(self): return len(self.X)
    def __getitem__(self, idx):
        return {"X": self.X[idx], "y": self.y[idx], "y_core": self.y_c[idx]}

train_loader = DataLoader(MultiLabelDataset(X_train, y_all, y_core), batch_size=64, shuffle=True)

In [8]:
# ==========================================================
# 2. ARCHITECTURE: Multi-Head GATv2 + Path
# ==========================================================
class MultiHeadPathLabelAttn(nn.Module):
    def __init__(self, dim, num_heads=4, dropout=0.2):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.w_q = nn.Linear(dim, dim, bias=False)
        self.w_k = nn.Linear(dim, dim, bias=False)
        self.v_attn = nn.Linear(self.head_dim, 1, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.ln = nn.LayerNorm(dim)
        self.gate = nn.Parameter(torch.ones(1) * 0.5)

    def forward(self, E, ancestors):
        L, d = E.shape
        Q = self.w_q(E).view(L, self.num_heads, self.head_dim)
        K = self.w_k(E).view(L, self.num_heads, self.head_dim)
        V = self.v_proj(E).view(L, self.num_heads, self.head_dim)
        
        out = torch.zeros_like(E)
        for c in range(L):
            anc = ancestors[c]
            if not anc:
                out[c] = E[c]; continue
            
            # GATv2 Dynamic scoring
            q_c = Q[c:c+1] 
            k_a = K[anc]
            scores = self.v_attn(F.leaky_relu(q_c + k_a, 0.2)).squeeze(-1)
            attn = F.softmax(scores, dim=0) 
            
            msg = (attn.unsqueeze(-1) * V[anc]).sum(dim=0).view(d)
            # Gated residual to balance original text vs hierarchy info
            out[c] = (self.gate * E[c]) + ((1.0 - self.gate) * msg)
        return self.ln(out)

class ProposedClassifier(nn.Module):
    def __init__(self, input_dim, num_labels, emb_dim, E_label_768, ancestors):
        super().__init__()
        self.doc_proj = nn.Sequential(nn.Linear(input_dim, emb_dim), nn.ReLU(), nn.Dropout(0.3))
        self.label_proj = nn.Linear(768, emb_dim, bias=False)
        self.register_buffer("E_text", E_label_768.float())
        self.label_attn = MultiHeadPathLabelAttn(emb_dim, num_heads=4)
        self.ancestors = ancestors

    def forward(self, X):
        h = self.doc_proj(X)
        E = self.label_attn(self.label_proj(self.E_text), self.ancestors)
        return h @ E.t()


In [6]:
# ==========================================================
# 3. TRAINING LOOP: Core-Aware & Gradient Clipping
# ==========================================================
# TWEAK: Set weight for auxiliary (expanded) labels
W_AUX = 0.45 

model = ProposedClassifier(768, 531, 256, E_label_768, ancestor_list).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

for epoch in range(1, 11):
    model.train()
    for batch in tqdm(train_loader, desc=f"Epoch {epoch}"):
        Xb, y_a, y_c = batch["X"].to(device), batch["y"].to(device), batch["y_core"].to(device)
        
        logits = model(Xb)
        loss_elem = F.binary_cross_entropy_with_logits(logits, y_a, reduction='none')
        
        # --- CORE-AWARE WEIGHTING ---
        y_aux = (y_a - y_c).clamp(0, 1)
        weight = torch.ones_like(loss_elem) + (W_AUX - 1.0) * y_aux 
        
        loss = (loss_elem * weight).mean()
        
        optimizer.zero_grad()
        loss.backward()
        # PREVENT EXPLODING GRADIENTS
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 
        optimizer.step()
    scheduler.step()

torch.save(model.state_dict(), "stage1_teacher_model.pth")

NameError: name 'E_label_768' is not defined