# Step 4A: Text Model Training (DistilBERT Specialist)

This notebook fine-tunes a pre-trained DistilBERT model for Brawl Stars sentiment classification.

**Goal**: Train the text specialist to predict post sentiment from title + text.

**Outputs**:
- `models/text_specialist.pth` - Trained model weights
- `results/text_model/confusion_matrix.png` - Confusion matrix visualization
- `results/text_model/training_curves.png` - Training/validation loss and accuracy
- `results/text_model/evaluation_metrics.json` - Detailed metrics (F1, precision, recall)
- `results/text_model/training_history.json` - Epoch-by-epoch training history
- `results/text_model/evaluation_report.txt` - Human-readable classification report

## 1. Setup and Imports

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import json
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration

In [None]:
# Paths
TRAIN_DATA = "data/train_set.csv"
VAL_DATA = "data/validation_set.csv"
MODEL_DIR = "models"
RESULTS_DIR = "results/text_model"

# Create directories
Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)
Path(RESULTS_DIR).mkdir(parents=True, exist_ok=True)

# Model configuration
MODEL_NAME = 'distilbert-base-uncased'
MAX_LENGTH = 128  # Maximum token length
BATCH_SIZE = 32
EPOCHS = 5
LEARNING_RATE = 2e-5
WARMUP_STEPS = 100

# Sentiment labels
LABELS = ['Anger', 'Joy', 'Neutral/Other', 'Sadness', 'Surprise']
LABEL_TO_ID = {label: idx for idx, label in enumerate(LABELS)}
ID_TO_LABEL = {idx: label for label, idx in LABEL_TO_ID.items()}

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Max Length: {MAX_LENGTH}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Labels: {LABELS}")

## 3. Data Loading and Preprocessing

In [None]:
# Load data
print("Loading datasets...")
train_df = pd.read_csv(TRAIN_DATA)
val_df = pd.read_csv(VAL_DATA)

print(f"Train set: {len(train_df):,} samples")
print(f"Validation set: {len(val_df):,} samples")

# Display sentiment distribution
print("\nTrain set sentiment distribution:")
print(train_df['post_sentiment'].value_counts())

print("\nValidation set sentiment distribution:")
print(val_df['post_sentiment'].value_counts())

In [None]:
# Custom Dataset class
class BrawlStarsTextDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Combine title and text
        title = str(row['title']) if pd.notna(row['title']) else ""
        text = str(row['text']) if pd.notna(row['text']) else ""
        combined_text = f"{title} {text}".strip()
        
        # Tokenize
        encoding = self.tokenizer(
            combined_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Get label
        label = LABEL_TO_ID[row['post_sentiment']]
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(label, dtype=torch.long)
        }

# Initialize tokenizer
print("Loading tokenizer...")
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)

# Create datasets
train_dataset = BrawlStarsTextDataset(train_df, tokenizer, MAX_LENGTH)
val_dataset = BrawlStarsTextDataset(val_df, tokenizer, MAX_LENGTH)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"✓ Created {len(train_loader)} train batches and {len(val_loader)} validation batches")

## 4. Model Definition

In [None]:
class TextSentimentClassifier(nn.Module):
    def __init__(self, n_classes=5):
        super(TextSentimentClassifier, self).__init__()
        self.bert = DistilBertModel.from_pretrained(MODEL_NAME)
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # Use [CLS] token representation
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits
    
    def get_embedding(self, input_ids, attention_mask):
        """Extract embedding without classification head (for Phase 2)"""
        with torch.no_grad():
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.last_hidden_state[:, 0, :]
        return pooled_output

# Initialize model
model = TextSentimentClassifier(n_classes=len(LABELS))
model = model.to(device)

print(f"✓ Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")

## 5. Training Setup

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=WARMUP_STEPS,
    num_training_steps=total_steps
)

print(f"✓ Optimizer and scheduler configured")
print(f"  Total training steps: {total_steps}")
print(f"  Warmup steps: {WARMUP_STEPS}")

## 6. Training Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0
    
    progress_bar = tqdm(dataloader, desc='Training')
    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        # Forward pass
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        # Calculate accuracy
        _, preds = torch.max(outputs, dim=1)
        correct_predictions += torch.sum(preds == labels).item()
        total_samples += labels.size(0)
        total_loss += loss.item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{correct_predictions/total_samples:.4f}'
        })
    
    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    
    return avg_loss, accuracy


def eval_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc='Validation')
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            
            _, preds = torch.max(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    return avg_loss, accuracy, f1, all_preds, all_labels

print("✓ Training functions defined")

## 7. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'val_f1': []
}

best_val_f1 = 0
best_epoch = 0

print("=" * 80)
print("STARTING TRAINING")
print("=" * 80)

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    print("-" * 80)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, scheduler, device)
    
    # Validate
    val_loss, val_acc, val_f1, _, _ = eval_model(model, val_loader, criterion, device)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)
    
    # Print metrics
    print(f"\nResults:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f} | Val F1: {val_f1:.4f}")
    
    # Save best model
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_epoch = epoch + 1
        torch.save(model.state_dict(), f"{MODEL_DIR}/text_specialist_best.pth")
        print(f"  ✓ New best model saved! (F1: {val_f1:.4f})")

print("\n" + "=" * 80)
print("TRAINING COMPLETE")
print("=" * 80)
print(f"Best validation F1: {best_val_f1:.4f} (Epoch {best_epoch})")

## 8. Save Final Model

In [None]:
# Save final model
torch.save(model.state_dict(), f"{MODEL_DIR}/text_specialist.pth")
print(f"✓ Final model saved to {MODEL_DIR}/text_specialist.pth")

# Save training history
with open(f"{RESULTS_DIR}/training_history.json", 'w') as f:
    json.dump(history, f, indent=2)
