In [35]:
# ==========================================================
# 1. SETUP: Imports, Devices, and Data Loading
# ==========================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import copy

device = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLASSES = 531 # Ensure this is defined globally

# Load the artifacts from the Preprocessing notebook
data = torch.load("preprocessed_features2.pth", weights_only=False)

X_train = data["X_train"] 
y_all = data["y_all"]     
y_core = data["y_core"]   
ancestor_list = data["ancestors"]
E_label_768 = data["E_label_768"] 

# Convert label embeddings to tensor if they were saved as numpy
if isinstance(E_label_768, np.ndarray):
    E_label_768 = torch.from_numpy(E_label_768)

In [36]:
# ==========================================================
# 2. DATA LOADERS & SPLITTING
# ==========================================================
class MultiLabelDataset(Dataset):
    def __init__(self, X, y, y_c):
        self.X = X
        self.y = y
        self.y_c = y_c
    def __len__(self): 
        return len(self.X)
    def __getitem__(self, idx):
        return {
            "X": self.X[idx], 
            "y": self.y[idx], 
            "y_core": self.y_c[idx]
        }

# Split indices for validation (15%)
train_idx, val_idx = train_test_split(
    np.arange(len(X_train)), 
    test_size=0.15, 
    random_state=42, 
    shuffle=True
)

train_ds = MultiLabelDataset(X_train[train_idx], y_all[train_idx], y_core[train_idx])
val_ds   = MultiLabelDataset(X_train[val_idx],   y_all[val_idx],   y_core[val_idx])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False)

print(f" Data ready: {len(train_ds)} train, {len(val_ds)} val samples.")

 Data ready: 25063 train, 4424 val samples.


In [37]:
# ==========================================================
# 3. ARCHITECTURE: Multi-Head GATv2 + Path Attention
# ==========================================================
class MultiHeadPathLabelAttn(nn.Module):
    def __init__(self, dim, num_heads=8, dropout=0.2): # Increased to 8 heads for Strategy B
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        
        self.w_q = nn.Linear(dim, dim, bias=False)
        self.w_k = nn.Linear(dim, dim, bias=False)
        self.v_attn = nn.Linear(self.head_dim, 1, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        
        self.ln = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)
        # Learnable gate to balance self-information vs. hierarchy information
        self.gate = nn.Parameter(torch.ones(1) * 0.5)

    def forward(self, E, ancestors):
        L, d = E.shape
        Q = self.w_q(E).view(L, self.num_heads, self.head_dim)
        K = self.w_k(E).view(L, self.num_heads, self.head_dim)
        V = self.v_proj(E).view(L, self.num_heads, self.head_dim)
        
        out = torch.zeros_like(E)
        for c in range(L):
            anc = ancestors[c]
            if not anc:
                out[c] = E[c]
                continue
            
            # GATv2 Dynamic Attention
            q_c = Q[c:c+1] 
            k_a = K[anc]
            
            # Compute scores: LeakyReLU(Q + K) is more expressive than standard GAT
            scores = self.v_attn(F.leaky_relu(q_c + k_a, 0.2)).squeeze(-1)
            attn = F.softmax(scores, dim=0) 
            
            # Aggregate hierarchy message
            msg = (attn.unsqueeze(-1) * V[anc]).sum(dim=0).view(d)
            
            # Apply Gating
            out[c] = (self.gate * E[c]) + ((1.0 - self.gate) * msg)
            
        # Strategy B: Added Residual Connection (out + E) before LayerNorm
        # This ensures specific label meanings aren't "washed out" by broad parent info
        return self.ln(self.dropout(out) + E)

class ProposedClassifier(nn.Module):
    def __init__(self, input_dim, num_labels, emb_dim, E_label_768, ancestors):
        super().__init__()
        self.doc_proj = nn.Sequential(
            nn.Linear(input_dim, emb_dim), 
            nn.ReLU(), 
            nn.Dropout(0.3)
        )
        self.label_proj = nn.Linear(768, emb_dim, bias=False)
        self.register_buffer("E_text", E_label_768.float())
        
        # Consistent num_heads with the Attention class
        self.label_attn = MultiHeadPathLabelAttn(emb_dim, num_heads=8)
        self.ancestors = ancestors

    def forward(self, X):
        # 1. Project product BERT features
        h = self.doc_proj(X)
        
        # 2. Refine label embeddings using hierarchy attention
        # Label proj maps BERT label space (768) to joint emb_dim (256)
        E = self.label_attn(self.label_proj(self.E_text), self.ancestors)
        
        # 3. Compute compatibility (Dot Product)
        return h @ E.t()

