# V5.1 SciBERT Training: Stabilized Cross-Attention + Back-Translation

**Google Colab Optimized Version**

## Problem & Solution

**V5.0 Issue:**
- Epoch 1: 57.60% val acc (PEAK)
- Epoch 2: 46.84% val acc (CRASH -10.76%)
- Root cause: LR too high + class weights too aggressive

**V5.1 Fixes:**
- LR: 5e-5 ‚Üí 3e-5 (40% reduction)
- Class weights: 2.0 ‚Üí 1.4 (30% reduction)
- Dropout: 0.35 ‚Üí 0.30
- Expected: 59-60% accuracy, stable training

## Configuration

- Architecture: CrossAttentionSciBERT
- Dataset: Augmented (450 cs.AI samples duplicated via back-translation)
- Hardware: Colab T4 GPU (16GB VRAM)
- Time: ~30-40 minutes (much faster than M2!)

## 1. Setup & Imports

In [None]:
# Install dependencies
!pip install -q transformers datasets scikit-learn matplotlib seaborn torch

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from transformers import AutoModel, AutoTokenizer, MarianMTModel, MarianTokenizer
from transformers import get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.metrics import recall_score, precision_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pickle
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Upload Dataset

**Option A:** Upload `arxiv_papers_raw.csv` (will augment here)

**Option B:** Upload `arxiv_papers_augmented.csv` (skip augmentation, faster)

Click the folder icon on the left ‚Üí Upload file

In [None]:
# Check which dataset is available
import os

if os.path.exists('arxiv_papers_augmented.csv'):
    print("‚úì Augmented dataset found!")
    print("  Will skip augmentation step.")
    DATA_PATH = 'arxiv_papers_augmented.csv'
    SKIP_AUGMENTATION = True
elif os.path.exists('arxiv_papers_raw.csv'):
    print("‚úì Raw dataset found!")
    print("  Will perform augmentation (adds ~30-40 min).")
    DATA_PATH = 'arxiv_papers_raw.csv'
    SKIP_AUGMENTATION = False
else:
    print("‚ùå No dataset found!")
    print("  Please upload arxiv_papers_raw.csv or arxiv_papers_augmented.csv")
    print("  Click folder icon ‚Üí Upload")

## 3. Data Augmentation (Skip if using pre-augmented dataset)

