# Step 4B: Image Model Training v2 (Improved)

This is an **improved version** of the image model training with better handling of class imbalance.

## üÜï What's New in v2:

### 1. **Class Weights** (CRITICAL - Fixes 0% F1 on Sadness/Surprise!)
   - Automatically computed from training data
   - Sadness & Surprise get ~5x more weight

### 2. **Data Augmentation** (Increases effective dataset size)
   - Random horizontal flip (50% chance)
   - Random rotation (¬±15 degrees)
   - Color jitter (brightness, contrast, saturation)
   - Only during training, not validation

### 3. **Early Stopping**
   - Stops when validation F1 stops improving
   - Patience: 4 epochs (more than text due to harder task)

### 4. **Better Learning Rate Scheduling**
   - ReduceLROnPlateau with patience=3
   - Reduces LR when stuck

### 5. **Gradient Clipping**
   - Max norm: 1.0
   - Prevents exploding gradients

### 6. **Label Smoothing**
   - Smoothing: 0.1
   - Better calibration

### 7. **More Epochs**
   - 12 epochs (up from 5)
   - Early stopping prevents overfitting

**Expected Improvements:**
- **Sadness: 0% F1 ‚Üí 20-30% F1** (huge improvement!)
- **Surprise: 0% F1 ‚Üí 18-28% F1** (huge improvement!)
- Overall: 45.7% acc ‚Üí 52-58% acc

## 1. Setup and Imports

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score
from sklearn.utils.class_weight import compute_class_weight  # NEW: For class weights
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image, ImageEnhance  # NEW: For data augmentation
import json
import os
from pathlib import Path
import warnings
import random
warnings.filterwarnings('ignore')

# Try to import tqdm for notebooks
try:
    from tqdm.notebook import tqdm
    print("‚úì Using notebook progress bars")
except ImportError:
    from tqdm import tqdm
    print("‚úì Using terminal progress bars")

# Set random seeds for reproducibility
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration

In [None]:
# Paths
TRAIN_DATA = "data/train_set.csv"
VAL_DATA = "data/validation_set.csv"
MODEL_DIR = "models"
RESULTS_DIR = "results/image_model_v2"  # NEW: Separate results directory

# Create directories
Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)
Path(RESULTS_DIR).mkdir(parents=True, exist_ok=True)

# Model configuration
MODEL_NAME = 'openai/clip-vit-base-patch32'
BATCH_SIZE = 32  # Increased for better GPU utilization (safe for RTX 2080 Super 8GB)
EPOCHS = 12  # NEW: Increased epochs (early stopping will prevent overfitting)
LEARNING_RATE = 1e-5

# NEW: Training improvements configuration
LABEL_SMOOTHING = 0.1
GRADIENT_CLIP_NORM = 1.0
EARLY_STOP_PATIENCE = 4  # More patience for image task
LR_SCHEDULER_PATIENCE = 3
LR_SCHEDULER_FACTOR = 0.5

# NEW: Data augmentation configuration
AUG_ROTATION_DEGREES = 15  # Random rotation range
AUG_FLIP_PROB = 0.5  # Horizontal flip probability
AUG_COLOR_JITTER = 0.2  # Brightness/contrast variation

