# Multisession Training: Linear AE vs Flow AE Comparison

This notebook compares the performance of two autoencoder architectures:
1. **LinearChannelAE**: Linear tied-weight autoencoder (like PCA)
2. **FlowChannelAE**: Normalizing flow autoencoder (invertible, nonlinear)

Both are tested on the same multisession neural data.

**Logging**: All output is saved to `logs/ae_comparison_TIMESTAMP.log`

In [1]:
%load_ext autoreload
%autoreload 2

import os
import random
import sys
import logging
import time
from datetime import datetime, timedelta

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

from pathlib import Path
from hydra import initialize_config_dir, compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
import tqdm
import torch
import matplotlib.pyplot as plt

from tbfm import film
from tbfm import multisession
from tbfm import utils

DATA_DIR = "/home/danmuir/Projects/tbfm_multisession/data"
sys.path.append(DATA_DIR)
from tbfm import dataset
meta = dataset.load_meta(DATA_DIR)

OUT_DIR = "data"
EMBEDDING_REST_SUBDIR = "embedding_rest"

conf_dir = Path("./conf").resolve()

# Initialize Hydra with the configuration directory
with initialize_config_dir(config_dir=str(conf_dir), version_base=None):
    cfg = compose(config_name="config")

DEVICE = "cuda"
WINDOW_SIZE = cfg.data.trial_len
NUM_HELD_OUT_SESSIONS = cfg.training.num_held_out_sessions

## Setup Logging

