In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from types import SimpleNamespace

# Import functions from codebase modules
from arma import generate_arma_batch
from loss import contrastive_latent_loss
from network import *


In [None]:
# Data generation wrapper for backward compatibility
def generate_random_walk(batch_size=16, T_raw=4096, C=4, mean=0.0, std=1.0, seed=None):
    """Generate multivariate random walks using ARMA processes."""
    X, _ = generate_arma_batch(batch_size=batch_size, T_raw=T_raw, C=C, mean=mean, std=std, seed=seed)
    return X


In [None]:
# Generate sample data and initialize model
x = generate_random_walk(batch_size=32, T_raw=4096)
print(f"Data shape: {x.shape}")

model = SimpleModel(C=4, H=64, W=32)
print(f"Model initialized: {model.__class__.__name__}")

In [None]:
# Visualize sample data
plt.figure(figsize=(10, 6))
plt.plot(np.cumsum(x[0, :, :], axis=0))
plt.title('Sample ARMA Process (Cumulative Sum)')
plt.xlabel('Time Steps')
plt.ylabel('Cumulative Value')
plt.legend([f'Channel {i}' for i in range(x.shape[2])])
plt.tight_layout()
plt.show()

In [None]:
# Test model forward pass
h_hat, h = model(x)
print(f"Forecasted latent shape: {h_hat.shape}")
print(f"Original latent shape: {h.shape}")

In [None]:
# Test loss computation
spec = SimpleNamespace(
    train_configuration={
        'contrastive_divergence_temperature': 0.05,
        'contrastive_latent_noise': None,
        'loss_shape': 'cosine_similarity'
    }
)

loss = contrastive_latent_loss(
    predicted_position=[h, h_hat],
    validation=False,
    spec=spec,
    get_history=False
)

print(f"Contrastive latent loss: {loss.item():.4f}")



In [None]:
def create_training_state():
    """Create a new training state dictionary."""
    return {
        'steps': [],
        'train_metrics': {'ff': [], 'fp': [], 'tp': []},
        'val_metrics': {'ff': [], 'fp': [], 'tp': []},
        'current_step': 0,
        'model': None,
        'optimizer': None,
        'x_val': None,
        'spec': None,
        'cld': None
    }

def setup_training(model, C, H, W, batch_size, device, lr=1e-4):
    """Initialize training setup with model, optimizer, and validation data."""
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    # Fixed validation set
    x_val = generate_random_walk(batch_size, T_raw=4096, C=C, seed=0).to(device)
    Bv, Tr, _ = x_val.shape
    T = Tr // W
    xt_val = x_val.view(Bv, T, W, C).permute(0,1,3,2)  # [Bv, T, C, W]
    
    # Training configuration
    spec = SimpleNamespace(train_configuration={
        'contrastive_divergence_temperature': 0.07,
        'contrastive_latent_noise': None,
        'loss_shape': 'cosine_similarity',
        'contrastive_latent_delay': 0
    })
    cld = spec.train_configuration['contrastive_latent_delay'] + 1
    
    return model, optimizer, xt_val, spec, cld

def compute_metrics(f_lat, o_lat, cld):
    """Compute forecast-future, forecast-past, and future-past cosine similarities."""
    fn = F.normalize(f_lat, p=2, dim=-1)
    on = F.normalize(o_lat, p=2, dim=-1)
    hyh = fn[:, :-cld, :, :]
    hyn = on[:,  cld:, :, :]
    hxn = on[:, :-cld, :, :]
    
    ff = (hyh * hyn).sum(-1).mean().item()
    fp = (hyh * hxn).sum(-1).mean().item()
    tp = (hyn * hxn).sum(-1).mean().item()
    
    return ff, fp, tp

def train_step(model, optimizer, loss_fn, C, H, W, batch_size, device, spec):
    """Execute a single training step."""
    model.train()
    optimizer.zero_grad()
    
    # Generate training batch
    x_train = generate_random_walk(batch_size, T_raw=4096, C=C).to(device)
    Bt, _, _ = x_train.shape
    T = x_train.shape[1] // W
    xt = x_train.view(Bt, T, W, C).permute(0,1,3,2)  # [Bt, T, C, W]
    
    # Forward pass
    f_flat, o_flat = model.transformer(xt)
    # f_flat and o_flat have shape [Bt*C, T, H]
    f_lat = f_flat.reshape(Bt, C, T, H).permute(0,2,1,3)  # [Bt, T, C, H]
    o_lat = o_flat.reshape(Bt, C, T, H).permute(0,2,1,3)  # [Bt, T, C, H]
    
    # Compute loss and backprop
    loss = loss_fn((f_lat, o_lat), validation=False, spec=spec)
    loss.backward()
    optimizer.step()
    
    return loss.item(), f_lat, o_lat

def validation_step(model, x_val, loss_fn, spec, cld):
    """Execute validation step."""
    model.eval()
    with torch.no_grad():
        fv_flat, ov_flat = model.transformer(x_val)
        Bv, T, C, W = x_val.shape  # x_val has shape [Bv, T, C, W]
        H = model.H  # Get H from model
        
        # Reshape the flat outputs correctly
        # fv_flat and ov_flat have shape [Bv*C, T, H]
        fv = fv_flat.reshape(Bv, C, T, H).permute(0,2,1,3)  # [Bv, T, C, H]
        ov = ov_flat.reshape(Bv, C, T, H).permute(0,2,1,3)  # [Bv, T, C, H]
        
        return compute_metrics(fv, ov, cld)