# Sentiment labels
LABELS = ['Anger', 'Joy', 'Neutral/Other', 'Sadness', 'Surprise']
LABEL_TO_ID = {label: idx for idx, label in enumerate(LABELS)}
ID_TO_LABEL = {idx: label for label, idx in LABEL_TO_ID.items()}

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Max Epochs: {EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Labels: {LABELS}")
print(f"\nüÜï v2 Improvements:")
print(f"  Label Smoothing: {LABEL_SMOOTHING}")
print(f"  Gradient Clipping: {GRADIENT_CLIP_NORM}")
print(f"  Early Stop Patience: {EARLY_STOP_PATIENCE} epochs")
print(f"  LR Scheduler: ReduceLROnPlateau (patience={LR_SCHEDULER_PATIENCE})")
print(f"  Data Augmentation:")
print(f"    - Random rotation: ¬±{AUG_ROTATION_DEGREES}¬∞")
print(f"    - Random horizontal flip: {AUG_FLIP_PROB*100}%")
print(f"    - Color jitter: {AUG_COLOR_JITTER}")

## 3. Data Loading and Class Weight Computation

In [None]:
# Load data
print("Loading datasets...")
train_df = pd.read_csv(TRAIN_DATA)
val_df = pd.read_csv(VAL_DATA)

# Filter for image posts only
train_df = train_df[train_df['media_type'] == 'image'].reset_index(drop=True)
val_df = val_df[val_df['media_type'] == 'image'].reset_index(drop=True)

print(f"Train set (images only): {len(train_df):,} samples")
print(f"Validation set (images only): {len(val_df):,} samples")

# Display sentiment distribution
print("\nTrain set sentiment distribution:")
train_counts = train_df['post_sentiment'].value_counts()
print(train_counts)
print("\nPercentages:")
print((train_counts / len(train_df) * 100).round(2))

print("\nValidation set sentiment distribution:")
val_counts = val_df['post_sentiment'].value_counts()
print(val_counts)
print("\nPercentages:")
print((val_counts / len(val_df) * 100).round(2))

In [None]:
# NEW: Compute class weights for handling severe imbalance
print("\n" + "="*80)
print("COMPUTING CLASS WEIGHTS (NEW in v2)")
print("="*80)

# Convert labels to numeric
train_labels_numeric = train_df['post_sentiment'].map(LABEL_TO_ID).values

# Compute balanced class weights
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.arange(len(LABELS)),
    y=train_labels_numeric
)

class_weights_tensor = torch.FloatTensor(class_weights).to(device)

print("\nClass weights (higher = model cares more):")
for label, weight in zip(LABELS, class_weights):
    count = train_counts.get(label, 0)
    print(f"  {label:15s}: {weight:.4f} (n={count:,})")

print("\nüí° Critical for Image Model:")
print("   Sadness and Surprise currently have 0% F1 in v1!")
print("   With ~5x weight, the model will actually learn these classes.")

## 4. NEW: Data Augmentation Functions

In [None]:
def apply_augmentation(image):
    """
    Apply random augmentations to image for training.
    Increases dataset diversity without collecting more data.
    """
    # Random horizontal flip
    if random.random() < AUG_FLIP_PROB:
        image = image.transpose(Image.FLIP_LEFT_RIGHT)
    
    # Random rotation
    if random.random() < 0.5:
        angle = random.uniform(-AUG_ROTATION_DEGREES, AUG_ROTATION_DEGREES)
        image = image.rotate(angle, fillcolor=(255, 255, 255))
    
    # Random brightness adjustment
    if random.random() < 0.5:
        factor = random.uniform(1 - AUG_COLOR_JITTER, 1 + AUG_COLOR_JITTER)
        enhancer = ImageEnhance.Brightness(image)
        image = enhancer.enhance(factor)
    
    # Random contrast adjustment
    if random.random() < 0.5:
        factor = random.uniform(1 - AUG_COLOR_JITTER, 1 + AUG_COLOR_JITTER)
        enhancer = ImageEnhance.Contrast(image)
        image = enhancer.enhance(factor)
    
    # Random saturation adjustment
    if random.random() < 0.5:
        factor = random.uniform(1 - AUG_COLOR_JITTER, 1 + AUG_COLOR_JITTER)
        enhancer = ImageEnhance.Color(image)
        image = enhancer.enhance(factor)
    
    return image

print("‚úì Data augmentation functions defined")

## 5. Dataset and DataLoader with Augmentation

In [None]:
# NEW: Dataset with optional augmentation
class BrawlStarsImageDataset(Dataset):
    def __init__(self, dataframe, processor, augment=False):
        self.data = dataframe.reset_index(drop=True)
        self.processor = processor
        self.augment = augment  # NEW: Enable augmentation for training only
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Normalize path (handle Windows backslashes)
        image_path = str(row['local_media_path']).replace('\\', '/')
        
        try:
            # Load image
            image = Image.open(image_path).convert('RGB')
            
            # NEW: Apply augmentation during training
            if self.augment:
                image = apply_augmentation(image)
            
            # Process image with CLIP processor
            inputs = self.processor(images=image, return_tensors="pt")
            
            # Get label
            label = LABEL_TO_ID[row['post_sentiment']]
            
            return {
                'pixel_values': inputs['pixel_values'].squeeze(0),
                'label': torch.tensor(label, dtype=torch.long)
            }
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Return a dummy sample if image fails to load
            dummy_image = Image.new('RGB', (224, 224), color='black')
            inputs = self.processor(images=dummy_image, return_tensors="pt")
            label = LABEL_TO_ID[row['post_sentiment']]
            return {
                'pixel_values': inputs['pixel_values'].squeeze(0),
                'label': torch.tensor(label, dtype=torch.long)
            }

