# SAINT Training for Bike Sharing Regression

This notebook demonstrates the training and evaluation of a **SAINT (Self-Attention and Intersample Attention Transformer)** model for bike sharing demand prediction.

## SAINT Architecture Overview

SAINT is a transformer-based architecture specifically designed for tabular data that combines:
- **Self-attention mechanisms** to capture feature interactions within samples
- **Intersample attention** to learn patterns across different samples in a batch
- **Feature embeddings** for numerical features
- **Positional encoding** to maintain feature order information

## Key Features
- Comprehensive logging to file
- Model checkpointing (both .pth and .pkl formats)
- Detailed training and evaluation plots
- Performance metrics tracking
- Early stopping with patience
- Learning rate scheduling

## 1. Setup and Imports

In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("📚 Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [None]:
# Import SAINT training functions
try:
    from saint_training_functions import *
    print("✅ SAINT training functions imported successfully")
except ImportError as e:
    print("❌ SAINT training functions not available.")
    print("Please ensure saint_training_functions.py is in the same directory.")
    print(f"Error: {e}")
    raise

## 2. Configuration and Device Setup

In [None]:
# Set device and configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🖥️ Using device: {device}")

# Training configuration
config = {
    'data_path': './bike_sharing_preprocessed_data.pkl',
    'device': device,
    'batch_size': 256,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'd_model': 128,
    'n_heads': 8,
    'n_layers': 6,
    'save_dir': './Section2_Model_Training'
}

print("⚙️ Configuration:")
for key, value in config.items():
    print(f"   {key}: {value}")

## 3. Data Loading and Preparation

In [None]:
# Load preprocessed data
print("📊 Loading and preparing data...")

(X_train_scaled, X_val_scaled, X_test_scaled, 
 y_train, y_val, y_test, feature_names, data_summary) = load_preprocessed_data(config['data_path'])

print(f"\n📈 Data Summary:")
print(f"   Training samples: {X_train_scaled.shape[0]:,}")
print(f"   Validation samples: {X_val_scaled.shape[0]:,}")
print(f"   Test samples: {X_test_scaled.shape[0]:,}")
print(f"   Features: {len(feature_names)}")
print(f"   Target range: [{y_train.min():.0f}, {y_train.max():.0f}]")

In [None]:
# Display feature names
print("🏷️ Feature Names:")
for i, name in enumerate(feature_names):
    print(f"   {i+1:2d}. {name}")

In [None]:
# Prepare data for training
(train_loader, val_loader, test_loader,
 X_train_tensor, X_val_tensor, X_test_tensor,
 y_train_tensor, y_val_tensor, y_test_tensor) = prepare_data_for_training(
    X_train_scaled, X_val_scaled, X_test_scaled, 
    y_train, y_val, y_test, feature_names, 
    config['device'], config['batch_size']
)

## 4. Model Creation and Architecture

In [None]:
# Create SAINT model
model, total_params = create_saint_model(
    n_features=len(feature_names),
    device=config['device'],
    d_model=config['d_model'],
    n_heads=config['n_heads'],
    n_layers=config['n_layers']
)

print(f"\n🏗️ SAINT Model Architecture:")
print(f"   Input features: {len(feature_names)}")
print(f"   Model dimension: {config['d_model']}")
print(f"   Attention heads: {config['n_heads']}")
print(f"   Transformer layers: {config['n_layers']}")
print(f"   Total parameters: {total_params:,}")
print(f"   Model size: ~{total_params * 4 / 1024 / 1024:.2f} MB")

## 5. Training Setup

In [None]:
# Setup training components
criterion, optimizer, scheduler, training_config = setup_training(
    model, 
    learning_rate=config['learning_rate'],
    weight_decay=config['weight_decay']
)

print(f"\n🔧 Training Configuration:")
print(f"   Learning rate: {training_config['learning_rate']}")
print(f"   Weight decay: {training_config['weight_decay']}")
print(f"   Max epochs: {training_config['n_epochs']}")
print(f"   Early stopping patience: {training_config['patience']}")
print(f"   Loss function: MSE (regression)")
print(f"   Optimizer: AdamW")
print(f"   Scheduler: ReduceLROnPlateau")

## 6. Model Training

In [None]:
# Setup logging
logger = setup_logging(config['save_dir'])
logger.info("Starting SAINT training from notebook")

print("🚀 Starting SAINT model training...")
print("📝 Training progress will be logged to file and displayed here.")
print("-" * 80)

In [None]:
# Train the model
model, history, best_epoch, training_time = train_saint_model(
    model, train_loader, val_loader, criterion, optimizer, scheduler, 
    training_config, config['device'], logger
)

print(f"\n✅ Training completed!")
print(f"   Best epoch: {best_epoch + 1}")
print(f"   Training time: {training_time:.2f} seconds")
print(f"   Final validation R²: {history['val_r2'][best_epoch]:.4f}")

## 7. Model Evaluation

In [None]:
# Evaluate the trained model
predictions, metrics = evaluate_model(
    model, X_test_tensor, y_test_tensor, config['device'], logger
)

print(f"\n📊 Final Test Performance:")
print(f"   R² Score: {metrics['r2_score']:.4f}")
print(f"   RMSE: {metrics['rmse']:.4f}")
print(f"   MAE: {metrics['mae']:.4f}")
print(f"   MAPE: {metrics['mape']:.2f}%")
print(f"   Explained Variance: {metrics['explained_variance']:.4f}")

## 8. Training Visualization

In [None]:
# Create training plots
create_training_plots(history, best_epoch, config['save_dir'])

## 9. Evaluation Visualization

In [None]:
# Create evaluation plots
create_evaluation_plots(y_test, predictions, config['save_dir'])

## 10. Results Analysis

In [None]:
# Display training history
history_df = pd.DataFrame(history)
print("📈 Training History (last 10 epochs):")
print(history_df.tail(10).round(4))

In [None]:
# Analyze predictions
predictions_df = pd.DataFrame({
    'actual': y_test,
    'predicted': predictions,
    'residuals': y_test - predictions,
    'absolute_error': np.abs(y_test - predictions)
})

print("🔍 Prediction Analysis:")
print(predictions_df.describe().round(2))

In [None]:
# Error distribution analysis
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.hist(predictions_df['residuals'], bins=30, alpha=0.7, color='blue')
plt.title('Residuals Distribution')
plt.xlabel('Residuals')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.hist(predictions_df['absolute_error'], bins=30, alpha=0.7, color='red')
plt.title('Absolute Error Distribution')
plt.xlabel('Absolute Error')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 3)
plt.scatter(predictions_df['actual'], predictions_df['absolute_error'], alpha=0.6)
plt.title('Error vs Actual Values')
plt.xlabel('Actual Bike Count')
plt.ylabel('Absolute Error')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 11. Save Results