In [None]:
# Back-Translation Augmentation Class
class BackTranslationAugmenter:
    def __init__(self, device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

        print(f"Loading translation models on {self.device}...")
        
        # English ‚Üí Spanish
        print("  Loading EN‚ÜíES model...")
        self.model_en_es = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-es')
        self.tokenizer_en_es = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-es')
        self.model_en_es.to(self.device)
        self.model_en_es.eval()
        
        # Spanish ‚Üí English
        print("  Loading ES‚ÜíEN model...")
        self.model_es_en = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-es-en')
        self.tokenizer_es_en = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-es-en')
        self.model_es_en.to(self.device)
        self.model_es_en.eval()
        
        print("‚úì Translation models loaded\n")

    def translate(self, text, model, tokenizer, max_length=512):
        inputs = tokenizer(text, return_tensors="pt", padding=True,
                          truncation=True, max_length=max_length)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model.generate(**inputs, max_length=max_length,
                                    num_beams=4, early_stopping=True)
        
        translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return translation

    def back_translate(self, text, max_length=512):
        try:
            spanish = self.translate(text, self.model_en_es, self.tokenizer_en_es, max_length)
            back_translated = self.translate(spanish, self.model_es_en, self.tokenizer_es_en, max_length)
            return back_translated
        except Exception as e:
            print(f"Warning: Back-translation failed: {e}")
            return text

    def augment_dataset(self, df, target_category='cs.AI', augment_factor=1, max_samples=450):
        print("="*70)
        print("DATASET AUGMENTATION")
        print("="*70)
        
        total_before = len(df)
        target_count = len(df[df['category'] == target_category])
        
        print(f"\nOriginal: {total_before} total, {target_count} {target_category}")
        print(f"Augmenting: {min(max_samples, target_count)} samples\n")
        
        target_samples = df[df['category'] == target_category].copy()
        
        if len(target_samples) > max_samples:
            target_samples = target_samples.sample(n=max_samples, random_state=42)
        
        augmented_samples = []
        
        for idx, row in tqdm(target_samples.iterrows(), total=len(target_samples),
                            desc="Back-translating"):
            augmented_abstract = self.back_translate(row['abstract'])
            augmented_samples.append({
                'title': row['title'],
                'abstract': augmented_abstract,
                'category': row['category']
            })
        
        augmented_df = pd.DataFrame(augmented_samples)
        final_df = pd.concat([df, augmented_df], ignore_index=True)
        final_df = final_df.sample(frac=1, random_state=42).reset_index(drop=True)
        
        print(f"\n‚úì Final: {len(final_df)} total, {len(final_df[final_df['category']==target_category])} {target_category}")
        print("="*70 + "\n")
        
        return final_df

In [None]:
# Run augmentation if needed
if not SKIP_AUGMENTATION:
    print("Starting data augmentation...")
    print("This will take ~30-40 minutes on Colab T4 GPU\n")
    
    df_raw = pd.read_csv(DATA_PATH)
    augmenter = BackTranslationAugmenter()
    df_augmented = augmenter.augment_dataset(df_raw, max_samples=450)
    
    # Save augmented dataset
    df_augmented.to_csv('arxiv_papers_augmented.csv', index=False)
    DATA_PATH = 'arxiv_papers_augmented.csv'
    print("‚úì Augmented dataset saved!")
    
    # Clear memory
    del augmenter
    torch.cuda.empty_cache()
else:
    print("‚úì Skipping augmentation (using pre-augmented dataset)")

## 4. Cross-Attention Model Architecture

In [None]:
class CrossAttentionSciBERT(nn.Module):
    """Cross-Attention architecture for title‚Üîabstract interaction"""
    
    def __init__(self, num_classes=4, dropout=0.30, freeze_bert_layers=3):
        super().__init__()
        
        # Load SciBERT
        self.bert = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
        hidden_size = self.bert.config.hidden_size  # 768
        
        # Freeze first N layers
        if freeze_bert_layers > 0:
            for layer in self.bert.encoder.layer[:freeze_bert_layers]:
                for param in layer.parameters():
                    param.requires_grad = False
        
        # Embedding dropout
        self.embedding_dropout = nn.Dropout(0.1)
        
        # Cross-Attention layers (bidirectional)
        self.cross_attn_title_to_abstract = nn.MultiheadAttention(
            embed_dim=hidden_size, num_heads=8, dropout=0.1, batch_first=True
        )
        self.cross_attn_abstract_to_title = nn.MultiheadAttention(
            embed_dim=hidden_size, num_heads=8, dropout=0.1, batch_first=True
        )
        
        # Layer normalization
        self.layer_norm1 = nn.LayerNorm(hidden_size)
        self.layer_norm2 = nn.LayerNorm(hidden_size)
        
        # Attention pooling
        self.title_attention = nn.Linear(hidden_size, 1)
        self.abstract_attention = nn.Linear(hidden_size, 1)
        
        # Fusion network
        self.fusion = nn.Sequential(
            nn.Linear(hidden_size * 2, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout),
            
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(dropout),
            
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout * 0.8)
        )
        
        # Classifier
        self.classifier = nn.Linear(128, num_classes)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        for module in [self.fusion, self.classifier]:
            for m in module.modules():
                if isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, std=0.02)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
    
    def attention_pool(self, hidden_states, attention_layer, mask):
        attention_weights = attention_layer(hidden_states)
        attention_weights = attention_weights.squeeze(-1)
        
        if mask is not None:
            attention_weights = attention_weights.masked_fill(mask == 0, -1e9)
        
        attention_weights = torch.softmax(attention_weights, dim=1)
        pooled = torch.bmm(attention_weights.unsqueeze(1), hidden_states)
        return pooled.squeeze(1), attention_weights
    
    def forward(self, title_input_ids, title_attention_mask,
                abstract_input_ids, abstract_attention_mask):
        
        # Encode title
        title_outputs = self.bert(
            input_ids=title_input_ids,
            attention_mask=title_attention_mask
        )
        title_hidden = self.embedding_dropout(title_outputs.last_hidden_state)
        
        # Encode abstract
        abstract_outputs = self.bert(
            input_ids=abstract_input_ids,
            attention_mask=abstract_attention_mask
        )
        abstract_hidden = self.embedding_dropout(abstract_outputs.last_hidden_state)
        
        # Cross-attention: title ‚Üê abstract
        title_enhanced, _ = self.cross_attn_title_to_abstract(
            query=title_hidden,
            key=abstract_hidden,
            value=abstract_hidden,
            key_padding_mask=(abstract_attention_mask == 0)
        )
        title_enhanced = self.layer_norm1(title_hidden + title_enhanced)
        
        # Cross-attention: abstract ‚Üê title
        abstract_enhanced, _ = self.cross_attn_abstract_to_title(
            query=abstract_hidden,
            key=title_hidden,
            value=title_hidden,
            key_padding_mask=(title_attention_mask == 0)
        )
        abstract_enhanced = self.layer_norm2(abstract_hidden + abstract_enhanced)
        
        # Attention pooling
        title_pooled, _ = self.attention_pool(
            title_enhanced, self.title_attention, title_attention_mask
        )
        abstract_pooled, _ = self.attention_pool(
            abstract_enhanced, self.abstract_attention, abstract_attention_mask
        )
        
        # Concatenate and classify
        combined = torch.cat([title_pooled, abstract_pooled], dim=1)
        features = self.fusion(combined)
        logits = self.classifier(features)
        
        return logits

