# ARMA Parameter Recovery Experiment

This notebook trains a linear head to recover ARMA model parameters from the trained SimpleModel's latent representations.

## Overview
- Load the pre-trained SimpleModel from forecast_arma.ipynb
- Create a linear head that outputs 2 tensors of 8 floats (AR and MA parameters)
- Train the head to recover the original ARMA parameters
- Evaluate performance during and after training


In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 1

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from types import SimpleNamespace

# Import functions from codebase modules
from arma import generate_arma_batch
from network import SimpleModel

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
# =============================================================================
# CONFIGURATION: ARMA Parameter Recovery Settings
# =============================================================================
NUM_ARMA_PARAMS = 8  # Number of AR and MA coefficients to recover (change this to control the experiment)
# =============================================================================
print(f"Configuration: NUM_ARMA_PARAMS = {NUM_ARMA_PARAMS}")
print("Change NUM_ARMA_PARAMS in this cell to control the number of AR/MA coefficients to recover")


In [None]:
# Load the pre-trained model
model = SimpleModel(C=4, H=1024, W=32)
model.load_state_dict(torch.load('trained_simple_model_H1024.pth', map_location=device))
model = model.to(device)
model.eval()  # Set to evaluation mode
print("Pre-trained model loaded successfully")

# Freeze the pre-trained model parameters
for param in model.parameters():
    param.requires_grad = False
print("Pre-trained model parameters frozen")


In [None]:
class ParameterRecoveryHead(nn.Module):
    """Linear head to recover ARMA parameters from latent representations.
    Works on flattened C*H representation: [B, T, C*H] -> [B, T, num_arma_params]"""
    
    def __init__(self, C=4, H=64, hidden_dim=64, num_arma_params=8):
        super().__init__()
        input_dim = H  # H flattened dimension
        
        # Linear head that works on C*H dimension
        # Input: [B, T, C*H] -> [B, T, hidden_dim]
        self.shared_layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.CELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.CELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.CELU(),
            nn.Dropout(0.1),
        )
        
        # Separate heads for AR and MA parameters
        # Output per (B, T): num_arma_params AR + num_arma_params MA = 2*num_arma_params values
        self.ar_head = nn.Linear(hidden_dim, num_arma_params)  # num_arma_params AR parameters
        self.ma_head = nn.Linear(hidden_dim, num_arma_params)  # num_arma_params MA parameters
        
    def forward(self, x):
        """
        Args:
            x: Input tensor of shape [batch_size, T, C*H] (flattened C and H)
        Returns:
            ar_params: AR parameters [batch_size, T, num_arma_params]
            ma_params: MA parameters [batch_size, T, num_arma_params]
        """
        # Apply linear head to C*H dimension, vectorized over B, T
        # shared_features: [B, T, hidden_dim]
        shared_features = self.shared_layers(x)
        
        # Predict parameters for each (B, T) position
        # ar_predictions: [B, T, num_arma_params], ma_predictions: [B, T, num_arma_params]
        ar_predictions = torch.tanh(self.ar_head(shared_features))
        ma_predictions = torch.tanh(self.ma_head(shared_features))
        
        return ar_predictions, ma_predictions

# Initialize the parameter recovery head
# Head works on C*H dimension (4*256=1024)
param_head = ParameterRecoveryHead(C=4, H=256, hidden_dim=256, num_arma_params=NUM_ARMA_PARAMS).to(device)
print(f"Parameter recovery head initialized with C=4, H=256 (input_dim=1024), num_params={NUM_ARMA_PARAMS}")


In [None]:
def extract_latent_features(model, x):
    """Extract latent features from the pre-trained model and reshape to C*H."""
    with torch.no_grad():
        # Get latent representations from the model
        h_hat, h = model(x)  # h_hat: forecasted, h: original
        
        # Use the original latent representation (h)
        # Reshape from [batch_size, T, C, H] to [batch_size, T, C*H]
        # This flattens C and H together since they are sampled the same and correlated
        B, T, C, H = h.shape
        h_reshaped = h.reshape(B, T, C * H)
        return h_reshaped