def plot_training_curves(training_state):
    """Plot training metrics curves."""
    plt.figure(figsize=(8,4))
    steps = range(1, len(training_state['train_metrics']['ff'])+1)
    plt.plot(steps, training_state['train_metrics']['ff'], label='Train Forecast vs Future')
    plt.plot(steps, training_state['train_metrics']['fp'], label='Train Forecast vs Past')
    plt.plot(steps, training_state['train_metrics']['tp'], label='Train Future vs Past')
    plt.xlabel('Step')
    plt.ylabel('Mean Cosine Similarity')
    plt.title('Training Metrics')
    plt.legend()
    plt.tight_layout()
    plt.show()

def plot_validation_curves(training_state):
    """Plot validation metrics curves."""
    plt.figure(figsize=(8,4))
    plt.plot(training_state['steps'], training_state['val_metrics']['ff'], label='Val Forecast vs Future')
    plt.plot(training_state['steps'], training_state['val_metrics']['fp'], label='Val Forecast vs Past')
    plt.plot(training_state['steps'], training_state['val_metrics']['tp'], label='Val Future vs Past')
    plt.xlabel('Step')
    plt.ylabel('Mean Cosine Similarity')
    plt.title('Validation Metrics')
    plt.legend()
    plt.tight_layout()
    plt.show()

def train_model(model, loss_fn, C, H, W, total_steps=500, batch_size=8, 
                lr=1e-4, device='cpu', val_every=50, training_state=None):
    """
    Train model with re-entry capability.
    
    Args:
        training_state: Dictionary containing training state. If None, creates a new one.
    """
    # Create or use provided training state
    if training_state is None:
        training_state = create_training_state()
    
    # Setup training if not already done
    if training_state['model'] is None:
        model, optimizer, x_val, spec, cld = setup_training(model, C, H, W, batch_size, device, lr)
        training_state.update({
            'model': model,
            'optimizer': optimizer,
            'x_val': x_val,
            'spec': spec,
            'cld': cld,
            'current_step': 0
        })
    else:
        model = training_state['model']
        optimizer = training_state['optimizer']
        x_val = training_state['x_val']
        spec = training_state['spec']
        cld = training_state['cld']
    
    current_step = training_state['current_step']
    
    # Training loop
    for step in range(current_step + 1, total_steps + 1):
        # Training step
        loss_val, f_lat, o_lat = train_step(model, optimizer, loss_fn, C, H, W, batch_size, device, spec)
        
        # Compute training metrics
        train_ff, train_fp, train_tp = compute_metrics(f_lat, o_lat, cld)
        training_state['train_metrics']['ff'].append(train_ff)
        training_state['train_metrics']['fp'].append(train_fp)
        training_state['train_metrics']['tp'].append(train_tp)
        
        # Validation step
        if step % val_every == 0 or step == total_steps:
            val_ff, val_fp, val_tp = validation_step(model, x_val, loss_fn, spec, cld)
            training_state['val_metrics']['ff'].append(val_ff)
            training_state['val_metrics']['fp'].append(val_fp)
            training_state['val_metrics']['tp'].append(val_tp)
            training_state['steps'].append(step)
            
            print(f"[Step {step}] train loss {loss_val:.4f} | "
                  f"train FF={train_ff:.4f}, FP={train_fp:.4f}, TP={train_tp:.4f} || "
                  f"val   FF={val_ff:.4f}, FP={val_fp:.4f}, TP={val_tp:.4f}")
            
            # Plot curves
            try:
                plot_training_curves(training_state)
                plot_validation_curves(training_state)
            except Exception as e:
                print(f"Plotting error: {e}")
        else:
            print(f"[Step {step}] train loss {loss_val:.4f}")
        
        # Update current step
        training_state['current_step'] = step
    
    # Final plots
    plot_training_curves(training_state)
    plot_validation_curves(training_state)
    
    return model, training_state

In [None]:
# Initialize model and start training
model = SimpleModel(C=4, H=64, W=32)
trained_model, training_state = train_model(
    model,
    loss_fn=contrastive_latent_loss,
    C=4, H=64, W=32,
    total_steps=500_000,
    batch_size=32,
    lr=1e-4,
    device='cuda',
    val_every=500,
)

In [None]:
# Example: Resume training from where you left off (re-entry capability)
# This allows you to interrupt training and continue from where you left off

print(f"Current step: {training_state['current_step']}")
print(f"Training metrics length: {len(training_state['train_metrics']['ff'])}")
print(f"Validation steps: {training_state['steps']}")

# To resume training from where you left off:
# resumed_model, updated_state = train_model(
#     model,
#     loss_fn=contrastive_latent_loss,
#     C=4, H=64, W=32,
#     total_steps=1000,  # Continue to step 1000
#     batch_size=32,
#     lr=1e-4,
#     device='cuda',
#     val_every=500,
#     training_state=training_state  # Pass the existing state
# )

# To start fresh training with a new state:
# new_model = SimpleModel(C=4, H=64, W=32)
# fresh_model, new_state = train_model(
#     new_model,
#     loss_fn=contrastive_latent_loss,
#     C=4, H=64, W=32,
#     total_steps=1000,
#     batch_size=32,
#     lr=1e-4,
#     device='cuda',
#     val_every=500,
#     training_state=None  # Creates a new state
# )
