# V5.1 SciBERT - Stabilized Training

LR: 3e-5 | Class weights: 1.4 | Dropout: 0.30 | Batch: 32

In [None]:
!pip install -q transformers datasets scikit-learn matplotlib seaborn torch

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from transformers import AutoModel, AutoTokenizer, MarianMTModel, MarianTokenizer
from transformers import get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.metrics import recall_score, precision_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pickle
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} ({torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB)")

## Upload Dataset
Upload `arxiv_papers_augmented.csv` or `arxiv_papers_raw.csv`

In [None]:
if os.path.exists('arxiv_papers_augmented.csv'):
    DATA_PATH = 'arxiv_papers_augmented.csv'
    SKIP_AUGMENTATION = True
    print("Using pre-augmented dataset")
elif os.path.exists('arxiv_papers_raw.csv'):
    DATA_PATH = 'arxiv_papers_raw.csv'
    SKIP_AUGMENTATION = False
    print("Will augment dataset (~30-40 min)")
else:
    print("ERROR: No dataset found. Upload arxiv_papers_raw.csv or arxiv_papers_augmented.csv")

## Data Augmentation (skip if using pre-augmented)

In [None]:
class BackTranslationAugmenter:
    def __init__(self, device=None):
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_en_es = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-es').to(self.device).eval()
        self.tokenizer_en_es = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-es')
        self.model_es_en = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-es-en').to(self.device).eval()
        self.tokenizer_es_en = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-es-en')

    def translate(self, text, model, tokenizer, max_length=512):
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = model.generate(**inputs, max_length=max_length, num_beams=4, early_stopping=True)
        return tokenizer.decode(outputs[0], skip_special_tokens=True)

    def back_translate(self, text, max_length=512):
        try:
            spanish = self.translate(text, self.model_en_es, self.tokenizer_en_es, max_length)
            return self.translate(spanish, self.model_es_en, self.tokenizer_es_en, max_length)
        except:
            return text

    def augment_dataset(self, df, target_category='cs.AI', max_samples=450):
        target_samples = df[df['category'] == target_category].copy()
        if len(target_samples) > max_samples:
            target_samples = target_samples.sample(n=max_samples, random_state=42)
        
        augmented_samples = []
        for idx, row in tqdm(target_samples.iterrows(), total=len(target_samples), desc="Augmenting"):
            augmented_samples.append({
                'title': row['title'],
                'abstract': self.back_translate(row['abstract']),
                'category': row['category']
            })
        
        final_df = pd.concat([df, pd.DataFrame(augmented_samples)], ignore_index=True)
        return final_df.sample(frac=1, random_state=42).reset_index(drop=True)

In [None]:
if not SKIP_AUGMENTATION:
    df_raw = pd.read_csv(DATA_PATH)
    augmenter = BackTranslationAugmenter()
    df_augmented = augmenter.augment_dataset(df_raw, max_samples=450)
    df_augmented.to_csv('arxiv_papers_augmented.csv', index=False)
    DATA_PATH = 'arxiv_papers_augmented.csv'
    del augmenter
    torch.cuda.empty_cache()
    print(f"Augmented: {len(df_augmented)} samples")

## Model Architecture

In [None]:
class CrossAttentionSciBERT(nn.Module):
    def __init__(self, num_classes=4, dropout=0.30, freeze_bert_layers=3):
        super().__init__()
        self.bert = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
        hidden_size = self.bert.config.hidden_size
        
        if freeze_bert_layers > 0:
            for layer in self.bert.encoder.layer[:freeze_bert_layers]:
                for param in layer.parameters():
                    param.requires_grad = False
        
        self.embedding_dropout = nn.Dropout(0.1)
        self.cross_attn_title_to_abstract = nn.MultiheadAttention(hidden_size, 8, dropout=0.1, batch_first=True)
        self.cross_attn_abstract_to_title = nn.MultiheadAttention(hidden_size, 8, dropout=0.1, batch_first=True)
        self.layer_norm1 = nn.LayerNorm(hidden_size)
        self.layer_norm2 = nn.LayerNorm(hidden_size)
        self.title_attention = nn.Linear(hidden_size, 1)
        self.abstract_attention = nn.Linear(hidden_size, 1)
        
        self.fusion = nn.Sequential(
            nn.Linear(hidden_size * 2, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(256, 128), nn.LayerNorm(128), nn.GELU(), nn.Dropout(dropout * 0.8)
        )
        self.classifier = nn.Linear(128, num_classes)
        self._init_weights()
    
    def _init_weights(self):
        for m in list(self.fusion.modules()) + [self.classifier]:
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def attention_pool(self, hidden_states, attention_layer, mask):
        attention_weights = attention_layer(hidden_states).squeeze(-1)
        if mask is not None:
            attention_weights = attention_weights.masked_fill(mask == 0, -1e9)
        attention_weights = torch.softmax(attention_weights, dim=1)
        return torch.bmm(attention_weights.unsqueeze(1), hidden_states).squeeze(1), attention_weights
    
    def forward(self, title_input_ids, title_attention_mask, abstract_input_ids, abstract_attention_mask):
        title_hidden = self.embedding_dropout(self.bert(title_input_ids, title_attention_mask).last_hidden_state)
        abstract_hidden = self.embedding_dropout(self.bert(abstract_input_ids, abstract_attention_mask).last_hidden_state)
        
        title_enhanced, _ = self.cross_attn_title_to_abstract(
            title_hidden, abstract_hidden, abstract_hidden, key_padding_mask=(abstract_attention_mask == 0))
        title_enhanced = self.layer_norm1(title_hidden + title_enhanced)
        
        abstract_enhanced, _ = self.cross_attn_abstract_to_title(
            abstract_hidden, title_hidden, title_hidden, key_padding_mask=(title_attention_mask == 0))
        abstract_enhanced = self.layer_norm2(abstract_hidden + abstract_enhanced)
        
        title_pooled, _ = self.attention_pool(title_enhanced, self.title_attention, title_attention_mask)
        abstract_pooled, _ = self.attention_pool(abstract_enhanced, self.abstract_attention, abstract_attention_mask)
        
        return self.classifier(self.fusion(torch.cat([title_pooled, abstract_pooled], dim=1)))

In [None]:
class SciBERTDataset(Dataset):
    def __init__(self, titles, abstracts, labels, tokenizer, max_title_len=32, max_abstract_len=128):
        self.titles = titles
        self.abstracts = abstracts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_title_len = max_title_len
        self.max_abstract_len = max_abstract_len
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        title_enc = self.tokenizer(self.titles[idx], max_length=self.max_title_len, 
                                   padding='max_length', truncation=True, return_tensors='pt')
        abstract_enc = self.tokenizer(self.abstracts[idx], max_length=self.max_abstract_len,
                                      padding='max_length', truncation=True, return_tensors='pt')
        return {
            'title_input_ids': title_enc['input_ids'].squeeze(0),
            'title_attention_mask': title_enc['attention_mask'].squeeze(0),
            'abstract_input_ids': abstract_enc['input_ids'].squeeze(0),
            'abstract_attention_mask': abstract_enc['attention_mask'].squeeze(0),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

## Data Preparation

In [None]:
df = pd.read_csv(DATA_PATH)
le = LabelEncoder()
df['label'] = le.fit_transform(df['category'])
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')

X, y = df[['title', 'abstract']], df['label']
X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.15, random_state=42, stratify=y)
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.15/(1-0.15), random_state=42, stratify=y_temp)

train_dataset = SciBERTDataset(X_train['title'].tolist(), X_train['abstract'].tolist(), y_train.tolist(), tokenizer)
val_dataset = SciBERTDataset(X_val['title'].tolist(), X_val['abstract'].tolist(), y_val.tolist(), tokenizer)
test_dataset = SciBERTDataset(X_test['title'].tolist(), X_test['abstract'].tolist(), y_test.tolist(), tokenizer)

print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

## Training Setup

In [None]:
FREEZE_BERT_LAYERS = 3
DROPOUT = 0.30
BATCH_SIZE = 32
EPOCHS = 10
LR = 3e-5
WEIGHT_DECAY = 0.01
CLASS_WEIGHTS = [1.4, 1.0, 1.0, 1.0]
PATIENCE = 3

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

model = CrossAttentionSciBERT(num_classes=4, dropout=DROPOUT, freeze_bert_layers=FREEZE_BERT_LAYERS).to(device)

class_weights_tensor = torch.FloatTensor(CLASS_WEIGHTS).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1, weight=class_weights_tensor)