In [None]:
# Save all results
save_results(
    model, history, metrics, predictions, y_test, feature_names, 
    training_time, total_params, config['save_dir'], logger
)

print("\n💾 All results saved successfully!")
print(f"📁 Check the '{config['save_dir']}' directory for:")
print("   - Training history CSV")
print("   - Evaluation metrics CSV")
print("   - Predictions CSV")
print("   - Model checkpoints (.pth and .pkl)")
print("   - Training and evaluation plots")
print("   - Complete training log")

## 12. Model Summary

In [None]:
# Final summary
print("🎉 SAINT Training Summary")
print("=" * 50)
print(f"📊 Model Performance:")
print(f"   R² Score: {metrics['r2_score']:.4f}")
print(f"   RMSE: {metrics['rmse']:.4f}")
print(f"   MAE: {metrics['mae']:.4f}")
print(f"   MAPE: {metrics['mape']:.2f}%")
print(f"\n⚙️ Model Configuration:")
print(f"   Architecture: SAINT Transformer")
print(f"   Parameters: {total_params:,}")
print(f"   Training time: {training_time:.2f} seconds")
print(f"   Best epoch: {best_epoch + 1}")
print(f"\n🚀 Model ready for deployment and comparison!")

logger.info("SAINT training notebook completed successfully")

## 13. Quick Model Loading Test

In [None]:
# Test model loading from saved checkpoint
import pickle

print("🔄 Testing model loading from saved checkpoint...")

# Load model data
with open(f"{config['save_dir']}/saint_model.pkl", 'rb') as f:
    saved_model_data = pickle.load(f)

print("✅ Model loaded successfully!")
print(f"   Saved metrics: R² = {saved_model_data['metrics']['r2_score']:.4f}")
print(f"   Model architecture: {saved_model_data['model_architecture']}")
print(f"   Feature names: {len(saved_model_data['feature_names'])} features")
print(f"   Training time: {saved_model_data['training_time']:.2f} seconds")