## 5. Dataset Class

In [None]:
class SciBERTDataset(Dataset):
    """Dataset for dual-encoder (title + abstract separate)"""
    
    def __init__(self, titles, abstracts, labels, tokenizer,
                 max_title_len=32, max_abstract_len=128):
        self.titles = titles
        self.abstracts = abstracts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_title_len = max_title_len
        self.max_abstract_len = max_abstract_len
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        title_encoding = self.tokenizer(
            self.titles[idx],
            max_length=self.max_title_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        abstract_encoding = self.tokenizer(
            self.abstracts[idx],
            max_length=self.max_abstract_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'title_input_ids': title_encoding['input_ids'].squeeze(0),
            'title_attention_mask': title_encoding['attention_mask'].squeeze(0),
            'abstract_input_ids': abstract_encoding['input_ids'].squeeze(0),
            'abstract_attention_mask': abstract_encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

## 6. Data Preparation

In [None]:
# Load data
print("Loading dataset...")
df = pd.read_csv(DATA_PATH)
print(f"‚úì Loaded {len(df)} samples\n")

print("Class distribution:")
print(df['category'].value_counts())
print()

# Encode labels
le = LabelEncoder()
df['label'] = le.fit_transform(df['category'])

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')

# Split data
X = df[['title', 'abstract']]
y = df['label']

X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.15, random_state=42, stratify=y
)

X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.15/(1-0.15), random_state=42, stratify=y_temp
)

# Create datasets
train_dataset = SciBERTDataset(
    X_train['title'].tolist(),
    X_train['abstract'].tolist(),
    y_train.tolist(),
    tokenizer
)

val_dataset = SciBERTDataset(
    X_val['title'].tolist(),
    X_val['abstract'].tolist(),
    y_val.tolist(),
    tokenizer
)

test_dataset = SciBERTDataset(
    X_test['title'].tolist(),
    X_test['abstract'].tolist(),
    y_test.tolist(),
    tokenizer
)

print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