print(f"✓ Training history saved to {RESULTS_DIR}/training_history.json")

## 9. Load Best Model for Evaluation

In [None]:
# Load best model
model.load_state_dict(torch.load(f"{MODEL_DIR}/text_specialist_best.pth"))
print(f"✓ Loaded best model (Epoch {best_epoch}, F1: {best_val_f1:.4f})")

## 10. Final Evaluation and Metrics

In [None]:
# Evaluate on validation set
print("Evaluating on validation set...")
val_loss, val_acc, val_f1, val_preds, val_labels = eval_model(model, val_loader, criterion, device)

print(f"\nFinal Validation Metrics:")
print(f"  Loss: {val_loss:.4f}")
print(f"  Accuracy: {val_acc:.4f}")
print(f"  Weighted F1: {val_f1:.4f}")

In [None]:
# Generate classification report
report = classification_report(val_labels, val_preds, target_names=LABELS, digits=4)
print("\nClassification Report:")
print(report)

# Save classification report
with open(f"{RESULTS_DIR}/evaluation_report.txt", 'w') as f:
    f.write("TEXT MODEL EVALUATION REPORT\n")
    f.write("=" * 80 + "\n\n")
    f.write(f"Best Epoch: {best_epoch}\n")
    f.write(f"Validation Loss: {val_loss:.4f}\n")
    f.write(f"Validation Accuracy: {val_acc:.4f}\n")
    f.write(f"Validation Weighted F1: {val_f1:.4f}\n\n")
    f.write("Classification Report:\n")
    f.write(report)

print(f"\n✓ Evaluation report saved to {RESULTS_DIR}/evaluation_report.txt")

In [None]:
# Generate detailed metrics JSON
from sklearn.metrics import precision_recall_fscore_support

precision, recall, f1_per_class, support = precision_recall_fscore_support(
    val_labels, val_preds, labels=range(len(LABELS))
)

metrics = {
    "overall": {
        "accuracy": float(val_acc),
        "weighted_f1": float(val_f1),
        "loss": float(val_loss)
    },
    "per_class": {}
}

for idx, label in enumerate(LABELS):
    metrics["per_class"][label] = {
        "precision": float(precision[idx]),
        "recall": float(recall[idx]),
        "f1_score": float(f1_per_class[idx]),
        "support": int(support[idx])
    }

# Save metrics
with open(f"{RESULTS_DIR}/evaluation_metrics.json", 'w') as f:
    json.dump(metrics, f, indent=2)

print(f"✓ Detailed metrics saved to {RESULTS_DIR}/evaluation_metrics.json")

## 11. Visualizations

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Validation Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(history['train_acc'], label='Train Accuracy', marker='o')
axes[1].plot(history['val_acc'], label='Validation Accuracy', marker='s')
axes[1].plot(history['val_f1'], label='Validation F1', marker='^')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Score')
axes[1].set_title('Training and Validation Metrics')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{RESULTS_DIR}/training_curves.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Training curves saved to {RESULTS_DIR}/training_curves.png")

In [None]:
# Plot confusion matrix
cm = confusion_matrix(val_labels, val_preds)

plt.figure(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=LABELS,
    yticklabels=LABELS,
    cbar_kws={'label': 'Count'}
)
plt.title('Confusion Matrix - Text Model (Validation Set)', fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(f"{RESULTS_DIR}/confusion_matrix.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Confusion matrix saved to {RESULTS_DIR}/confusion_matrix.png")

In [None]:
# Plot per-class F1 scores
f1_scores = [metrics['per_class'][label]['f1_score'] for label in LABELS]

plt.figure(figsize=(10, 6))
bars = plt.bar(LABELS, f1_scores, color=['#e74c3c', '#f39c12', '#95a5a6', '#3498db', '#9b59b6'])
plt.xlabel('Sentiment', fontsize=12)
plt.ylabel('F1 Score', fontsize=12)
plt.title('Per-Class F1 Scores - Text Model', fontsize=14, fontweight='bold')
plt.ylim(0, 1.0)
plt.xticks(rotation=45, ha='right')
plt.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, score in zip(bars, f1_scores):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{score:.3f}',
             ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig(f"{RESULTS_DIR}/f1_scores_per_class.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Per-class F1 scores plot saved to {RESULTS_DIR}/f1_scores_per_class.png")

## 12. Summary

In [None]:
print("=" * 80)
print("TEXT MODEL TRAINING COMPLETE!")
print("=" * 80)
print("\nFiles Generated:")
print(f"  1. Model weights: {MODEL_DIR}/text_specialist.pth")
print(f"  2. Best model weights: {MODEL_DIR}/text_specialist_best.pth")
print(f"  3. Training history: {RESULTS_DIR}/training_history.json")
print(f"  4. Evaluation metrics: {RESULTS_DIR}/evaluation_metrics.json")
print(f"  5. Evaluation report: {RESULTS_DIR}/evaluation_report.txt")
print(f"  6. Training curves: {RESULTS_DIR}/training_curves.png")
print(f"  7. Confusion matrix: {RESULTS_DIR}/confusion_matrix.png")
print(f"  8. F1 scores per class: {RESULTS_DIR}/f1_scores_per_class.png")

print("\nFinal Performance:")
print(f"  Accuracy: {val_acc:.4f}")
print(f"  Weighted F1: {val_f1:.4f}")
print(f"  Best Epoch: {best_epoch}/{EPOCHS}")

print("\nPer-Class F1 Scores:")
for label in LABELS:
    f1 = metrics['per_class'][label]['f1_score']
    support = metrics['per_class'][label]['support']
    print(f"  {label:15s}: {f1:.4f} (n={support})")

print("\n" + "=" * 80)
print("Next Step: Train Image Model (CLIP) - Step 4B")
print("=" * 80)