bert_params = [p for n, p in model.named_parameters() if 'bert' in n and p.requires_grad]
classifier_params = [p for n, p in model.named_parameters() if 'bert' not in n and p.requires_grad]
optimizer = torch.optim.AdamW([
    {'params': bert_params, 'lr': LR, 'weight_decay': WEIGHT_DECAY},
    {'params': classifier_params, 'lr': LR * 5, 'weight_decay': WEIGHT_DECAY * 2}
])

num_training_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_training_steps // 10, 
                                           num_training_steps=num_training_steps)

print(f"Config: LR={LR} | Weights={CLASS_WEIGHTS} | Batch={BATCH_SIZE} | Steps={num_training_steps}")

## Training

In [None]:
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'val_f1': []}
best_val_acc = 0
best_model_state = None
patience_counter = 0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    model.train()
    train_loss = 0
    all_train_preds, all_train_labels = [], []
    
    for batch in tqdm(train_loader, desc='Train'):
        title_ids = batch['title_input_ids'].to(device)
        title_mask = batch['title_attention_mask'].to(device)
        abstract_ids = batch['abstract_input_ids'].to(device)
        abstract_mask = batch['abstract_attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = model(title_ids, title_mask, abstract_ids, abstract_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        train_loss += loss.item()
        all_train_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
        all_train_labels.extend(labels.cpu().numpy())
    
    train_acc = accuracy_score(all_train_labels, all_train_preds)
    
    model.eval()
    val_loss = 0
    all_val_preds, all_val_labels = [], []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Val'):
            title_ids = batch['title_input_ids'].to(device)
            title_mask = batch['title_attention_mask'].to(device)
            abstract_ids = batch['abstract_input_ids'].to(device)
            abstract_mask = batch['abstract_attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(title_ids, title_mask, abstract_ids, abstract_mask)
            val_loss += criterion(outputs, labels).item()
            all_val_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
            all_val_labels.extend(labels.cpu().numpy())
    
    val_acc = accuracy_score(all_val_labels, all_val_preds)
    val_f1 = f1_score(all_val_labels, all_val_preds, average='weighted')
    
    print(f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | F1: {val_f1:.4f} | Gap: {abs(train_acc-val_acc):.4f}")
    
    history['train_loss'].append(train_loss / len(train_loader))
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss / len(val_loader))
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict().copy()
        patience_counter = 0
        print(f"‚úì Best: {val_acc:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"Early stop")
            break

model.load_state_dict(best_model_state)
print(f"\nBest val: {best_val_acc:.4f}")

## Evaluation

In [None]:
model.eval()
all_test_preds, all_test_labels = [], []

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Test'):
        title_ids = batch['title_input_ids'].to(device)
        title_mask = batch['title_attention_mask'].to(device)
        abstract_ids = batch['abstract_input_ids'].to(device)
        abstract_mask = batch['abstract_attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(title_ids, title_mask, abstract_ids, abstract_mask)
        all_test_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
        all_test_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(all_test_labels, all_test_preds)
test_f1 = f1_score(all_test_labels, all_test_preds, average='weighted')
recalls = recall_score(all_test_labels, all_test_preds, average=None)
precisions = precision_score(all_test_labels, all_test_preds, average=None, zero_division=0)

cs_ai_idx = list(le.classes_).index('cs.AI')
cs_ai_recall = recalls[cs_ai_idx]

print(f"\nTest Acc: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f"Test F1: {test_f1:.4f}")
print(f"cs.AI Recall: {cs_ai_recall:.4f} ({cs_ai_recall*100:.2f}%)")
print(f"\n{classification_report(all_test_labels, all_test_preds, target_names=le.classes_, digits=4)}")

acc_met = test_acc >= 0.60
cs_ai_met = cs_ai_recall > 0.30
print(f"\nAcc ‚â•60%: {'‚úÖ' if acc_met else '‚ùå'} | cs.AI >30%: {'‚úÖ' if cs_ai_met else '‚ùå'}")
if acc_met and cs_ai_met:
    print("üéâ BOTH OBJECTIVES MET!")

## Plots

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].plot(history['train_loss'], label='Train', marker='o')
axes[0].plot(history['val_loss'], label='Val', marker='o')
axes[0].set_title('Loss')
axes[0].legend()
axes[0].grid(True)

axes[1].plot(history['train_acc'], label='Train', marker='o')
axes[1].plot(history['val_acc'], label='Val', marker='o')
axes[1].set_title('Accuracy')
axes[1].legend()
axes[1].grid(True)
plt.tight_layout()
plt.savefig('v5_1_history.png', dpi=150, bbox_inches='tight')
plt.show()

plt.figure(figsize=(10, 8))
sns.heatmap(confusion_matrix(all_test_labels, all_test_preds), annot=True, fmt='d', cmap='Blues',
            xticklabels=le.classes_, yticklabels=le.classes_)
plt.title(f'V5.1 Confusion Matrix | Acc: {test_acc:.3f}')
plt.ylabel('True')
plt.xlabel('Predicted')
plt.tight_layout()
plt.savefig('v5_1_confusion.png', dpi=150, bbox_inches='tight')
plt.show()

## Save & Download

In [None]:
torch.save(model.state_dict(), 'best_v5_1.pth')
with open('label_encoder.pkl', 'wb') as f:
    pickle.dump(le, f)

import json
with open('results.json', 'w') as f:
    json.dump({
        'test_accuracy': float(test_acc),
        'test_f1': float(test_f1),
        'cs_ai_recall': float(cs_ai_recall),
        'best_val_acc': float(best_val_acc)
    }, f, indent=2)

print("Saved: best_v5_1.pth, label_encoder.pkl, results.json, v5_1_history.png, v5_1_confusion.png")

In [None]:
from google.colab import files
for f in ['best_v5_1.pth', 'label_encoder.pkl', 'results.json', 'v5_1_history.png', 'v5_1_confusion.png']:
    files.download(f)