# Initialize CLIP processor
print("Loading CLIP processor...")
processor = CLIPProcessor.from_pretrained(MODEL_NAME)

# Create datasets (augmentation only for training!)
train_dataset = BrawlStarsImageDataset(train_df, processor, augment=True)
val_dataset = BrawlStarsImageDataset(val_df, processor, augment=False)

# Create dataloaders - Windows optimized
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE,  # 32 for better GPU utilization
    shuffle=True, 
    num_workers=0,          # Windows-safe (avoid multiprocessing issues)
    pin_memory=True         # Faster GPU transfer
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=0,          # Windows-safe
    pin_memory=True
)

print(f"‚úì Created {len(train_loader)} train batches and {len(val_loader)} validation batches")
print(f"‚úì Training data will be augmented (rotation, flip, color jitter)")
print(f"üí° Windows mode: batch_size={BATCH_SIZE} + pin_memory for GPU optimization")
print(f"   Expected GPU usage: 10% ‚Üí 40-60%")

## 6. Model Definition

In [None]:
class ImageSentimentClassifier(nn.Module):
    def __init__(self, n_classes=5):
        super(ImageSentimentClassifier, self).__init__()
        self.clip = CLIPModel.from_pretrained(MODEL_NAME)
        
        # Get CLIP's vision model output dimension
        self.vision_embed_dim = self.clip.vision_model.config.hidden_size
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.vision_embed_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, n_classes)
        )
        
    def forward(self, pixel_values):
        # Get image features from CLIP's vision encoder
        vision_outputs = self.clip.vision_model(pixel_values=pixel_values)
        # Use pooled output (CLS token)
        image_embeds = vision_outputs.pooler_output
        # Classification
        logits = self.classifier(image_embeds)
        return logits
    
    def get_embedding(self, pixel_values):
        """Extract embedding without classification head (for Phase 2)"""
        with torch.no_grad():
            vision_outputs = self.clip.vision_model(pixel_values=pixel_values)
            image_embeds = vision_outputs.pooler_output
        return image_embeds

# Initialize model
model = ImageSentimentClassifier(n_classes=len(LABELS))
model = model.to(device)

print(f"‚úì Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")

## 7. Training Setup with Improvements

In [None]:
# NEW: Loss function with class weights AND label smoothing
class LabelSmoothingCrossEntropy(nn.Module):
    """Cross entropy with label smoothing for better calibration"""
    def __init__(self, weight=None, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
        self.weight = weight
        
    def forward(self, pred, target):
        n_classes = pred.size(-1)
        log_pred = torch.log_softmax(pred, dim=-1)
        
        # Apply label smoothing
        with torch.no_grad():
            true_dist = torch.zeros_like(log_pred)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        
        # Compute loss per sample
        loss = torch.sum(-true_dist * log_pred, dim=-1)
        
        # Apply class weights to each sample based on its target class
        if self.weight is not None:
            loss = loss * self.weight[target]
            
        return torch.mean(loss)

criterion = LabelSmoothingCrossEntropy(weight=class_weights_tensor, smoothing=LABEL_SMOOTHING)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

# NEW: ReduceLROnPlateau scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',  # Maximize F1
    factor=LR_SCHEDULER_FACTOR,
    patience=LR_SCHEDULER_PATIENCE,
    verbose=True
)

print(f"‚úì Optimizer and scheduler configured")
print(f"\nüÜï Loss function: Cross Entropy + Class Weights + Label Smoothing")
print(f"  This is CRITICAL for fixing the 0% F1 on Sadness/Surprise!")

## 8. Training Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0
    
    progress_bar = tqdm(dataloader, desc='Training')
    for batch in progress_bar:
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['label'].to(device)
        
        # Forward pass
        outputs = model(pixel_values)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # NEW: Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRADIENT_CLIP_NORM)
        
        optimizer.step()
        
        # Calculate accuracy
        _, preds = torch.max(outputs, dim=1)
        correct_predictions += torch.sum(preds == labels).item()
        total_samples += labels.size(0)
        total_loss += loss.item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{correct_predictions/total_samples:.4f}'
        })
    
    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    
    return avg_loss, accuracy


