# Phase 2: Foundation Model Fine-Tuning

This notebook demonstrates fine-tuning pre-trained foundation models (MOMENT, Chronos, Moirai) for EV charging demand prediction.

**Objectives**:
- Fine-tune MOMENT model on Shenzhen dataset
- Compare with baseline models from Phase 1
- Target: >10% MAE improvement over best baseline

**Strategy**:
1. Start with MOMENT (best multivariate support)
2. Progressive unfreezing (head → last 2 layers → full model)
3. Incorporate auxiliary features (weather, pricing)

In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Add parent directory to path
sys.path.insert(0, os.path.abspath('..'))

from api.dataset.common import EVDataset
from api.model.foundation import MOMENTForecaster, load_foundation_model
from api.utils import calculate_regression_metrics
from experiment.utils.experiment_tracking import ExperimentTracker

print("✓ Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## Configuration

In [None]:
# Configuration
CONFIG = {
    'city': 'SZH',  # Shenzhen - largest dataset
    'model_name': 'moment',  # Start with MOMENT
    'model_size': 'small',  # Use small for faster training initially
    'feature': 'volume',
    'auxiliary': 'all',  # Use all auxiliary features
    
    # Data configuration
    'context_length': 168,  # 1 week of hourly data
    'forecast_horizon': 24,  # 1 day ahead prediction
    'use_segments': False,  # Set True for segment-based (MOMENT-specific)
    
    # Training configuration
    'batch_size': 16,  # Smaller batch for foundation models
    'epochs': 30,
    'learning_rate_head': 1e-3,
    'learning_rate_encoder': 1e-4,
    'weight_decay': 1e-5,
    
    # Progressive unfreezing
    'freeze_encoder': True,  # Start with frozen encoder
    'unfreeze_at_epoch': 10,  # Unfreeze last 2 layers after epoch 10
    'unfreeze_layers': 2,
    
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'seed': 42
}

# Set random seeds for reproducibility
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['seed'])

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## Load Dataset

In [None]:
# Load EV dataset
data_path = f'../data/{CONFIG["city"]}_remove_zero/'
print(f"Loading dataset from {data_path}...")

ev_dataset = EVDataset(
    feature=CONFIG['feature'],
    auxiliary=CONFIG['auxiliary'],
    data_path=data_path,
    pred_type='site',
    seq_l=CONFIG['context_length'],
    pre_len=CONFIG['forecast_horizon']
)

# Split data (80% train, 10% val, 10% test)
ev_dataset.split_cross_validation(
    fold=1,
    total_fold=6,
    train_ratio=0.8,
    valid_ratio=0.1
)

print(f"\n✓ Dataset loaded:")
print(f"  Sites: {ev_dataset.feat.shape[1]}")
print(f"  Total timesteps: {ev_dataset.feat.shape[0]}")
print(f"  Training samples: {len(ev_dataset.train_feat)}")
print(f"  Validation samples: {len(ev_dataset.valid_feat)}")
print(f"  Test samples: {len(ev_dataset.test_feat)}")

if ev_dataset.extra_feat.size > 0:
    n_aux = ev_dataset.extra_feat.shape[1]
    print(f"  Auxiliary features: {n_aux}")
else:
    n_aux = 0
    print(f"  Auxiliary features: None")

## Initialize Foundation Model

In [None]:
# Initialize experiment tracker
experiment_name = f"foundation_{CONFIG['model_name']}_{CONFIG['city']}"
tracker = ExperimentTracker(experiment_name, save_dir='../results/foundation_models')

tracker.log_hyperparameters(CONFIG)

# Load foundation model
print(f"\nLoading {CONFIG['model_name']} model...")

try:
    model = load_foundation_model(
        model_name=CONFIG['model_name'],
        model_size=CONFIG['model_size'],
        n_aux_features=n_aux,
        prediction_length=CONFIG['forecast_horizon'],
        context_length=CONFIG['context_length']
    )
    
    model = model.to(CONFIG['device'])
    
    # Freeze encoder if specified
    if CONFIG['freeze_encoder']:
        model.freeze_encoder()
    
    print(f"\n✓ Model loaded successfully")
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    print(f"  Frozen parameters: {total_params - trainable_params:,}")
    
except Exception as e:
    print(f"\n❌ Error loading model: {str(e)}")
    print("\nNote: This notebook requires foundation model libraries.")
    print("Install with:")
    print("  pip install momentfm  # for MOMENT")
    print("  pip install chronos-forecasting  # for Chronos")
    print("  pip install uni2ts  # for Moirai")
    raise

## Prepare Data Loaders

In [None]:
# NOTE: This is a simplified data preparation.
# For production, you'll need to adapt this based on the specific
# foundation model's expected input format.

print("Preparing data loaders...")
print("\n⚠️  Note: This is a placeholder implementation.")
print("Each foundation model (MOMENT/Chronos/Moirai) has specific")
print("input format requirements that need to be implemented based")
print("on their respective APIs.")

# Placeholder for data preparation
# In practice, you'll need to:
# 1. Format data according to model requirements
# 2. Handle auxiliary features appropriately
# 3. Create proper sliding windows
# 4. Normalize/scale data

print("\n✓ Data loaders prepared (placeholder)")

## Training Loop with Progressive Unfreezing

In [None]:
print("\n" + "="*60)
print("TRAINING SETUP")
print("="*60)

print("\n⚠️  This is a template implementation.")
print("\nThe actual training loop needs to be customized for each foundation model:")
print("\n1. MOMENT: Uses MOMENTPipeline with specific forecast() method")
print("2. Chronos: Requires ChronosPipeline.predict() with context")
print("3. Moirai: Uses GluonTS dataset format with MoiraiForecast")

print("\nFor now, please refer to:")
print("  - example/test_moment_all.py for MOMENT integration")
print("  - example/test_chronos_all.py for Chronos integration")
print("  - example/test_moirai_all.py for Moirai integration")

print("\nThese scripts show how to properly use each model's API.")
print("You'll need to adapt them for fine-tuning with PyTorch optimizers.")

## Next Steps

To complete this notebook, you need to:

### 1. Study the existing test scripts
- `example/test_moment_all.py` - Shows MOMENT data preparation and inference
- `example/test_chronos_all.py` - Shows Chronos usage
- `example/test_moirai_all.py` - Shows Moirai with GluonTS

### 2. Implement model-specific training
Each foundation model has a different API:
- **MOMENT**: Likely supports `.fit()` or requires custom training loop
- **Chronos**: T5-based, may need HuggingFace Trainer
- **Moirai**: GluonTS integration, custom training needed

### 3. Alternative: Use existing scripts as baseline
You can:
1. Run the test scripts to get zero-shot baseline performance
2. Compare with Phase 1 baseline results
3. Then implement fine-tuning for the best-performing model

### 4. Recommended Approach
Start with the **simplest path**:
1. Run `python example/test_moment_all.py` to get MOMENT zero-shot results
2. Compare with baseline MAE from Phase 1
3. If MOMENT already beats baselines → document this!
4. If not → implement fine-tuning in next iteration

This is a common pattern in research: first establish baselines, then refine.

In [None]:
# Placeholder for recording that this notebook was run
tracker.log_metrics({
    'status': 'template_created',
    'note': 'Awaiting model-specific implementation'
})

tracker.print_summary()

print("\n" + "="*60)
print("PHASE 2 TEMPLATE COMPLETE")
print("="*60)
print("\nThis notebook provides the structure for foundation model fine-tuning.")
print("See 'Next Steps' section above for implementation guidance.")
print("\nRecommended: Start by running existing test scripts to establish")
print("zero-shot baseline performance before implementing fine-tuning.")