# Berghain LSTM Policy Network Training

This notebook trains an LSTM neural network to learn optimal admission decisions for the Berghain nightclub challenge using supervised learning on historical game data.

## Features
- **Supervised Learning**: Trains on successful game strategies
- **Real-time Visualization**: Shows training progress with loss and accuracy plots
- **Model Checkpointing**: Saves best model automatically
- **GPU Acceleration**: Uses CUDA when available
- **Comprehensive Logging**: Detailed training metrics

## Setup Instructions
1. Upload your game logs to the `game_logs/` directory
2. Run all cells in order
3. Download the trained model from the `models/` directory

## 1. Setup and Installation

In [None]:
# Install required packages
!pip install torch torchvision numpy matplotlib scikit-learn pyyaml requests

# Import standard libraries
import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from datetime import datetime
from pathlib import Path
import logging
from typing import Dict, List, Tuple, Any
from dataclasses import dataclass

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on: {device}")

## 2. Upload and Prepare Data

Upload your game logs to Google Colab. The logs should be in JSON format from successful games.

In [None]:
# Create directories
!mkdir -p game_logs models data

# For Google Colab, upload files
from google.colab import files
import zipfile
import shutil

print("Please upload your game logs.")
print("You can either:")
print("1. Upload a ZIP file containing all game logs")
print("2. Upload individual JSON files")

uploaded = files.upload()

# Process uploaded files
for filename in uploaded.keys():
    if filename.endswith('.zip'):
        # Extract ZIP file
        with zipfile.ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall('game_logs/')
        print(f"Extracted {filename} to game_logs/")
    elif filename.endswith('.json'):
        # Move JSON file to game_logs
        shutil.move(filename, f'game_logs/{filename}')
        print(f"Moved {filename} to game_logs/")

# List uploaded files
game_files = [f for f in os.listdir('game_logs') if f.endswith('.json') and not 'consolidated' in f]
print(f"\nFound {len(game_files)} game log files")
if game_files:
    print("Sample files:")
    for f in game_files[:5]:
        print(f"  {f}")

## 3. Define Model Architecture

LSTM-based policy network for sequential decision making.