def eval_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc='Validation')
        for batch in progress_bar:
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(pixel_values)
            loss = criterion(outputs, labels)
            
            _, preds = torch.max(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    return avg_loss, accuracy, f1, all_preds, all_labels

print("‚úì Training functions defined with gradient clipping")

## 9. Training Loop with Early Stopping

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'val_f1': [],
    'learning_rates': []
}

best_val_f1 = 0
best_epoch = 0
epochs_without_improvement = 0

print("=" * 80)
print("STARTING TRAINING (v2 with improvements)")
print("=" * 80)
print(f"\nüÜï Key improvements enabled:")
print(f"  - Class weights (Sadness/Surprise get ~5x more weight)")
print(f"  - Data augmentation (flip, rotation, color jitter)")
print(f"  - Early stopping (patience={EARLY_STOP_PATIENCE} epochs)")
print("\n")

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    print("-" * 80)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc, val_f1, _, _ = eval_model(model, val_loader, criterion, device)
    
    # Get current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)
    history['learning_rates'].append(current_lr)
    
    # Print metrics
    print(f"\nResults:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f} | Val F1: {val_f1:.4f}")
    print(f"  Learning Rate: {current_lr:.2e}")
    
    # Update scheduler
    scheduler.step(val_f1)
    
    # Save best model
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_epoch = epoch + 1
        epochs_without_improvement = 0
        torch.save(model.state_dict(), f"{MODEL_DIR}/image_specialist_v2_best.pth")
        print(f"  ‚úì New best model saved! (F1: {val_f1:.4f})")
    else:
        epochs_without_improvement += 1
        print(f"  No improvement for {epochs_without_improvement} epoch(s)")
    
    # Early stopping check
    if epochs_without_improvement >= EARLY_STOP_PATIENCE:
        print(f"\nüõë Early stopping triggered! No improvement for {EARLY_STOP_PATIENCE} epochs.")
        print(f"   Best F1: {best_val_f1:.4f} at epoch {best_epoch}")
        break

print("\n" + "=" * 80)
print("TRAINING COMPLETE")
print("=" * 80)
print(f"Best validation F1: {best_val_f1:.4f} (Epoch {best_epoch})")
print(f"Total epochs run: {epoch + 1}/{EPOCHS}")

## 10. Save Final Model

In [None]:
# Save final model
torch.save(model.state_dict(), f"{MODEL_DIR}/image_specialist_v2.pth")
print(f"‚úì Final model saved to {MODEL_DIR}/image_specialist_v2.pth")

# Save training history
with open(f"{RESULTS_DIR}/training_history.json", 'w') as f:
    json.dump(history, f, indent=2)
print(f"‚úì Training history saved to {RESULTS_DIR}/training_history.json")

## 11. Load Best Model for Evaluation

In [None]:
# Load best model
model.load_state_dict(torch.load(f"{MODEL_DIR}/image_specialist_v2_best.pth"))
print(f"‚úì Loaded best model (Epoch {best_epoch}, F1: {best_val_f1:.4f})")

## 12. Final Evaluation

In [None]:
# Evaluate on validation set
print("Evaluating on validation set...")
val_loss, val_acc, val_f1, val_preds, val_labels = eval_model(model, val_loader, criterion, device)

print(f"\nFinal Validation Metrics:")
print(f"  Loss: {val_loss:.4f}")
print(f"  Accuracy: {val_acc:.4f}")
print(f"  Weighted F1: {val_f1:.4f}")

In [None]:
# Generate classification report
report = classification_report(val_labels, val_preds, target_names=LABELS, digits=4)
print("\nClassification Report:")
print(report)

# Save classification report
with open(f"{RESULTS_DIR}/evaluation_report.txt", 'w') as f:
    f.write("IMAGE MODEL (CLIP) v2 EVALUATION REPORT\n")
    f.write("=" * 80 + "\n\n")
    f.write("Improvements in v2:\n")
    f.write("  - Class weights for severe imbalance (Sadness/Surprise ~5x)\n")
    f.write("  - Data augmentation (rotation, flip, color jitter)\n")
    f.write("  - Label smoothing (0.1)\n")
    f.write("  - Gradient clipping (1.0)\n")
    f.write("  - Early stopping (patience=4)\n")
    f.write("  - ReduceLROnPlateau scheduler\n\n")
    f.write(f"Best Epoch: {best_epoch}\n")
    f.write(f"Validation Loss: {val_loss:.4f}\n")
    f.write(f"Validation Accuracy: {val_acc:.4f}\n")
    f.write(f"Validation Weighted F1: {val_f1:.4f}\n\n")
    f.write("Classification Report:\n")
    f.write(report)