## 7. Training Configuration (V5.1 Stabilized)

In [None]:
# V5.1 Configuration: STABILIZED
FREEZE_BERT_LAYERS = 3
DROPOUT = 0.30  # Reduced from 0.35
BATCH_SIZE = 32  # Increased for GPU (was 12 on M2)
EPOCHS = 10
LR = 3e-5  # REDUCED from 5e-5 (prevents crash)
WEIGHT_DECAY = 0.01
CLASS_WEIGHTS = [1.4, 1.0, 1.0, 1.0]  # SOFTENED from [2.0, 1.0, 1.0, 1.0]
PATIENCE = 3

print("V5.1 STABILIZED CONFIGURATION")
print("="*60)
print(f"Freeze layers: {FREEZE_BERT_LAYERS}")
print(f"Dropout: {DROPOUT}")
print(f"Batch size: {BATCH_SIZE} (GPU optimized)")
print(f"Learning rate: {LR}")
print(f"Class weights: {CLASS_WEIGHTS}")
print("="*60)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

## 8. Model & Optimizer Setup

In [None]:
# Create model
print("Creating Cross-Attention model...")
model = CrossAttentionSciBERT(
    num_classes=4,
    dropout=DROPOUT,
    freeze_bert_layers=FREEZE_BERT_LAYERS
)
model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print()

# Loss function with class weights
class_weights_tensor = torch.FloatTensor(CLASS_WEIGHTS).to(device)
criterion = nn.CrossEntropyLoss(
    label_smoothing=0.1,
    weight=class_weights_tensor
)

# Optimizer with differential learning rates
bert_params = []
classifier_params = []

for name, param in model.named_parameters():
    if 'bert' in name and param.requires_grad:
        bert_params.append(param)
    elif param.requires_grad:
        classifier_params.append(param)

optimizer = torch.optim.AdamW([
    {'params': bert_params, 'lr': LR, 'weight_decay': WEIGHT_DECAY},
    {'params': classifier_params, 'lr': LR * 5, 'weight_decay': WEIGHT_DECAY * 2}
])

# Learning rate scheduler with warmup
num_training_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_training_steps // 10,
    num_training_steps=num_training_steps
)

print(f"Total training steps: {num_training_steps}")
print(f"Warmup steps: {num_training_steps // 10}")
print("‚úì Model and optimizer ready!")

## 9. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [], 'val_f1': []
}

best_val_acc = 0
best_model_state = None
patience_counter = 0