In [2]:
class ExperimentLogger:
    """Logger that writes to both console and file with timing information."""
    
    def __init__(self, log_dir="logs", experiment_name="ae_comparison"):
        # Create logs directory
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(exist_ok=True)
        
        # Create log file with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.log_file = self.log_dir / f"{experiment_name}_{timestamp}.log"
        
        # Setup logger
        self.logger = logging.getLogger(experiment_name)
        self.logger.setLevel(logging.INFO)
        
        # Remove existing handlers
        self.logger.handlers = []
        
        # File handler
        fh = logging.FileHandler(self.log_file)
        fh.setLevel(logging.INFO)
        
        # Console handler
        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)
        
        # Formatter
        formatter = logging.Formatter('%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
        fh.setFormatter(formatter)
        ch.setFormatter(formatter)
        
        self.logger.addHandler(fh)
        self.logger.addHandler(ch)
        
        # Timing
        self.start_time = None
        self.phase_start = None
        
        self.info(f"Logging to: {self.log_file}")
        self.info("="*80)
    
    def info(self, message):
        self.logger.info(message)
    
    def start_phase(self, phase_name):
        """Start timing a phase."""
        self.phase_start = time.time()
        self.info(f"\n{'='*80}")
        self.info(f"Starting: {phase_name}")
        self.info(f"{'='*80}")
    
    def end_phase(self, phase_name):
        """End timing a phase and report duration."""
        if self.phase_start is not None:
            duration = time.time() - self.phase_start
            self.info(f"\nCompleted: {phase_name}")
            self.info(f"Duration: {self.format_duration(duration)}")
            self.info(f"{'='*80}")
            self.phase_start = None
    
    @staticmethod
    def format_duration(seconds):
        """Format seconds as HH:MM:SS."""
        hours = int(seconds // 3600)
        minutes = int((seconds % 3600) // 60)
        secs = int(seconds % 60)
        if hours > 0:
            return f"{hours:02d}:{minutes:02d}:{secs:02d}"
        else:
            return f"{minutes:02d}:{secs:02d}"
    
    def start_training(self, num_epochs, phase_name="Training"):
        """Start training with progress tracking."""
        self.training_start = time.time()
        self.num_epochs = num_epochs
        self.start_phase(phase_name)
    
    def log_progress(self, epoch, num_epochs, train_loss, test_loss=None, train_r2=None, test_r2=None):
        """Log training progress with ETA."""
        elapsed = time.time() - self.training_start
        epochs_done = epoch + 1
        time_per_epoch = elapsed / epochs_done
        eta_seconds = time_per_epoch * (num_epochs - epochs_done)
        
        progress_pct = (epochs_done / num_epochs) * 100
        
        msg = f"Epoch {epoch}/{num_epochs} ({progress_pct:.1f}%) | "
        msg += f"Train Loss: {train_loss:.6f}"
        
        if test_loss is not None:
            msg += f" | Test Loss: {test_loss:.6f}"
        if train_r2 is not None:
            msg += f" | Train R²: {train_r2:.6f}"
        if test_r2 is not None:
            msg += f" | Test R²: {test_r2:.6f}"
        
        msg += f" | Elapsed: {self.format_duration(elapsed)}"
        msg += f" | ETA: {self.format_duration(eta_seconds)}"
        
        self.info(msg)
    
    def end_training(self, phase_name="Training"):
        """End training and report total time."""
        total_time = time.time() - self.training_start
        self.info(f"\n{phase_name} completed in {self.format_duration(total_time)}")
        self.end_phase(phase_name)

# Initialize logger
logger = ExperimentLogger(log_dir="logs", experiment_name="ae_comparison")
logger.info("Experiment: Linear AE vs Flow AE Comparison")
logger.info(f"Device: {DEVICE}")
logger.info(f"Window size: {WINDOW_SIZE}")

2025-10-27 18:01:12 - Logging to: logs/ae_comparison_20251027_180112.log


2025-10-27 18:01:12 - Experiment: Linear AE vs Flow AE Comparison
2025-10-27 18:01:12 - Device: cuda
2025-10-27 18:01:12 - Window size: 184


## Load Data

In [3]:
logger.start_phase("Data Loading")

# Session selection
held_in_session_ids=["MonkeyG_20150925_Session2_S1"]

num_sessions = len(held_in_session_ids)
MAX_BATCH_SIZE = 62500 // 2
batch_size = (MAX_BATCH_SIZE // num_sessions) * num_sessions

d, held_out_session_ids = multisession.load_stim_batched(
    window_size=WINDOW_SIZE,
    session_subdir="torchraw",
    data_dir=DATA_DIR,
    unpack_stiminds=True,
    held_in_session_ids=held_in_session_ids,
    batch_size=batch_size,
    num_held_out_sessions=NUM_HELD_OUT_SESSIONS,
)
data_train, data_test = d.train_test_split(5000, test_cut=2500)

held_in_session_ids = data_train.session_ids

# Load cached rest embeddings
embeddings_rest = multisession.load_rest_embeddings(held_in_session_ids, device=DEVICE)

logger.info(f"Loaded {len(held_in_session_ids)} sessions")
logger.info(f"Training batch size: {batch_size}")
logger.info(f"Sessions: {held_in_session_ids}")

# Check batch shapes
b = next(iter(data_train))
k0 = list(b.keys())[0]
logger.info(f"Train batch shape: {b[k0][0].shape}")

b = next(iter(data_test))
logger.info(f"Test batch shape: {b[k0][0].shape}")

logger.end_phase("Data Loading")

2025-10-27 18:01:15 - 
2025-10-27 18:01:15 - Starting: Data Loading
2025-10-27 18:01:16 - Loaded 1 sessions
2025-10-27 18:01:16 - Training batch size: 31250
2025-10-27 18:01:16 - Sessions: ['MonkeyG_20150925_Session2_S1']
2025-10-27 18:01:16 - Train batch shape: torch.Size([5000, 20, 60])
2025-10-27 18:01:16 - Test batch shape: torch.Size([2500, 20, 60])
2025-10-27 18:01:16 - 
Completed: Data Loading
2025-10-27 18:01:16 - Duration: 00:01


## Configure Training Parameters

In [4]:
def cfg_base(cfg, dim):
    """Base configuration for both AE types"""
    cfg.ae.training.coadapt = False
    cfg.ae.warm_start_is_identity = True
    cfg.tbfm.module.use_film_bases = False
    cfg.tbfm.module.num_bases = 12
    cfg.tbfm.module.latent_dim = 2
    cfg.latent_dim = dim
    cfg.training.epochs = 7001
    cfg.normalizers.module._target_ = "tbfm.normalizers.ScalerZscore"
    cfg.tbfm.training.lambda_fro = 0.03
    return cfg

logger.info("Configuration: epochs=7001, latent_dim=50")

2025-10-27 18:01:18 - Configuration: epochs=7001, latent_dim=50


## Training Helper Function

In [5]:
# Define training wrapper function that adds logging
import builtins

original_train = multisession.train_from_cfg

def train_with_logging(*args, **kwargs):
    """Wrapper for train_from_cfg that logs progress to logger."""
    # Store original print function
    original_print = builtins.print
    
    def logging_print(*print_args, **print_kwargs):
        # Capture the print output
        msg = ' '.join(map(str, print_args))
        
        # Check if it's a training progress line
        if msg.startswith('----'):
            parts = msg.split()
            if len(parts) >= 5:
                try:
                    epoch = int(parts[1])
                    train_loss = float(parts[2])
                    test_loss = float(parts[3])
                    train_r2 = float(parts[4])
                    test_r2 = float(parts[5]) if len(parts) > 5 else None
                    
                    logger.log_progress(epoch, kwargs.get('epochs', 7001), train_loss, test_loss, train_r2, test_r2)
                    return
                except (ValueError, IndexError):
                    pass
        
        # For other messages, just log normally
        if msg.strip() and not msg.startswith('Building') and not msg.startswith('BOOM'):
            logger.info(msg)
        original_print(*print_args, **print_kwargs)
    
    # Replace print temporarily
    builtins.print = logging_print
    
    try:
        result = original_train(*args, **kwargs)
    finally:
        builtins.print = original_print
    
    return result

logger.info("Training helper function defined")

2025-10-27 18:01:20 - Training helper function defined


## Experiment 1: Linear Autoencoder (Baseline)

In [10]:
logger.start_phase("EXPERIMENT 1: Linear Autoencoder (Baseline)")

# Load config with linear AE
with initialize_config_dir(config_dir=str(conf_dir), version_base=None):
    cfg_linear = compose(config_name="config")  # Default is linear

cfg_linear = cfg_base(cfg_linear, dim=50)

# Build model
logger.info("Building Linear AE model...")
ms_linear = multisession.build_from_cfg(cfg_linear, data_train, device=DEVICE, quiet=True)
model_optims_linear = multisession.get_optims(cfg_linear, ms_linear)

logger.info(f"Autoencoder type: {type(ms_linear.ae.instances[held_in_session_ids[0]]).__name__}")
logger.info(f"Latent dim: {cfg_linear.latent_dim}")

2025-10-27 17:13:52 - 
2025-10-27 17:13:52 - Starting: EXPERIMENT 1: Linear Autoencoder (Baseline)
2025-10-27 17:13:53 - Building Linear AE model...
2025-10-27 17:13:53 - Autoencoder type: LinearChannelAE
2025-10-27 17:13:53 - Latent dim: 50


In [11]:
# Train Linear AE with logging
logger.start_training(cfg_linear.training.epochs, "Linear AE Training")

embeddings_stim_linear, results_linear = train_with_logging(
    cfg_linear,
    ms_linear,
    data_train,
    model_optims_linear,
    embeddings_rest,
    data_test=data_test,
    test_interval=1000,
    epochs=cfg_linear.training.epochs
)

logger.end_training("Linear AE Training")

2025-10-27 17:13:53 - 
2025-10-27 17:13:53 - Starting: Linear AE Training
2025-10-27 17:13:53 - Epoch 0/7001 (0.0%) | Train Loss: 1.187163 | Test Loss: 1.151420 | Train R²: 0.254855 | Test R²: 0.393432 | Elapsed: 00:00 | ETA: 44:35
2025-10-27 17:14:36 - Epoch 1000/7001 (14.3%) | Train Loss: 0.658903 | Test Loss: 1.034887 | Train R²: 0.363734 | Test R²: 0.454519 | Elapsed: 00:43 | ETA: 04:18
2025-10-27 17:15:22 - Epoch 2000/7001 (28.6%) | Train Loss: 0.618455 | Test Loss: 0.991228 | Train R²: 0.404007 | Test R²: 0.477522 | Elapsed: 01:29 | ETA: 03:42
2025-10-27 17:16:13 - Epoch 3000/7001 (42.9%) | Train Loss: 0.589484 | Test Loss: 0.957529 | Train R²: 0.433221 | Test R²: 0.495303 | Elapsed: 02:20 | ETA: 03:06
2025-10-27 17:17:03 - Epoch 4000/7001 (57.1%) | Train Loss: 0.567299 | Test Loss: 0.934708 | Train R²: 0.455795 | Test R²: 0.507340 | Elapsed: 03:10 | ETA: 02:22
2025-10-27 17:17:53 - Epoch 5000/7001 (71.4%) | Train Loss: 0.533861 | Test Loss: 0.891246 | Train R²: 0.488432 | Test R

Final: 0.8666365742683411 0.5432320237159729


In [12]:
# Custom training loop with logging
logger.start_training(cfg_linear.training.epochs, "Linear AE Training")

# Monkey-patch the train function to add logging
import types

original_train = multisession.train_from_cfg

def train_with_logging(*args, **kwargs):
    # Store original print function
    import builtins
    original_print = builtins.print
    
    # Track epoch info
    epoch_info = {'last_epoch': 0, 'last_train_loss': 0, 'last_test_loss': 0, 'last_train_r2': 0, 'last_test_r2': 0}
    
    def logging_print(*print_args, **print_kwargs):
        # Capture the print output
        msg = ' '.join(map(str, print_args))
        
        # Check if it's a training progress line
        if msg.startswith('----'):
            parts = msg.split()
            if len(parts) >= 5:
                try:
                    epoch = int(parts[1])
                    train_loss = float(parts[2])
                    test_loss = float(parts[3])
                    train_r2 = float(parts[4])
                    test_r2 = float(parts[5]) if len(parts) > 5 else None
                    
                    logger.log_progress(epoch, kwargs.get('epochs', 7001), train_loss, test_loss, train_r2, test_r2)
                    return
                except (ValueError, IndexError):
                    pass
        
        # For other messages, just log normally
        if msg.strip() and not msg.startswith('Building') and not msg.startswith('BOOM'):
            logger.info(msg)
        original_print(*print_args, **print_kwargs)
    
    # Replace print temporarily
    builtins.print = logging_print
    
    try:
        result = original_train(*args, **kwargs)
    finally:
        builtins.print = original_print
    
    return result

embeddings_stim_linear, results_linear = train_with_logging(
    cfg_linear,
    ms_linear,
    data_train,
    model_optims_linear,
    embeddings_rest,
    data_test=data_test,
    test_interval=1000,
    epochs=cfg_linear.training.epochs
)

logger.end_training("Linear AE Training")

2025-10-27 17:19:33 - 
2025-10-27 17:19:33 - Starting: Linear AE Training
2025-10-27 17:19:33 - Epoch 0/7001 (0.0%) | Train Loss: 0.484650 | Test Loss: 0.867595 | Train R²: 0.534223 | Test R²: 0.542716 | Elapsed: 00:00 | ETA: 08:51
2025-10-27 17:20:23 - Epoch 1000/7001 (14.3%) | Train Loss: 0.480487 | Test Loss: 0.865982 | Train R²: 0.537112 | Test R²: 0.543538 | Elapsed: 00:50 | ETA: 05:03
2025-10-27 17:21:13 - Epoch 2000/7001 (28.6%) | Train Loss: 0.502539 | Test Loss: 0.868884 | Train R²: 0.517390 | Test R²: 0.542014 | Elapsed: 01:40 | ETA: 04:11
2025-10-27 17:22:03 - Epoch 3000/7001 (42.9%) | Train Loss: 0.504160 | Test Loss: 0.861968 | Train R²: 0.518835 | Test R²: 0.545701 | Elapsed: 02:30 | ETA: 03:20
2025-10-27 17:22:53 - Epoch 4000/7001 (57.1%) | Train Loss: 0.508790 | Test Loss: 0.857332 | Train R²: 0.515552 | Test R²: 0.548139 | Elapsed: 03:20 | ETA: 02:30
2025-10-27 17:23:43 - Epoch 5000/7001 (71.4%) | Train Loss: 0.502248 | Test Loss: 0.855573 | Train R²: 0.522419 | Test R

Final: 0.84388667345047 0.5552590489387512


In [25]:
# Clean up Linear AE from GPU memory
logger.info("\nCleaning up Linear AE from GPU memory...")

# Move model to CPU (needed later for latent space analysis)
# if ms_linear is not None:
#     ms_linear = ms_linear.cpu()

# Delete optimizer (not needed anymore)
# del model_optims_linear

# Move result tensors to CPU
for key in results_linear.get('y_test', {}).keys():
    if isinstance(results_linear['y_test'][key], (list, tuple)):
        results_linear['y_test'][key] = [t.cpu() if torch.is_tensor(t) else t for t in results_linear['y_test'][key]]
    elif torch.is_tensor(results_linear['y_test'][key]):
        results_linear['y_test'][key] = results_linear['y_test'][key].cpu()
        
for key in results_linear.get('y_hat_test', {}).keys():
    if torch.is_tensor(results_linear['y_hat_test'][key]):
        results_linear['y_hat_test'][key] = results_linear['y_hat_test'][key].cpu()

# Move embeddings to CPU
embeddings_stim_linear = {k: v.cpu() if torch.is_tensor(v) else v for k, v in embeddings_stim_linear.items()}

# Clear CUDA cache
torch.cuda.empty_cache()
logger.info(f"GPU memory freed. Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")

2025-10-27 17:45:57 - 
Cleaning up Linear AE from GPU memory...


AttributeError: 'NoneType' object has no attribute 'items'

## Experiment 2: Flow Autoencoder (Invertible)

In [None]:
logger.start_phase("EXPERIMENT 2: Flow Autoencoder (Invertible)")

# Load config with flow AE
with initialize_config_dir(config_dir=str(conf_dir), version_base=None):
    cfg_flow = compose(
        config_name="config",
        overrides=["ae=flow"]  # Switch to flow autoencoder
    )

cfg_flow = cfg_base(cfg_flow, dim=50)

# Configure flow-specific parameters
cfg_flow.ae.module.num_flow_layers = 3
cfg_flow.ae.module.hidden_dim = 64

# Build model
logger.info("Building Flow AE model...")
ms_flow = multisession.build_from_cfg(cfg_flow, data_train, device=DEVICE, quiet=True)
model_optims_flow = multisession.get_optims(cfg_flow, ms_flow)

logger.info(f"Autoencoder type: {type(ms_flow.ae.instances[held_in_session_ids[0]]).__name__}")
logger.info(f"Latent dim: {cfg_flow.latent_dim}")
logger.info(f"Flow layers: {cfg_flow.ae.module.num_flow_layers}")
logger.info(f"Hidden dim: {cfg_flow.ae.module.hidden_dim}")

2025-10-27 18:01:30 - 
2025-10-27 18:01:30 - Starting: EXPERIMENT 2: Flow Autoencoder (Invertible)
2025-10-27 18:01:30 - Building Flow AE model...
2025-10-27 18:01:30 - Autoencoder type: FlowChannelAE
2025-10-27 18:01:30 - Latent dim: 50
2025-10-27 18:01:30 - Flow layers: 4
2025-10-27 18:01:30 - Hidden dim: 128


In [None]:
# Test invertibility for Flow AE
test_batch = next(iter(data_train))
session_id = held_in_session_ids[0]
x_test = test_batch[session_id][0][:10].to(DEVICE)

ae_flow = ms_flow.ae.instances[session_id]
mask = torch.arange(x_test.shape[-1]).to(DEVICE)

with torch.no_grad():
    z = ae_flow.encode(x_test, mask)
    x_recon = ae_flow.decode(z, mask)

error = (x_test - x_recon).abs().max().item()
mse = ((x_test - x_recon)**2).mean().item()
logger.info(f"\nFlow AE Reconstruction Test:")
logger.info(f"  Max absolute error: {error:.2e}")
logger.info(f"  MSE: {mse:.2e}")
if error < 1e-5:
    logger.info(f"  ✓ INVERTIBLE (error < 1e-5)")
else:
    logger.info(f"  ✗ Not perfectly invertible (error = {error:.2e})")

2025-10-27 18:01:33 - 
Flow AE Reconstruction Test:
2025-10-27 18:01:33 -   Max absolute error: 4.93e-05
2025-10-27 18:01:33 -   MSE: 3.82e-11
2025-10-27 18:01:33 -   ✗ Not perfectly invertible (error = 4.93e-05)


In [None]:
# Train Flow AE with logging
logger.start_training(cfg_flow.training.epochs, "Flow AE Training")

embeddings_stim_flow, results_flow = train_with_logging(
    cfg_flow,
    ms_flow,
    data_train,
    model_optims_flow,
    embeddings_rest,
    data_test=data_test,
    test_interval=1000,
    epochs=cfg_flow.training.epochs
)

logger.end_training("Flow AE Training")

2025-10-27 18:03:34 - 
2025-10-27 18:03:34 - Starting: Flow AE Training
2025-10-27 18:03:34 - Epoch 0/7001 (0.0%) | Train Loss: 0.953443 | Test Loss: 1.074933 | Train R²: 0.323728 | Test R²: 0.433621 | Elapsed: 00:00 | ETA: 40:19
2025-10-27 18:07:44 - Epoch 1000/7001 (14.3%) | Train Loss: 0.622717 | Test Loss: 0.992339 | Train R²: 0.400810 | Test R²: 0.476995 | Elapsed: 04:09 | ETA: 24:58
2025-10-27 18:11:52 - Epoch 2000/7001 (28.6%) | Train Loss: 0.589073 | Test Loss: 0.956695 | Train R²: 0.434432 | Test R²: 0.495777 | Elapsed: 08:18 | ETA: 20:45


## Results Comparison

In [None]:
logger.start_phase("Results Analysis")

# Summary table
logger.info("\n" + "="*80)
logger.info("RESULTS SUMMARY")
logger.info("="*80)
logger.info(f"{'Method':<20} {'Test R2':<15} {'Test Loss':<15} {'Train R2':<15}")
logger.info("-"*80)

linear_test_r2 = results_linear['final_test_r2']
linear_test_loss = results_linear['final_test_loss']
linear_train_r2 = results_linear['train_r2s'][-1][1]

flow_test_r2 = results_flow['final_test_r2']
flow_test_loss = results_flow['final_test_loss']
flow_train_r2 = results_flow['train_r2s'][-1][1]

logger.info(f"{'Linear AE':<20} {linear_test_r2:<15.6f} {linear_test_loss:<15.6f} {linear_train_r2:<15.6f}")
logger.info(f"{'Flow AE':<20} {flow_test_r2:<15.6f} {flow_test_loss:<15.6f} {flow_train_r2:<15.6f}")

# Calculate improvement
r2_improvement = ((flow_test_r2 - linear_test_r2) / abs(linear_test_r2)) * 100
loss_improvement = ((linear_test_loss - flow_test_loss) / linear_test_loss) * 100

logger.info("-"*80)
logger.info(f"{'Improvement':<20} {r2_improvement:+.2f}% {'':>6} {loss_improvement:+.2f}%")
logger.info("="*80)

# Per-session results
logger.info("\nPer-session test R2:")
logger.info(f"{'Session':<40} {'Linear AE':<15} {'Flow AE':<15} {'Δ':<15}")
logger.info("-"*80)
for session_id in results_linear['final_test_r2s'].keys():
    linear_r2 = results_linear['final_test_r2s'][session_id]
    flow_r2 = results_flow['final_test_r2s'][session_id]
    delta = flow_r2 - linear_r2
    logger.info(f"{session_id:<40} {linear_r2:<15.6f} {flow_r2:<15.6f} {delta:+.6f}")

logger.end_phase("Results Analysis")

## Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Training loss
ax = axes[0, 0]
linear_epochs = [x[0] for x in results_linear['train_losses']]
linear_losses = [x[1] for x in results_linear['train_losses']]
flow_epochs = [x[0] for x in results_flow['train_losses']]
flow_losses = [x[1] for x in results_flow['train_losses']]

ax.plot(linear_epochs, linear_losses, label='Linear AE', alpha=0.7)
ax.plot(flow_epochs, flow_losses, label='Flow AE', alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Training Loss')
ax.set_title('Training Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Test loss
ax = axes[0, 1]
linear_test_epochs = [x[0] for x in results_linear['test_losses']]
linear_test_losses = [x[1] for x in results_linear['test_losses']]
flow_test_epochs = [x[0] for x in results_flow['test_losses']]
flow_test_losses = [x[1] for x in results_flow['test_losses']]

ax.plot(linear_test_epochs, linear_test_losses, label='Linear AE', marker='o', alpha=0.7)
ax.plot(flow_test_epochs, flow_test_losses, label='Flow AE', marker='s', alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Test Loss')
ax.set_title('Test Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Train R2
ax = axes[1, 0]
linear_r2_epochs = [x[0] for x in results_linear['train_r2s']]
linear_r2_values = [x[1] for x in results_linear['train_r2s']]
flow_r2_epochs = [x[0] for x in results_flow['train_r2s']]
flow_r2_values = [x[1] for x in results_flow['train_r2s']]

ax.plot(linear_r2_epochs, linear_r2_values, label='Linear AE', marker='o', alpha=0.7)
ax.plot(flow_r2_epochs, flow_r2_values, label='Flow AE', marker='s', alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Train R²')
ax.set_title('Train R²')
ax.legend()
ax.grid(True, alpha=0.3)

# Test R2
ax = axes[1, 1]
linear_test_r2_epochs = [x[0] for x in results_linear['test_r2s']]
linear_test_r2_values = [x[1] for x in results_linear['test_r2s']]
flow_test_r2_epochs = [x[0] for x in results_flow['test_r2s']]
flow_test_r2_values = [x[1] for x in results_flow['test_r2s']]

ax.plot(linear_test_r2_epochs, linear_test_r2_values, label='Linear AE', marker='o', alpha=0.7)
ax.plot(flow_test_r2_epochs, flow_test_r2_values, label='Flow AE', marker='s', alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Test R²')
ax.set_title('Test R²')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

logger.info("Generated training curves plot")

In [None]:
# Plot example predictions
session_id = held_in_session_ids[0]
channel_idx = 30
trial_idx = 0

fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Linear AE
ax = axes[0]
y_true = results_linear['y_test'][session_id][2][trial_idx, :, channel_idx].cpu()
y_pred = results_linear['y_hat_test'][session_id][trial_idx, :, channel_idx].cpu()
ax.plot(y_true, label='True', linewidth=2)
ax.plot(y_pred, label='Predicted', linewidth=2, alpha=0.7)
ax.set_xlabel('Time')
ax.set_ylabel('Activity')
ax.set_title(f'Linear AE - Channel {channel_idx}')
ax.legend()
ax.grid(True, alpha=0.3)

# Flow AE
ax = axes[1]
y_true = results_flow['y_test'][session_id][2][trial_idx, :, channel_idx].cpu()
y_pred = results_flow['y_hat_test'][session_id][trial_idx, :, channel_idx].cpu()
ax.plot(y_true, label='True', linewidth=2)
ax.plot(y_pred, label='Predicted', linewidth=2, alpha=0.7)
ax.set_xlabel('Time')
ax.set_ylabel('Activity')
ax.set_title(f'Flow AE - Channel {channel_idx}')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

logger.info("Generated prediction comparison plot")

## Latent Space Analysis

In [None]:
# Compare latent representations
session_id = held_in_session_ids[0]
test_batch = next(iter(data_test))
x = test_batch[session_id][0][:100].to(DEVICE)
mask = torch.arange(x.shape[-1]).to(DEVICE)

# Get latent representations
with torch.no_grad():
    z_linear = ms_linear.ae.instances[session_id].encode(x, mask).cpu().numpy()
    z_flow = ms_flow.ae.instances[session_id].encode(x, mask).cpu().numpy()

logger.info(f"\nLatent Space Analysis:")
logger.info(f"  Linear AE latent shape: {z_linear.shape}")
logger.info(f"  Flow AE latent shape: {z_flow.shape}")
logger.info(f"  Linear AE latent std: {z_linear.std(axis=0).mean():.4f}")
logger.info(f"  Flow AE latent std: {z_flow.std(axis=0).mean():.4f}")

In [None]:
# Visualize first 2 latent dimensions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Linear AE
ax = axes[0]
scatter = ax.scatter(z_linear[:, 0], z_linear[:, 1], 
                     c=range(len(z_linear)), cmap='viridis', alpha=0.6)
ax.set_xlabel('Latent Dim 1')
ax.set_ylabel('Latent Dim 2')
ax.set_title('Linear AE Latent Space')
ax.grid(True, alpha=0.3)
plt.colorbar(scatter, ax=ax, label='Sample Index')

# Flow AE
ax = axes[1]
scatter = ax.scatter(z_flow[:, 0], z_flow[:, 1], 
                     c=range(len(z_flow)), cmap='viridis', alpha=0.6)
ax.set_xlabel('Latent Dim 1')
ax.set_ylabel('Latent Dim 2')
ax.set_title('Flow AE Latent Space')
ax.grid(True, alpha=0.3)
plt.colorbar(scatter, ax=ax, label='Sample Index')

plt.tight_layout()
plt.show()

logger.info("Generated latent space visualization")

## Save Results

In [None]:
logger.start_phase("Saving Results")

# Save comparison results
comparison_results = {
    'linear': results_linear,
    'flow': results_flow,
    'config_linear': OmegaConf.to_container(cfg_linear, resolve=True),
    'config_flow': OmegaConf.to_container(cfg_flow, resolve=True),
    'summary': {
        'linear_test_r2': linear_test_r2,
        'flow_test_r2': flow_test_r2,
        'r2_improvement_pct': r2_improvement,
        'linear_test_loss': linear_test_loss,
        'flow_test_loss': flow_test_loss,
        'loss_improvement_pct': loss_improvement,
    },
    'log_file': str(logger.log_file),
}

results_path = 'ae_comparison_results.torch'
torch.save(comparison_results, results_path)
logger.info(f"Results saved to {results_path}")

logger.end_phase("Saving Results")

logger.info("\n" + "="*80)
logger.info("EXPERIMENT COMPLETED SUCCESSFULLY")
logger.info(f"Log file: {logger.log_file}")
logger.info("="*80)

## Conclusions

### Key Findings:
1. **Invertibility**: Flow AE achieves near-perfect reconstruction (error < 1e-5), while Linear AE has reconstruction loss
2. **Performance**: See comparison table above and log file for detailed metrics
3. **Training**: Both models converge successfully with time tracking

### When to Use Each:
- **Linear AE**: Faster, more interpretable (PCA-like), good baseline
- **Flow AE**: Invertible, more expressive, handles nonlinear structure

### Logs:
- All training progress, metrics, and timing information saved to the log file
- Check the `logs/` folder for detailed session logs