print(f"\n‚úì Evaluation report saved to {RESULTS_DIR}/evaluation_report.txt")

In [None]:
# Generate detailed metrics JSON
from sklearn.metrics import precision_recall_fscore_support

precision, recall, f1_per_class, support = precision_recall_fscore_support(
    val_labels, val_preds, labels=range(len(LABELS))
)

metrics_v2 = {
    "version": "v2",
    "improvements": [
        "Class weights (Sadness/Surprise ~5x)",
        "Data augmentation (rotation, flip, color jitter)",
        "Label smoothing (0.1)",
        "Gradient clipping (1.0)",
        "Early stopping (patience=4)",
        "ReduceLROnPlateau scheduler"
    ],
    "overall": {
        "accuracy": float(val_acc),
        "weighted_f1": float(val_f1),
        "loss": float(val_loss),
        "best_epoch": best_epoch,
        "total_epochs": epoch + 1
    },
    "per_class": {}
}

for idx, label in enumerate(LABELS):
    metrics_v2["per_class"][label] = {
        "precision": float(precision[idx]),
        "recall": float(recall[idx]),
        "f1_score": float(f1_per_class[idx]),
        "support": int(support[idx])
    }

# Save metrics
with open(f"{RESULTS_DIR}/evaluation_metrics.json", 'w') as f:
    json.dump(metrics_v2, f, indent=2)

print(f"‚úì Detailed metrics saved to {RESULTS_DIR}/evaluation_metrics.json")

## 13. üÜï Comparison with v1 (CRITICAL!)

In [None]:
# Try to load v1 metrics for comparison
v1_metrics_path = "results/image_model/evaluation_metrics.json"
comparison_available = os.path.exists(v1_metrics_path)

if comparison_available:
    with open(v1_metrics_path, 'r') as f:
        metrics_v1 = json.load(f)
    
    print("=" * 80)
    print("üìä COMPARISON: v1 vs v2 - IMAGE MODEL")
    print("=" * 80)
    
    print("\nüéØ Overall Metrics:")
    print(f"{'Metric':<20} {'v1':<12} {'v2':<12} {'Change':<12}")
    print("-" * 56)
    
    v1_acc = metrics_v1['overall']['accuracy']
    v2_acc = metrics_v2['overall']['accuracy']
    acc_change = v2_acc - v1_acc
    acc_symbol = "‚úÖ" if acc_change > 0 else "‚ö†Ô∏è" if acc_change < 0 else "‚ûñ"
    print(f"{'Accuracy':<20} {v1_acc:<12.4f} {v2_acc:<12.4f} {acc_symbol} {acc_change:+.4f}")
    
    v1_f1 = metrics_v1['overall']['weighted_f1']
    v2_f1 = metrics_v2['overall']['weighted_f1']
    f1_change = v2_f1 - v1_f1
    f1_symbol = "‚úÖ" if f1_change > 0 else "‚ö†Ô∏è" if f1_change < 0 else "‚ûñ"
    print(f"{'Weighted F1':<20} {v1_f1:<12.4f} {v2_f1:<12.4f} {f1_symbol} {f1_change:+.4f}")
    
    print("\nüé≠ Per-Class F1 Scores (Focus on Sadness & Surprise!):")
    print(f"{'Class':<20} {'v1 F1':<12} {'v2 F1':<12} {'Change':<12} {'Improvement'}")
    print("-" * 80)
    
    for label in LABELS:
        v1_class_f1 = metrics_v1['per_class'][label]['f1_score']
        v2_class_f1 = metrics_v2['per_class'][label]['f1_score']
        class_f1_change = v2_class_f1 - v1_class_f1
        
        # Special highlighting for Sadness/Surprise (were 0% in v1!)
        if label in ['Sadness', 'Surprise'] and v2_class_f1 > 0.15:
            symbol = "üöÄüöÄüöÄ"  # Triple rocket for huge improvement!
            if v1_class_f1 == 0:
                improvement = "‚àû (was 0%!)"
            else:
                pct_change = (class_f1_change / max(v1_class_f1, 0.001)) * 100
                improvement = f"+{pct_change:.0f}%"
        elif class_f1_change > 0.05:
            symbol = "üöÄ"
            pct_change = (class_f1_change / max(v1_class_f1, 0.001)) * 100
            improvement = f"+{pct_change:.1f}%"
        elif class_f1_change > 0:
            symbol = "‚úÖ"
            pct_change = (class_f1_change / max(v1_class_f1, 0.001)) * 100
            improvement = f"+{pct_change:.1f}%"
        elif class_f1_change < -0.05:
            symbol = "‚ö†Ô∏è"
            pct_change = (class_f1_change / max(v1_class_f1, 0.001)) * 100
            improvement = f"{pct_change:.1f}%"
        else:
            symbol = "‚ûñ"
            improvement = "~0%"
        
        print(f"{label:<20} {v1_class_f1:<12.4f} {v2_class_f1:<12.4f} {symbol} {class_f1_change:+.4f}    {improvement}")
    
    print("\n" + "=" * 80)
    print("üí° Key Success Metrics:")
    print("   üéØ Did Sadness go from 0% to >15% F1?")
    print("   üéØ Did Surprise go from 0% to >15% F1?")
    print("   üéØ Did overall F1 increase by >5%?")
    print("=" * 80)
    
    # Save comparison
    comparison = {
        "v1": metrics_v1,
        "v2": metrics_v2,
        "improvements": {
            "accuracy_change": float(acc_change),
            "f1_change": float(f1_change),
            "sadness_f1_change": float(metrics_v2['per_class']['Sadness']['f1_score'] - metrics_v1['per_class']['Sadness']['f1_score']),
            "surprise_f1_change": float(metrics_v2['per_class']['Surprise']['f1_score'] - metrics_v1['per_class']['Surprise']['f1_score'])
        }
    }
    
    with open(f"{RESULTS_DIR}/comparison_v1_vs_v2.json", 'w') as f:
        json.dump(comparison, f, indent=2)
    
    print(f"\n‚úì Comparison saved to {RESULTS_DIR}/comparison_v1_vs_v2.json")