print("="*70)
print("TRAINING V5.1: STABILIZED")
print("="*70)
print()

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 70)
    
    # Training phase
    model.train()
    train_loss = 0
    all_train_preds = []
    all_train_labels = []
    
    pbar = tqdm(train_loader, desc='Training')
    for batch in pbar:
        title_ids = batch['title_input_ids'].to(device)
        title_mask = batch['title_attention_mask'].to(device)
        abstract_ids = batch['abstract_input_ids'].to(device)
        abstract_mask = batch['abstract_attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = model(title_ids, title_mask, abstract_ids, abstract_mask)
        loss = criterion(outputs, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        train_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        all_train_preds.extend(preds.cpu().numpy())
        all_train_labels.extend(labels.cpu().numpy())
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_train_loss = train_loss / len(train_loader)
    train_acc = accuracy_score(all_train_labels, all_train_preds)
    
    # Validation phase
    model.eval()
    val_loss = 0
    all_val_preds = []
    all_val_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validating'):
            title_ids = batch['title_input_ids'].to(device)
            title_mask = batch['title_attention_mask'].to(device)
            abstract_ids = batch['abstract_input_ids'].to(device)
            abstract_mask = batch['abstract_attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(title_ids, title_mask, abstract_ids, abstract_mask)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            all_val_preds.extend(preds.cpu().numpy())
            all_val_labels.extend(labels.cpu().numpy())
    
    avg_val_loss = val_loss / len(val_loader)
    val_acc = accuracy_score(all_val_labels, all_val_preds)
    val_f1 = f1_score(all_val_labels, all_val_preds, average='weighted')
    
    # Metrics
    gap = abs(train_acc - val_acc)
    
    print(f"\nResults:")
    print(f"  Train Loss: {avg_train_loss:.4f}  Train Acc: {train_acc:.4f} ({train_acc*100:.2f}%)")
    print(f"  Val Loss:   {avg_val_loss:.4f}  Val Acc:   {val_acc:.4f} ({val_acc*100:.2f}%)")
    print(f"  Val F1:     {val_f1:.4f}")
    print(f"  Gap (Overfit): {gap:.4f} ({gap*100:.2f}%)")
    
    # Save history
    history['train_loss'].append(avg_train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(avg_val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)
    
    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict().copy()
        patience_counter = 0
        print(f"  ‚úì New best model! (val_acc: {val_acc:.4f})")
    else:
        patience_counter += 1
        print(f"  No improvement ({patience_counter}/{PATIENCE})")
    
    if patience_counter >= PATIENCE:
        print(f"\n‚ö† Early stopping triggered (patience={PATIENCE})")
        break

# Restore best model
print(f"\n‚úì Training complete!")
print(f"  Best val accuracy: {best_val_acc:.4f} ({best_val_acc*100:.2f}%)")
model.load_state_dict(best_model_state)

## 10. Training History Visualization

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='o')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Accuracy
axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
axes[1].plot(history['val_acc'], label='Val Acc', marker='o')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig('v5_1_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Training history saved: v5_1_history.png")

## 11. Final Evaluation on Test Set

In [None]:
print("="*70)
print("FINAL EVALUATION ON TEST SET")
print("="*70)
print()

model.eval()
all_test_preds = []
all_test_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Testing'):
        title_ids = batch['title_input_ids'].to(device)
        title_mask = batch['title_attention_mask'].to(device)
        abstract_ids = batch['abstract_input_ids'].to(device)
        abstract_mask = batch['abstract_attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(title_ids, title_mask, abstract_ids, abstract_mask)
        preds = torch.argmax(outputs, dim=1)
        
        all_test_preds.extend(preds.cpu().numpy())
        all_test_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(all_test_labels, all_test_preds)
test_f1 = f1_score(all_test_labels, all_test_preds, average='weighted')

print(f"Test Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f"Test F1: {test_f1:.4f}")
print()

# Classification report
print("Classification Report:")
print(classification_report(all_test_labels, all_test_preds,
                           target_names=le.classes_, digits=4))

# Per-class metrics
recalls = recall_score(all_test_labels, all_test_preds, average=None)
precisions = precision_score(all_test_labels, all_test_preds, average=None, zero_division=0)

cs_ai_idx = list(le.classes_).index('cs.AI')
cs_ai_recall = recalls[cs_ai_idx]

print(f"\ncs.AI specific metrics:")
print(f"  Recall: {cs_ai_recall:.4f} ({cs_ai_recall*100:.2f}%)")
print(f"  Precision: {precisions[cs_ai_idx]:.4f} ({precisions[cs_ai_idx]*100:.2f}%)")

## 12. Confusion Matrix

In [None]:
cm = confusion_matrix(all_test_labels, all_test_preds)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=le.classes_, yticklabels=le.classes_)
plt.title(f'V5.1 Confusion Matrix (Stabilized Training)\nTest Acc: {test_acc:.3f}',
          fontsize=14, pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Prediction', fontsize=12)
plt.tight_layout()
plt.savefig('v5_1_confusion.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Confusion matrix saved: v5_1_confusion.png")

## 13. Objectives Check

In [None]:
print("="*70)
print("OBJECTIVES CHECK")
print("="*70)

gap_acc = abs(test_acc - 0.60)
gap_cs_ai = max(0, 0.30 - cs_ai_recall)
gap_total = gap_acc + gap_cs_ai

acc_target_met = test_acc >= 0.60
cs_ai_target_met = cs_ai_recall > 0.30

print(f"\nTest Accuracy >= 60%: {'‚úÖ YES' if acc_target_met else '‚ùå NO'} ({test_acc*100:.2f}%)")
print(f"cs.AI Recall > 30%:   {'‚úÖ YES' if cs_ai_target_met else '‚ùå NO'} ({cs_ai_recall*100:.2f}%)")
print(f"\nGap Total: {gap_total:.4f} ({gap_total*100:.2f}%)")

if acc_target_met and cs_ai_target_met:
    print("\n" + "üéâ"*20)
    print("‚úÖ SUCCESS! BOTH OBJECTIVES MET!")
    print("üéâ"*20)
elif test_acc >= 0.59:
    print("\n‚úÖ Very close to target! Excellent result!")
else:
    print("\n‚ö†Ô∏è  Gap remains, but progress made")

print("\n" + "="*70)

# Compare with baseline
print("\nCOMPARISON WITH V5.0 BASELINE")
print("="*70)

baseline_acc = 0.5701
baseline_cs_ai = 0.4189

improvement_acc = test_acc - baseline_acc
improvement_cs_ai = cs_ai_recall - baseline_cs_ai

print(f"\n{'Metric':<20} {'V5.0':<12} {'V5.1':<12} {'Change'}")
print("-"*60)
print(f"{'Test Accuracy':<20} {baseline_acc*100:>6.2f}% {test_acc*100:>11.2f}% {improvement_acc*100:>11.2f}%")
print(f"{'cs.AI Recall':<20} {baseline_cs_ai*100:>6.2f}% {cs_ai_recall*100:>11.2f}% {improvement_cs_ai*100:>11.2f}%")

print("\n" + "="*70)

## 14. Save Model & Results

In [None]:
# Save model
torch.save(model.state_dict(), 'best_scibert_v5_1_colab.pth')
print("‚úì Model saved: best_scibert_v5_1_colab.pth")

# Save label encoder
with open('scibert_label_encoder.pkl', 'wb') as f:
    pickle.dump(le, f)
print("‚úì Label encoder saved: scibert_label_encoder.pkl")

# Save results summary
results = {
    'test_accuracy': test_acc,
    'test_f1': test_f1,
    'cs_ai_recall': cs_ai_recall,
    'cs_ai_precision': precisions[cs_ai_idx],
    'best_val_acc': best_val_acc,
    'history': history,
    'classification_report': classification_report(all_test_labels, all_test_preds,
                                                   target_names=le.classes_, output_dict=True)
}

import json
with open('v5_1_results.json', 'w') as f:
    # Convert numpy types to python types for JSON serialization
    results_json = {
        'test_accuracy': float(test_acc),
        'test_f1': float(test_f1),
        'cs_ai_recall': float(cs_ai_recall),
        'cs_ai_precision': float(precisions[cs_ai_idx]),
        'best_val_acc': float(best_val_acc)
    }
    json.dump(results_json, f, indent=2)

print("‚úì Results saved: v5_1_results.json")
print("\n" + "="*70)
print("‚úì ALL RESULTS SAVED!")
print("="*70)
print("\nDownload these files:")
print("  - best_scibert_v5_1_colab.pth (model)")
print("  - scibert_label_encoder.pkl (label encoder)")
print("  - v5_1_history.png (training curves)")
print("  - v5_1_confusion.png (confusion matrix)")
print("  - v5_1_results.json (metrics)")
print("  - arxiv_papers_augmented.csv (augmented dataset, if created)")

## 15. Download Files (Colab)

Run this cell to download all result files to your local machine.

In [None]:
from google.colab import files

# Download model and results
files.download('best_scibert_v5_1_colab.pth')
files.download('scibert_label_encoder.pkl')
files.download('v5_1_history.png')
files.download('v5_1_confusion.png')
files.download('v5_1_results.json')

print("‚úì All files downloaded!")