In [38]:
# ==========================================================
# 4. EVALUATION HELPER
# ==========================================================
def evaluate_f1(model, loader, thr=0.35):
    model.eval()
    all_preds, all_targets = [], []
    with torch.no_grad():
        for batch in loader:
            Xb, yb = batch["X"].to(device), batch["y"].to(device)
            probs = torch.sigmoid(model(Xb)).cpu().numpy()
            for p in probs:
                idx = np.argsort(p)[::-1]
                chosen = [i for i in idx[:50] if p[i] >= thr]
                if len(chosen) < 2: chosen = idx[:2].tolist()
                
                vec = np.zeros(NUM_CLASSES)
                vec[chosen[:3]] = 1 # Top 3 max constraint
                all_preds.append(vec)
            all_targets.append(yb.cpu().numpy())
    return f1_score(np.vstack(all_targets), np.vstack(all_preds), average='samples', zero_division=0)

In [None]:
# ==========================================================
# 5. STAGE 1 TRAINING: Core-Aware & Gradient Clipping (REFINED)
# ==========================================================
# Strategy C: Lower W_AUX even further (e.g., 0.1 to 0.3) 
# This tells the model: "Core labels are 10x more important than expanded ancestors"
W_AUX = 0.2  
best_val_f1 = 0.0

# Re-instantiate with the new 8-head Architecture
model = ProposedClassifier(768, NUM_CLASSES, 256, E_label_768, ancestor_list).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=12)

for epoch in range(1, 13): # Increased epochs slightly for the new architecture
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch}"):
        Xb, y_a, y_c = batch["X"].to(device), batch["y"].to(device), batch["y_core"].to(device)
        
        logits = model(Xb)
        loss_elem = F.binary_cross_entropy_with_logits(logits, y_a, reduction='none')
        
        # Strategy C Logic: 
        # y_aux are labels that are in y_all but NOT in y_core
        y_aux = (y_a - y_c).clamp(0, 1)
        
        # We assign a weight of 1.0 to Core labels (y_c) 
        # and a weight of W_AUX (0.2) to auxiliary labels (y_aux)
        weight = torch.ones_like(loss_elem)
        weight = torch.where(y_aux == 1, torch.tensor(W_AUX).to(device), weight)
        
        loss = (loss_elem * weight).mean()
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 
        optimizer.step()
        total_loss += loss.item()
    
    scheduler.step()
    
    current_f1 = evaluate_f1(model, val_loader)
    print(f"Epoch {epoch} | Loss: {total_loss/len(train_loader):.4f} | Val F1: {current_f1:.4f}")
    
    if current_f1 > best_val_f1:
        best_val_f1 = current_f1
        torch.save(model.state_dict(), "stage1_teacher_model1.pth")
        print("  ⭐ Best model saved!")

print(f"✅ Stage 1 Complete. Best Val F1: {best_val_f1:.4f}")

Epoch 1: 100%|██████████| 392/392 [03:22<00:00,  1.94it/s]


Epoch 1 | Loss: 0.0165 | Val F1: 0.1163
  ⭐ Best model saved!


Epoch 2: 100%|██████████| 392/392 [03:23<00:00,  1.92it/s]


Epoch 2 | Loss: 0.0099 | Val F1: 0.1592
  ⭐ Best model saved!


Epoch 3: 100%|██████████| 392/392 [03:24<00:00,  1.92it/s]


Epoch 3 | Loss: 0.0095 | Val F1: 0.1395


Epoch 4: 100%|██████████| 392/392 [03:23<00:00,  1.92it/s]


Epoch 4 | Loss: 0.0090 | Val F1: 0.1577


Epoch 5: 100%|██████████| 392/392 [03:25<00:00,  1.91it/s]


Epoch 5 | Loss: 0.0088 | Val F1: 0.1809
  ⭐ Best model saved!


Epoch 6: 100%|██████████| 392/392 [03:23<00:00,  1.93it/s]


Epoch 6 | Loss: 0.0086 | Val F1: 0.1763


Epoch 7: 100%|██████████| 392/392 [03:22<00:00,  1.94it/s]


Epoch 7 | Loss: 0.0084 | Val F1: 0.1974
  ⭐ Best model saved!


Epoch 8: 100%|██████████| 392/392 [03:25<00:00,  1.90it/s]


Epoch 8 | Loss: 0.0082 | Val F1: 0.1953


Epoch 9: 100%|██████████| 392/392 [03:23<00:00,  1.93it/s]


Epoch 9 | Loss: 0.0079 | Val F1: 0.2038
  ⭐ Best model saved!


Epoch 10: 100%|██████████| 392/392 [03:20<00:00,  1.95it/s]


Epoch 10 | Loss: 0.0078 | Val F1: 0.2077
  ⭐ Best model saved!


Epoch 11: 100%|██████████| 392/392 [03:24<00:00,  1.92it/s]