else:
    print("‚ö†Ô∏è v1 metrics not found. Run the original notebook first for comparison.")

## 14. Visualizations

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(20, 5))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Validation Loss', marker='s')
axes[0].axvline(best_epoch - 1, color='r', linestyle='--', alpha=0.5, label=f'Best Epoch ({best_epoch})')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss (v2)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(history['train_acc'], label='Train Accuracy', marker='o')
axes[1].plot(history['val_acc'], label='Validation Accuracy', marker='s')
axes[1].plot(history['val_f1'], label='Validation F1', marker='^')
axes[1].axvline(best_epoch - 1, color='r', linestyle='--', alpha=0.5, label=f'Best Epoch ({best_epoch})')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Score')
axes[1].set_title('Training and Validation Metrics (v2)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Learning rate plot
axes[2].plot(history['learning_rates'], marker='o', color='purple')
axes[2].axvline(best_epoch - 1, color='r', linestyle='--', alpha=0.5, label=f'Best Epoch ({best_epoch})')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('Learning Rate Schedule (v2)')
axes[2].set_yscale('log')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{RESULTS_DIR}/training_curves.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Training curves saved to {RESULTS_DIR}/training_curves.png")

In [None]:
# Plot confusion matrix
cm = confusion_matrix(val_labels, val_preds)

plt.figure(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Greens',
    xticklabels=LABELS,
    yticklabels=LABELS,
    cbar_kws={'label': 'Count'}
)
plt.title('Confusion Matrix - Image Model v2 (Validation Set)', fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(f"{RESULTS_DIR}/confusion_matrix.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Confusion matrix saved to {RESULTS_DIR}/confusion_matrix.png")

In [None]:
# Plot per-class F1 scores with v1 comparison
f1_scores_v2 = [metrics_v2['per_class'][label]['f1_score'] for label in LABELS]

fig, ax = plt.subplots(figsize=(12, 6))

if comparison_available:
    f1_scores_v1 = [metrics_v1['per_class'][label]['f1_score'] for label in LABELS]
    
    x = np.arange(len(LABELS))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, f1_scores_v1, width, label='v1 (Original)', color='#95a5a6', alpha=0.7)
    bars2 = ax.bar(x + width/2, f1_scores_v2, width, label='v2 (Improved)', 
                   color=['#e74c3c', '#f39c12', '#95a5a6', '#3498db', '#9b59b6'])
    
    ax.set_xlabel('Sentiment', fontsize=12)
    ax.set_ylabel('F1 Score', fontsize=12)
    ax.set_title('Per-Class F1 Scores - Image Model v1 vs v2 Comparison\n(Watch Sadness & Surprise!)', 
                fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(LABELS, rotation=45, ha='right')
    ax.legend()
    ax.set_ylim(0, 1.0)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add value labels
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.3f}',
                   ha='center', va='bottom', fontsize=9)