def parameter_loss(pred_ar, pred_ma, true_ar, true_ma):
    """Compute loss between predicted and true parameters.
    
    Args:
        pred_ar: [B, T, NUM_ARMA_PARAMS] - predicted AR parameters
        pred_ma: [B, T, NUM_ARMA_PARAMS] - predicted MA parameters
        true_ar: [B, NUM_ARMA_PARAMS] - true AR parameters
        true_ma: [B, NUM_ARMA_PARAMS] - true MA parameters
    
    Returns:
        total_loss: combined AR + MA loss
        ar_loss: AR loss
        ma_loss: MA loss
    """
    B, T, _ = pred_ar.shape
    
    # Average over time dimension to get [B, NUM_ARMA_PARAMS]
    # This aggregates predictions across all patches on the time dimension
    pred_ar_avg = pred_ar.mean(dim=1)  # [B, NUM_ARMA_PARAMS]
    pred_ma_avg = pred_ma.mean(dim=1)  # [B, NUM_ARMA_PARAMS]
    
    # Compute loss between averaged predictions [B, NUM_ARMA_PARAMS] and true [B, NUM_ARMA_PARAMS]
    ar_loss = F.mse_loss(pred_ar_avg, true_ar)
    ma_loss = F.mse_loss(pred_ma_avg, true_ma)
    
    return ar_loss + ma_loss, ar_loss, ma_loss


In [None]:
def prepare_training_data(batch_size=32, T_raw=4096, C=4, seed=None):
    """Generate training data with known ARMA parameters."""
    # Generate ARMA data with known parameters
    x, parameters = generate_arma_batch(batch_size=batch_size, T_raw=T_raw, C=C, seed=seed)
    x = x.to(device)
    
    # Extract true parameters
    true_ar_params = []
    true_ma_params = []
    
    for ar_poly, ma_poly in parameters:
        # Convert polynomial form to parameter form
        # AR: 1 - φ1*L - φ2*L^2 - ... -> [φ1, φ2, ...]
        # MA: 1 + θ1*L + θ2*L^2 + ... -> [θ1, θ2, ...]
        ar_coeffs = -ar_poly[1:]  # Remove constant term and negate
        ma_coeffs = ma_poly[1:]   # Remove constant term
        
        # Pad or truncate to exactly NUM_ARMA_PARAMS parameters
        ar_padded = np.pad(ar_coeffs, (0, max(0, NUM_ARMA_PARAMS - len(ar_coeffs))), mode='constant')[:NUM_ARMA_PARAMS]
        ma_padded = np.pad(ma_coeffs, (0, max(0, NUM_ARMA_PARAMS - len(ma_coeffs))), mode='constant')[:NUM_ARMA_PARAMS]
        
        true_ar_params.append(ar_padded)
        true_ma_params.append(ma_padded)
    
    true_ar = torch.tensor(np.array(true_ar_params), dtype=torch.float32).to(device)
    true_ma = torch.tensor(np.array(true_ma_params), dtype=torch.float32).to(device)
    
    return x, true_ar, true_ma

# Test data generation
x_test, ar_test, ma_test = prepare_training_data(batch_size=4, seed=42)
print(f"Test data shapes: x={x_test.shape}, ar={ar_test.shape}, ma={ma_test.shape}")
print(f"Sample AR params: {ar_test[0]}")
print(f"Sample MA params: {ma_test[0]}")
print(f"AR L1 norm: {torch.norm(ar_test[0], p=1):.4f}")
print(f"MA L1 norm: {torch.norm(ma_test[0], p=1):.4f}")


In [None]:
def train_parameter_recovery(param_head, model, num_epochs=100, batch_size=32, lr=1e-3):
    """Train the parameter recovery head."""
    optimizer = optim.Adam(param_head.parameters(), lr=lr)
    
    # Training history
    train_losses = []
    val_losses = []
    ar_losses = []
    ma_losses = []
    
    # Generate validation set
    x_val, ar_val, ma_val = prepare_training_data(batch_size=batch_size, seed=0)
    h_val = extract_latent_features(model, x_val)
    
    print("Starting parameter recovery training...")
    
    for epoch in range(num_epochs):
        # Training step
        param_head.train()
        optimizer.zero_grad()
        
        # Generate training batch
        x_train, ar_train, ma_train = prepare_training_data(batch_size=batch_size, seed=epoch)
        h_train = extract_latent_features(model, x_train)
        
        # Forward pass
        pred_ar, pred_ma = param_head(h_train)
        
        # Compute loss
        loss, ar_loss, ma_loss = parameter_loss(pred_ar, pred_ma, ar_train, ma_train)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Validation step
        param_head.eval()
        with torch.no_grad():
            pred_ar_val, pred_ma_val = param_head(h_val)
            val_loss, _, _ = parameter_loss(pred_ar_val, pred_ma_val, ar_val, ma_val)
        
        # Record metrics
        train_losses.append(loss.item())
        val_losses.append(val_loss.item())
        ar_losses.append(ar_loss.item())
        ma_losses.append(ma_loss.item())
        
        # Print progress
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}: "
                  f"Train Loss: {loss.item():.6f}, "
                  f"Val Loss: {val_loss.item():.6f}, "
                  f"AR Loss: {ar_loss.item():.6f}, "
                  f"MA Loss: {ma_loss.item():.6f}")
    
    return train_losses, val_losses, ar_losses, ma_losses