In [None]:
class LSTMPolicyNetwork(nn.Module):
    """
    LSTM-based policy network for sequential decision making in Berghain game.
    
    Architecture:
    - LSTM layers for temporal modeling
    - Policy head for action probabilities (accept/reject)
    - Value head for state value estimation (for PPO)
    """
    
    def __init__(
        self,
        input_dim: int = 8,
        hidden_dim: int = 128,
        lstm_layers: int = 2,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.lstm_layers = lstm_layers
        
        # LSTM backbone for sequential modeling
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=dropout if lstm_layers > 1 else 0
        )
        
        # Policy head (action probabilities)
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 2),  # Binary: [reject_prob, accept_prob]
            nn.Softmax(dim=-1)
        )
        
        # Value head (state value estimation)
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)  # Single scalar value
        )
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize network weights using Xavier initialization."""
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)
    
    def forward(
        self, 
        x: torch.Tensor, 
        hidden: tuple = None
    ) -> Tuple[torch.Tensor, torch.Tensor, tuple]:
        # LSTM forward pass
        lstm_out, hidden_new = self.lstm(x, hidden)
        
        # Apply heads
        policy = self.policy_head(lstm_out)
        value = self.value_head(lstm_out)
        
        return policy, value, hidden_new
    
    def set_training_history(self, history_dict: dict) -> None:
        """Set training history for the model."""
        self.training_history = history_dict
    
    def get_training_history(self) -> dict:
        """Get training history from the model."""
        return getattr(self, 'training_history', {})

print("Model architecture defined successfully!")

## 4. Data Preprocessing

Convert game logs into training sequences for the LSTM.

In [None]:
class GameDataPreprocessor:
    """
    Preprocesses game log data for supervised learning.
    
    Converts decision sequences from game logs into training data where:
    - X: Sequential game state features
    - y: Expert decision labels (0=reject, 1=accept)
    """
    
    def __init__(self, sequence_length: int = 50):
        self.sequence_length = sequence_length
        self.feature_dim = 8  # Matches StateEncoder feature count
        
    def load_game_logs(self, log_directory: str) -> List[Dict[str, Any]]:
        """Load all individual game JSON files from directory."""
        games = []
        
        for filename in os.listdir(log_directory):
            if (filename.startswith('game_') and 
                filename.endswith('.json') and 
                'consolidated' not in filename):
                
                filepath = os.path.join(log_directory, filename)
                try:
                    with open(filepath, 'r') as f:
                        game_data = json.load(f)
                        # Only include successful games for supervised learning
                        if game_data.get('success', False):
                            games.append(game_data)
                except Exception as e:
                    logger.warning(f"Error loading {filename}: {e}")
        
        logger.info(f"Loaded {len(games)} successful games from {log_directory}")
        return games
    
    def extract_features_and_labels(self, game_data: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
        """Extract feature sequences and decision labels from a single game."""
        decisions = game_data['decisions']
        features = []
        labels = []
        
        # Track game state progression
        admitted_count = 0
        rejected_count = 0
        constraints = game_data['constraints']
        
        # Get constraint targets
        young_target = next((c['min_count'] for c in constraints if c['attribute'] == 'young'), 600)
        well_dressed_target = next((c['min_count'] for c in constraints if c['attribute'] == 'well_dressed'), 600)
        
        # Track admitted attributes
        young_admitted = 0
        well_dressed_admitted = 0
        
        for i, decision in enumerate(decisions):
            person_attrs = decision['attributes']
            decision_made = decision['decision']
            
            # Person attributes
            well_dressed = 1.0 if person_attrs.get('well_dressed', False) else 0.0
            young = 1.0 if person_attrs.get('young', False) else 0.0
            
            # Constraint progress
            constraint_progress_y = min(young_admitted / young_target, 1.0) if young_target > 0 else 1.0
            constraint_progress_w = min(well_dressed_admitted / well_dressed_target, 1.0) if well_dressed_target > 0 else 1.0
            
            # Capacity and rejection ratios
            total_decisions = admitted_count + rejected_count
            capacity_ratio = admitted_count / 1000.0  # Max capacity
            rejection_ratio = rejected_count / 20000.0 if total_decisions > 0 else 0.0  # Max rejections
            
            # Game phase
            if admitted_count < 300:
                game_phase = 0.0  # Early
            elif admitted_count < 700:
                game_phase = 0.5  # Mid
            else:
                game_phase = 1.0  # Late
            
            # Person index normalized
            person_index_norm = min(i / 25000, 1.0)
            
            # Create feature vector
            feature_vector = np.array([
                well_dressed, young, constraint_progress_y, constraint_progress_w,
                capacity_ratio, rejection_ratio, game_phase, person_index_norm
            ], dtype=np.float32)
            
            features.append(feature_vector)
            labels.append(1 if decision_made else 0)
            
            # Update game state for next iteration
            if decision_made:
                admitted_count += 1
                if person_attrs.get('young', False):
                    young_admitted += 1
                if person_attrs.get('well_dressed', False):
                    well_dressed_admitted += 1
            else:
                rejected_count += 1
        
        return np.array(features), np.array(labels, dtype=np.int64)
    
    def create_sequences(self, features: np.ndarray, labels: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Create fixed-length sequences for LSTM training."""
        if len(features) <= self.sequence_length:
            # Pad short sequences
            pad_length = self.sequence_length - len(features)
            padded_features = np.pad(features, ((0, pad_length), (0, 0)), mode='constant')
            padded_labels = np.pad(labels, (0, pad_length), mode='constant')
            return padded_features[np.newaxis, :, :], padded_labels[np.newaxis, :]
        
        # Create overlapping windows for longer sequences
        sequences_features = []
        sequences_labels = []
        
        for i in range(0, len(features) - self.sequence_length + 1, self.sequence_length // 2):
            end_idx = i + self.sequence_length
            sequences_features.append(features[i:end_idx])
            sequences_labels.append(labels[i:end_idx])
        
        return np.array(sequences_features), np.array(sequences_labels)
    
    def prepare_dataset(self, games: List[Dict[str, Any]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """Prepare complete dataset from list of games."""
        all_sequences = []
        all_labels = []
        
        for game in games:
            try:
                features, labels = self.extract_features_and_labels(game)
                seq_features, seq_labels = self.create_sequences(features, labels)
                
                # Convert to tensors and add to lists
                for i in range(len(seq_features)):
                    all_sequences.append(torch.tensor(seq_features[i], dtype=torch.float32))
                    all_labels.append(torch.tensor(seq_labels[i], dtype=torch.long))
                    
            except Exception as e:
                logger.warning(f"Error processing game {game.get('game_id', 'unknown')}: {e}")
        
        logger.info(f"Created {len(all_sequences)} training sequences")
        return all_sequences, all_labels


class SequenceDataset(Dataset):
    """Dataset class for sequence training data."""
    
    def __init__(self, data_tuples: List[Tuple[torch.Tensor, torch.Tensor]]):
        self.sequences = [item[0] for item in data_tuples]
        self.labels = [item[1] for item in data_tuples]
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        return self.sequences[idx], self.labels[idx]


# Prepare the training data
print("Loading and preprocessing game data...")
processor = GameDataPreprocessor(sequence_length=50)

# Load games
games = processor.load_game_logs('game_logs')
if not games:
    raise ValueError("No successful games found in game_logs directory")

# Prepare dataset
sequences, labels = processor.prepare_dataset(games)

# Split into train/validation
train_sequences, val_sequences, train_labels, val_labels = train_test_split(
    sequences, labels, test_size=0.2, random_state=42, shuffle=True
)

# Combine into tuples
train_data = list(zip(train_sequences, train_labels))
val_data = list(zip(val_sequences, val_labels))

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

# Create data loaders
batch_size = 32
train_dataset = SequenceDataset(train_data)
val_dataset = SequenceDataset(val_data)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 5. Training Configuration and Setup

In [None]:
# Training configuration
@dataclass
class TrainingConfig:
    # Model parameters
    input_dim: int = 8
    hidden_dim: int = 128
    lstm_layers: int = 2
    dropout: float = 0.1
    
    # Training parameters
    epochs: int = 50
    learning_rate: float = 0.001
    weight_decay: float = 1e-5
    patience: int = 10  # Early stopping patience
    
    # Paths
    save_path: str = 'models/lstm_berghain'

config = TrainingConfig()

# Initialize model
model = LSTMPolicyNetwork(
    input_dim=config.input_dim,
    hidden_dim=config.hidden_dim,
    lstm_layers=config.lstm_layers,
    dropout=config.dropout
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {total_params:,} (trainable: {trainable_params:,})")

# Initialize optimizer and criterion
optimizer = optim.Adam(
    model.parameters(), 
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)
criterion = nn.CrossEntropyLoss()

# Training history tracking
training_history = {
    'train_loss': [],
    'val_loss': [],
    'train_accuracy': [],
    'val_accuracy': [],
    'epochs': [],
    'best_epoch': 0,
    'best_val_accuracy': 0.0
}

print("Training setup completed!")

## 6. Training Loop with Real-time Visualization

In [None]:
# Training and validation functions
def train_epoch(model, train_loader, optimizer, criterion, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    
    for batch_idx, (sequences, labels) in enumerate(train_loader):
        sequences = sequences.to(device)
        labels = labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        policy, _, _ = model(sequences)
        
        # Reshape for loss calculation
        policy_flat = policy.view(-1, 2)  # (batch * seq_len, 2)
        labels_flat = labels.view(-1)  # (batch * seq_len,)
        
        # Calculate loss
        loss = criterion(policy_flat, labels_flat)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        _, predicted = torch.max(policy_flat, 1)
        total_predictions += labels_flat.size(0)
        correct_predictions += (predicted == labels_flat).sum().item()
    
    avg_loss = total_loss / len(train_loader)
    accuracy = 100.0 * correct_predictions / total_predictions
    
    return avg_loss, accuracy


def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    total_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    
    with torch.no_grad():
        for sequences, labels in val_loader:
            sequences = sequences.to(device)
            labels = labels.to(device)
            
            # Forward pass
            policy, _, _ = model(sequences)
            
            # Reshape for loss calculation
            policy_flat = policy.view(-1, 2)
            labels_flat = labels.view(-1)
            
            # Calculate loss
            loss = criterion(policy_flat, labels_flat)
            total_loss += loss.item()
            
            # Track accuracy
            _, predicted = torch.max(policy_flat, 1)
            total_predictions += labels_flat.size(0)
            correct_predictions += (predicted == labels_flat).sum().item()
    
    avg_loss = total_loss / len(val_loader)
    accuracy = 100.0 * correct_predictions / total_predictions
    
    return avg_loss, accuracy


def plot_training_progress(history):
    """Plot training progress."""
    if not history['epochs']:
        return
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    epochs = history['epochs']
    
    # Plot losses
    ax1.plot(epochs, history['train_loss'], label='Train Loss', color='blue')
    ax1.plot(epochs, history['val_loss'], label='Validation Loss', color='red')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot accuracies
    ax2.plot(epochs, history['train_accuracy'], label='Train Accuracy', color='green')
    ax2.plot(epochs, history['val_accuracy'], label='Validation Accuracy', color='orange')
    ax2.set_title('Training and Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()


# Main training loop
print("Starting LSTM training pipeline...")
print("Preparing training data...")

best_val_accuracy = 0.0
patience_counter = 0
start_time = datetime.now()

for epoch in range(1, config.epochs + 1):
    epoch_start = datetime.now()
    
    # Train and validate
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    
    # Update history
    training_history['train_loss'].append(train_loss)
    training_history['val_loss'].append(val_loss)
    training_history['train_accuracy'].append(train_acc)
    training_history['val_accuracy'].append(val_acc)
    training_history['epochs'].append(epoch)
    
    # Check for best model
    is_best = val_acc > best_val_accuracy
    if is_best:
        best_val_accuracy = val_acc
        training_history['best_epoch'] = epoch
        training_history['best_val_accuracy'] = val_acc
        patience_counter = 0
        
        # Save best model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'training_history': training_history,
            'config': config.__dict__
        }, f"{config.save_path}_best.pth")
        
        print(f"‚úÖ New best model saved with validation accuracy: {val_acc:.2f}%")
    else:
        patience_counter += 1
    
    # Log progress
    epoch_time = (datetime.now() - epoch_start).total_seconds()
    print(f"Epoch {epoch:2d}/{config.epochs} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.1f}% | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.1f}% | "
          f"Time: {epoch_time:.1f}s")
    
    # Plot progress every 10 epochs
    if epoch % 10 == 0:
        plot_training_progress(training_history)
    
    # Early stopping check
    if patience_counter >= config.patience:
        print(f"üõë Early stopping triggered after {epoch} epochs")
        break

# Training completed
total_time = (datetime.now() - start_time).total_seconds()
print(f"\nüéâ Training completed in {total_time/60:.1f} minutes")
print(f"üìä Best validation accuracy: {best_val_accuracy:.2f}% at epoch {training_history['best_epoch']}")

# Set training history in model
model.set_training_history(training_history)

# Save final model
torch.save({
    'model_state_dict': model.state_dict(),
    'training_history': training_history,
    'config': config.__dict__
}, f"{config.save_path}_final.pth")

print(f"üíæ Model saved to {config.save_path}_final.pth")
print("üìà Training history saved in model")

# Final training plot
plot_training_progress(training_history)

## 7. Model Evaluation and Testing

In [None]:
# Load the best model
checkpoint = torch.load(f"{config.save_path}_best.pth", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("Testing model inference...")

# Test model on validation set
test_correct = 0
test_total = 0
accept_correct = 0
accept_total = 0
reject_correct = 0
reject_total = 0

with torch.no_grad():
    for sequences, labels in val_loader:
        sequences = sequences.to(device)
        labels = labels.to(device)
        
        # Forward pass
        policy, _, _ = model(sequences)
        policy_flat = policy.view(-1, 2)
        labels_flat = labels.view(-1)
        
        _, predicted = torch.max(policy_flat, 1)
        
        # Overall accuracy
        test_total += labels_flat.size(0)
        test_correct += (predicted == labels_flat).sum().item()
        
        # Per-class accuracy
        accept_mask = labels_flat == 1
        reject_mask = labels_flat == 0
        
        if accept_mask.sum() > 0:
            accept_total += accept_mask.sum().item()
            accept_correct += (predicted[accept_mask] == labels_flat[accept_mask]).sum().item()
        
        if reject_mask.sum() > 0:
            reject_total += reject_mask.sum().item()
            reject_correct += (predicted[reject_mask] == labels_flat[reject_mask]).sum().item()

# Calculate metrics
overall_accuracy = 100.0 * test_correct / test_total
accept_accuracy = 100.0 * accept_correct / accept_total if accept_total > 0 else 0
reject_accuracy = 100.0 * reject_correct / reject_total if reject_total > 0 else 0

print("\nüìä Final Model Performance:")
print(f"  Overall Accuracy: {overall_accuracy:.2f}%")
print(f"  Accept Accuracy: {accept_accuracy:.2f}% ({accept_correct}/{accept_total})")
print(f"  Reject Accuracy: {reject_accuracy:.2f}% ({reject_correct}/{reject_total})")
print(f"  Total Test Samples: {test_total:,}")

# Display training summary
print("\nüìà Training Summary:")
print(f"  Best Epoch: {training_history['best_epoch']}")
print(f"  Best Validation Accuracy: {training_history['best_val_accuracy']:.2f}%")
print(f"  Final Training Loss: {training_history['train_loss'][-1]:.4f}")
print(f"  Final Validation Loss: {training_history['val_loss'][-1]:.4f}")
print(f"  Total Epochs Trained: {len(training_history['epochs'])}")

print("\nTraining complete! Model is ready for inference. üéâ")

## 8. Final Training Visualization

In [None]:
# Plot comprehensive training history
def plot_comprehensive_history(history):
    """Plot comprehensive training history."""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    epochs = history['epochs']
    
    # Plot losses
    ax1.plot(epochs, history['train_loss'], label='Train Loss', color='blue', linewidth=2)
    ax1.plot(epochs, history['val_loss'], label='Validation Loss', color='red', linewidth=2)
    ax1.axvline(x=history['best_epoch'], color='green', linestyle='--', alpha=0.7, label='Best Epoch')
    ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot accuracies
    ax2.plot(epochs, history['train_accuracy'], label='Train Accuracy', color='green', linewidth=2)
    ax2.plot(epochs, history['val_accuracy'], label='Validation Accuracy', color='orange', linewidth=2)
    ax2.axvline(x=history['best_epoch'], color='green', linestyle='--', alpha=0.7, label='Best Epoch')
    ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Plot overfitting indicator
    loss_diff = [v - t for t, v in zip(history['train_loss'], history['val_loss'])]
    ax3.plot(epochs, loss_diff, label='Val - Train Loss', color='purple', linewidth=2)
    ax3.axhline(y=0, color='black', linestyle='-', alpha=0.5)
    ax3.axvline(x=history['best_epoch'], color='green', linestyle='--', alpha=0.7)
    ax3.set_title('Overfitting Indicator (Val Loss - Train Loss)', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Loss Difference')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # Summary statistics
    ax4.text(0.05, 0.90, f"Best Epoch: {history['best_epoch']}", transform=ax4.transAxes, fontsize=12, fontweight='bold')
    ax4.text(0.05, 0.80, f"Best Val Accuracy: {history['best_val_accuracy']:.2f}%", transform=ax4.transAxes, fontsize=12)
    ax4.text(0.05, 0.70, f"Final Train Loss: {history['train_loss'][-1]:.4f}", transform=ax4.transAxes, fontsize=12)
    ax4.text(0.05, 0.60, f"Final Val Loss: {history['val_loss'][-1]:.4f}", transform=ax4.transAxes, fontsize=12)
    ax4.text(0.05, 0.50, f"Total Epochs: {len(history['epochs'])}", transform=ax4.transAxes, fontsize=12)
    ax4.text(0.05, 0.40, f"Training Samples: {len(train_data):,}", transform=ax4.transAxes, fontsize=12)
    ax4.text(0.05, 0.30, f"Validation Samples: {len(val_data):,}", transform=ax4.transAxes, fontsize=12)
    ax4.text(0.05, 0.20, f"Model Parameters: {total_params:,}", transform=ax4.transAxes, fontsize=12)
    ax4.text(0.05, 0.10, f"Device: {device}", transform=ax4.transAxes, fontsize=12)
    ax4.set_title('Training Summary', fontsize=14, fontweight='bold')
    ax4.axis('off')
    
    plt.suptitle('Berghain LSTM Training Results', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Generate final comprehensive plot
plot_comprehensive_history(training_history)

# Save training history as JSON for later analysis
with open(f"{config.save_path}_history.json", 'w') as f:
    json.dump(training_history, f, indent=2)

print(f"üìÅ Training history saved to {config.save_path}_history.json")

## 9. Download Trained Model

Download the trained model files to use in your local environment.

In [None]:
from google.colab import files
import zipfile
import os

# Create a ZIP file with all model artifacts
zip_filename = 'berghain_lstm_trained_model.zip'

with zipfile.ZipFile(zip_filename, 'w') as zipf:
    # Add model files
    if os.path.exists(f"{config.save_path}_best.pth"):
        zipf.write(f"{config.save_path}_best.pth", "lstm_berghain_best.pth")
    if os.path.exists(f"{config.save_path}_final.pth"):
        zipf.write(f"{config.save_path}_final.pth", "lstm_berghain_final.pth")
    if os.path.exists(f"{config.save_path}_history.json"):
        zipf.write(f"{config.save_path}_history.json", "training_history.json")

print(f"üì¶ Model files packaged in {zip_filename}")
print("\nüéØ Download includes:")
print("  ‚Ä¢ lstm_berghain_best.pth - Best performing model checkpoint")
print("  ‚Ä¢ lstm_berghain_final.pth - Final training checkpoint")
print("  ‚Ä¢ training_history.json - Complete training metrics")

# Download the ZIP file
files.download(zip_filename)

print("\n‚úÖ Download started! Check your browser's download folder.")
print("\nüöÄ Model is ready to use for Berghain game inference!")
print("\nüí° Usage in your local environment:")
print("```python")
print("import torch")
print("from berghain.training.lstm_policy import LSTMPolicyNetwork")
print("")
print("# Load the trained model")
print("model = LSTMPolicyNetwork()")
print("checkpoint = torch.load('lstm_berghain_best.pth')")
print("model.load_state_dict(checkpoint['model_state_dict'])")
print("")
print("# Access training history")
print("history = checkpoint['training_history']")
print("print(f'Best accuracy: {history[\"best_val_accuracy\"]:.2f}%')")
print("```")