else:
    bars = ax.bar(LABELS, f1_scores_v2, color=['#e74c3c', '#f39c12', '#95a5a6', '#3498db', '#9b59b6'])
    ax.set_xlabel('Sentiment', fontsize=12)
    ax.set_ylabel('F1 Score', fontsize=12)
    ax.set_title('Per-Class F1 Scores - Image Model v2', fontsize=14, fontweight='bold')
    ax.set_ylim(0, 1.0)
    plt.xticks(rotation=45, ha='right')
    ax.grid(True, alpha=0.3, axis='y')
    
    for bar, score in zip(bars, f1_scores_v2):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{score:.3f}',
               ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig(f"{RESULTS_DIR}/f1_scores_comparison.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì F1 scores comparison saved to {RESULTS_DIR}/f1_scores_comparison.png")

## 15. Summary

In [None]:
print("=" * 80)
print("IMAGE MODEL v2 TRAINING COMPLETE!")
print("=" * 80)

print("\nüÜï Improvements Applied:")
print("  1. ‚úÖ Class weights (Sadness/Surprise ~5x more weight)")
print("  2. ‚úÖ Data augmentation (rotation, flip, color jitter)")
print("  3. ‚úÖ Label smoothing (0.1)")
print("  4. ‚úÖ Gradient clipping (1.0)")
print("  5. ‚úÖ Early stopping (patience=4)")
print("  6. ‚úÖ ReduceLROnPlateau scheduler")

print("\nFiles Generated:")
print(f"  1. Model weights: {MODEL_DIR}/image_specialist_v2.pth")
print(f"  2. Best model weights: {MODEL_DIR}/image_specialist_v2_best.pth")
print(f"  3. Training history: {RESULTS_DIR}/training_history.json")
print(f"  4. Evaluation metrics: {RESULTS_DIR}/evaluation_metrics.json")
print(f"  5. Evaluation report: {RESULTS_DIR}/evaluation_report.txt")
print(f"  6. Training curves: {RESULTS_DIR}/training_curves.png")
print(f"  7. Confusion matrix: {RESULTS_DIR}/confusion_matrix.png")
print(f"  8. F1 scores comparison: {RESULTS_DIR}/f1_scores_comparison.png")
if comparison_available:
    print(f"  9. v1 vs v2 comparison: {RESULTS_DIR}/comparison_v1_vs_v2.json")

print("\nFinal Performance (v2):")
print(f"  Accuracy: {val_acc:.4f}")
print(f"  Weighted F1: {val_f1:.4f}")
print(f"  Best Epoch: {best_epoch}/{epoch + 1}")

print("\nPer-Class F1 Scores (v2):")
for label in LABELS:
    f1 = metrics_v2['per_class'][label]['f1_score']
    support = metrics_v2['per_class'][label]['support']
    print(f"  {label:15s}: {f1:.4f} (n={support})")

if comparison_available:
    print("\nüìä Comparison with v1:")
    print(f"  Overall Accuracy: {acc_change:+.4f}")
    print(f"  Overall F1: {f1_change:+.4f}")
    
    sadness_v1 = metrics_v1['per_class']['Sadness']['f1_score']
    sadness_v2 = metrics_v2['per_class']['Sadness']['f1_score']
    surprise_v1 = metrics_v1['per_class']['Surprise']['f1_score']
    surprise_v2 = metrics_v2['per_class']['Surprise']['f1_score']
    
    print(f"\n  üéØ Critical Improvements (were 0% in v1!):")
    print(f"    Sadness:  {sadness_v1:.4f} ‚Üí {sadness_v2:.4f} ({sadness_v2-sadness_v1:+.4f})")
    print(f"    Surprise: {surprise_v1:.4f} ‚Üí {surprise_v2:.4f} ({surprise_v2-surprise_v1:+.4f})")

print("\n" + "=" * 80)
print("Next Steps:")
print("  1. Check if Sadness/Surprise F1 improved from 0%")
print("  2. If v2 is better, use image_specialist_v2_best.pth for fusion")
print("  3. Create video model v2 with same improvements")
print("=" * 80)