In [None]:
# Start training
train_losses, val_losses, ar_losses, ma_losses = train_parameter_recovery(
    param_head, model, num_epochs=30_000, batch_size=32, lr=1e-3
)

In [None]:
# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss', alpha=0.8)
plt.plot(val_losses, label='Val Loss', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Total Loss')
plt.legend()
plt.yscale('log')

plt.subplot(1, 3, 2)
plt.plot(ar_losses, label='AR Loss', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('AR Parameter Loss')
plt.legend()
plt.yscale('log')

plt.subplot(1, 3, 3)
plt.plot(ma_losses, label='MA Loss', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('MA Parameter Loss')
plt.legend()
plt.yscale('log')

plt.tight_layout()
plt.show()


In [None]:
# Evaluate final performance
def evaluate_parameter_recovery(param_head, model, num_samples=100):
    """Evaluate parameter recovery performance on test samples."""
    param_head.eval()
    
    all_ar_errors = []
    all_ma_errors = []
    all_total_errors = []
    all_baseline_errors = []  # Error if predicting 0
    
    with torch.no_grad():
        for i in range(num_samples):
            try:
                # Generate test sample
                x_test, ar_true, ma_true = prepare_training_data(batch_size=1, seed=i+1000)
                h_test = extract_latent_features(model, x_test)
                
                # Predict parameters: [B, T, NUM_ARMA_PARAMS]
                pred_ar, pred_ma = param_head(h_test)
                
                # Compute errors using parameter_loss function
                _, ar_error, ma_error = parameter_loss(pred_ar, pred_ma, ar_true, ma_true)
                ar_error = ar_error.item()
                ma_error = ma_error.item()
                total_error = ar_error + ma_error
                
                # Compute baseline error (predicting 0)
                baseline_ar_error = F.mse_loss(torch.zeros_like(ar_true), ar_true).item()
                baseline_ma_error = F.mse_loss(torch.zeros_like(ma_true), ma_true).item()
                baseline_total_error = baseline_ar_error + baseline_ma_error
                
                all_ar_errors.append(ar_error)
                all_ma_errors.append(ma_error)
                all_total_errors.append(total_error)
                all_baseline_errors.append(baseline_total_error)
            except Exception as e:
                print(f"Error in sample {i}: {e}")
                continue
    
    if not all_ar_errors:
        raise RuntimeError("No valid samples processed during evaluation")
    
    return {
        'ar_errors': all_ar_errors,
        'ma_errors': all_ma_errors,
        'total_errors': all_total_errors,
        'baseline_errors': all_baseline_errors,
        'mean_ar_error': np.mean(all_ar_errors),
        'mean_ma_error': np.mean(all_ma_errors),
        'mean_total_error': np.mean(all_total_errors),
        'mean_baseline_error': np.mean(all_baseline_errors),
        'std_ar_error': np.std(all_ar_errors),
        'std_ma_error': np.std(all_ma_errors),
        'std_total_error': np.std(all_total_errors),
        'std_baseline_error': np.std(all_baseline_errors)
    }

# Run evaluation
results = evaluate_parameter_recovery(param_head, model, num_samples=200)

print("\n=== Parameter Recovery Performance ===")
print(f"Mean AR Error: {results['mean_ar_error']:.6f} ± {results['std_ar_error']:.6f}")
print(f"Mean MA Error: {results['mean_ma_error']:.6f} ± {results['std_ma_error']:.6f}")
print(f"Mean Total Error: {results['mean_total_error']:.6f} ± {results['std_total_error']:.6f}")
print(f"\n=== Baseline (predicting 0) ===")
print(f"Mean Baseline Error: {results['mean_baseline_error']:.6f} ± {results['std_baseline_error']:.6f}")
print(f"Improvement Ratio: {results['mean_baseline_error'] / results['mean_total_error']:.2f}x better than baseline")


In [None]:
# Visualize parameter recovery quality
def plot_parameter_comparison(param_head, model, num_examples=5):
    """Plot comparison between true and predicted parameters."""
    param_head.eval()
    
    fig, axes = plt.subplots(num_examples, 2, figsize=(12, 3*num_examples))
    
    with torch.no_grad():
        for i in range(num_examples):
            # Generate test sample
            x_test, ar_true, ma_true = prepare_training_data(batch_size=1, seed=i+2000)
            h_test = extract_latent_features(model, x_test)
            
            # Predict parameters: [1, T, NUM_ARMA_PARAMS]
            pred_ar, pred_ma = param_head(h_test)
            
            # Average over T to get single [NUM_ARMA_PARAMS] prediction for plotting
            pred_ar_mean = pred_ar[0].mean(dim=0).cpu().numpy()  # [NUM_ARMA_PARAMS]
            pred_ma_mean = pred_ma[0].mean(dim=0).cpu().numpy()  # [NUM_ARMA_PARAMS]
            
            # Convert to numpy for plotting
            ar_true_np = ar_true[0].cpu().numpy()
            ma_true_np = ma_true[0].cpu().numpy()
            ar_pred_np = pred_ar_mean
            ma_pred_np = pred_ma_mean
            
            # Plot AR parameters
            axes[i, 0].bar(range(NUM_ARMA_PARAMS), ar_true_np, alpha=0.7, label='True', color='blue')
            axes[i, 0].bar(range(NUM_ARMA_PARAMS), ar_pred_np, alpha=0.7, label='Predicted', color='red')
            axes[i, 0].set_title(f'AR Parameters - Sample {i+1}')
            axes[i, 0].set_xlabel('Parameter Index')
            axes[i, 0].set_ylabel('Value')
            axes[i, 0].legend()
            axes[i, 0].grid(True, alpha=0.3)
            
            # Plot MA parameters
            axes[i, 1].bar(range(NUM_ARMA_PARAMS), ma_true_np, alpha=0.7, label='True', color='blue')
            axes[i, 1].bar(range(NUM_ARMA_PARAMS), ma_pred_np, alpha=0.7, label='Predicted', color='red')
            axes[i, 1].set_title(f'MA Parameters - Sample {i+1}')
            axes[i, 1].set_xlabel('Parameter Index')
            axes[i, 1].set_ylabel('Value')
            axes[i, 1].legend()
            axes[i, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Plot parameter comparisons
plot_parameter_comparison(param_head, model, num_examples=5)


In [None]:
# Error distribution analysis
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.hist(results['ar_errors'], bins=30, alpha=0.7, color='blue')
plt.xlabel('AR Parameter MSE')
plt.ylabel('Frequency')
plt.title('AR Parameter Error Distribution')
plt.axvline(results['mean_ar_error'], color='red', linestyle='--', label=f'Mean: {results["mean_ar_error"]:.4f}')
plt.legend()

plt.subplot(1, 3, 2)
plt.hist(results['ma_errors'], bins=30, alpha=0.7, color='green')
plt.xlabel('MA Parameter MSE')
plt.ylabel('Frequency')
plt.title('MA Parameter Error Distribution')
plt.axvline(results['mean_ma_error'], color='red', linestyle='--', label=f'Mean: {results["mean_ma_error"]:.4f}')
plt.legend()

plt.subplot(1, 3, 3)
plt.hist(results['total_errors'], bins=30, alpha=0.7, color='purple')
plt.xlabel('Total Parameter MSE')
plt.ylabel('Frequency')
plt.title('Total Parameter Error Distribution')
plt.axvline(results['mean_total_error'], color='red', linestyle='--', label=f'Mean: {results["mean_total_error"]:.4f}')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
# Save the trained parameter recovery head
torch.save(param_head.state_dict(), 'parameter_recovery_head.pth')
print("Parameter recovery head saved to 'parameter_recovery_head.pth'")

# Summary of results
print("\n=== Experiment Summary ===")
if train_losses and val_losses:
    print(f"Final training loss: {train_losses[-1]:.6f}")
    print(f"Final validation loss: {val_losses[-1]:.6f}")
else:
    print("Training not completed - no loss values available")

if 'results' in locals() and results:
    print(f"Mean AR parameter recovery error: {results['mean_ar_error']:.6f}")
    print(f"Mean MA parameter recovery error: {results['mean_ma_error']:.6f}")
    print(f"Mean total parameter recovery error: {results['mean_total_error']:.6f}")
else:
    print("Evaluation not completed - no results available")
