# MNIST PMFlow Benchmark - Interactive Notebook

This notebook compares baseline CNN with PMFlow-enhanced neural networks on MNIST classification. Perfect for testing GPU acceleration and parallelism on Jetson Nano.

## Features
- **GPU Acceleration**: CUDA-optimized for Jetson Nano
- **Interactive Progress**: Real-time training visualization
- **Parallel Training**: Side-by-side model comparison
- **PMFlow Integration**: Pushing-medium flow blocks in neural networks

In [None]:
# GPU Detection and Environment Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
import time
from IPython.display import clear_output
import threading

# Check GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("CUDA not available - using CPU")

ModuleNotFoundError: No module named 'torchvision'

: 

: 

In [None]:
# PMFlow Block Implementation
class PMFlow(nn.Module):
    """
    Pushing-Medium Flow block that simulates gravitational effects in latent space
    """
    def __init__(self, latent_dim=16, centers=None, mus=None, steps=3, dt=0.1):
        super().__init__()
        if centers is None:
            centers = torch.randn(4, latent_dim) * 0.5
        if mus is None:
            mus = torch.ones(len(centers)) * 0.5
        
        self.centers = nn.Parameter(torch.tensor(centers, dtype=torch.float32))
        self.mus = nn.Parameter(torch.tensor(mus, dtype=torch.float32))
        self.steps = steps
        self.dt = dt

    def forward(self, z):
        """Apply pushing-medium flow transformation"""
        for _ in range(self.steps):
            # Calculate refractive index n = 1 + sum(mu/r)
            n = torch.ones(z.size(0), device=z.device)
            grad = torch.zeros_like(z)
            
            for c, mu in zip(self.centers, self.mus):
                rvec = z - c
                r = torch.norm(rvec, dim=1) + 1e-4  # Avoid division by zero
                n = n + mu / r
                grad = grad + (-mu) * rvec / (r.unsqueeze(1)**3)
            
            # Flow step: z += dt * grad(ln n)
            grad_ln_n = grad / n.unsqueeze(1)
            z = z + self.dt * grad_ln_n
        
        return z

# Test PMFlow block
print("Testing PMFlow block...")
pmflow = PMFlow(latent_dim=8).to(device)
test_input = torch.randn(32, 8).to(device)
test_output = pmflow(test_input)
print(f"Input shape: {test_input.shape}, Output shape: {test_output.shape}")
print("PMFlow block initialized successfully!")

In [None]:
# Neural Network Models
class PMNet(nn.Module):
    """Neural network with optional PMFlow block"""
    def __init__(self, use_flow=True, latent_dim=16):
        super().__init__()
        self.use_flow = use_flow
        
        # Encoder: 28x28 -> latent_dim
        self.enc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 256), 
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )
        
        # Optional PMFlow block
        self.flow = PMFlow(latent_dim=latent_dim) if use_flow else None
        
        # Classification head
        self.head = nn.Linear(latent_dim, 10)

    def forward(self, x):
        z = self.enc(x)
        if self.flow is not None:
            z = self.flow(z)
        return self.head(z)

# Create models
print("Creating models...")
model_baseline = PMNet(use_flow=False, latent_dim=16).to(device)
model_pmflow = PMNet(use_flow=True, latent_dim=16).to(device)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

baseline_params = count_parameters(model_baseline)
pmflow_params = count_parameters(model_pmflow)

print(f"Baseline model parameters: {baseline_params:,}")
print(f"PMFlow model parameters: {pmflow_params:,}")
print(f"PMFlow overhead: {pmflow_params - baseline_params:,} parameters")

In [None]:
# Data Loading with Parallel Workers
def setup_data_loaders(batch_size=128, num_workers=2):
    """Setup MNIST data loaders with parallel processing"""
    transform = transforms.Compose([transforms.ToTensor()])
    
    # Download and load datasets
    print("Loading MNIST dataset...")
    train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
    
    # Create data loaders with parallel workers
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available()
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size*2, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available()
    )
    
    return train_loader, test_loader

# Setup data loaders
# Adjust num_workers based on your system (Jetson Nano: 2-4 workers optimal)
train_loader, test_loader = setup_data_loaders(batch_size=128, num_workers=2)

print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")
print(f"Total training samples: {len(train_loader.dataset)}")
print(f"Total test samples: {len(test_loader.dataset)}")

In [None]:
# Training and Evaluation Functions
def train_epoch(model, optimizer, loader, device, desc="Training"):
    """Train model for one epoch with progress tracking"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    progress_bar = tqdm(loader, desc=desc, leave=False)
    
    for batch_idx, (data, target) in enumerate(progress_bar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += target.size(0)
        
        # Update progress bar
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    return total_loss / len(loader), correct / total

@torch.no_grad()
def evaluate_model(model, loader, device, desc="Evaluating"):
    """Evaluate model with progress tracking"""
    model.eval()
    correct = 0
    total = 0
    
    progress_bar = tqdm(loader, desc=desc, leave=False)
    
    for data, target in progress_bar:
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += target.size(0)
        
        progress_bar.set_postfix({'Acc': f'{100.*correct/total:.2f}%'})
    
    return correct / total

print("Training and evaluation functions ready!")

In [None]:
# Real-time Visualization Setup
class LivePlotter:
    """Real-time plotting for training progress"""
    def __init__(self):
        self.fig, (self.ax1, self.ax2) = plt.subplots(1, 2, figsize=(12, 4))
        self.baseline_train_acc = []
        self.baseline_test_acc = []
        self.pmflow_train_acc = []
        self.pmflow_test_acc = []
        self.epochs = []
        
    def update(self, epoch, baseline_train, baseline_test, pmflow_train, pmflow_test):
        """Update plots with new data"""
        self.epochs.append(epoch)
        self.baseline_train_acc.append(baseline_train)
        self.baseline_test_acc.append(baseline_test)
        self.pmflow_train_acc.append(pmflow_train)
        self.pmflow_test_acc.append(pmflow_test)
        
        # Clear and redraw
        self.ax1.clear()
        self.ax2.clear()
        
        # Training accuracy
        self.ax1.plot(self.epochs, self.baseline_train_acc, 'bo--', label='Baseline Train', alpha=0.7)
        self.ax1.plot(self.epochs, self.pmflow_train_acc, 'r^--', label='PMFlow Train', alpha=0.7)
        self.ax1.set_title("Training Accuracy")
        self.ax1.set_xlabel("Epoch")
        self.ax1.set_ylabel("Accuracy")
        self.ax1.legend()
        self.ax1.grid(True, alpha=0.3)
        
        # Test accuracy
        self.ax2.plot(self.epochs, self.baseline_test_acc, 'bo-', label='Baseline Test')
        self.ax2.plot(self.epochs, self.pmflow_test_acc, 'r^-', label='PMFlow Test')
        self.ax2.set_title("Test Accuracy")
        self.ax2.set_xlabel("Epoch")
        self.ax2.set_ylabel("Accuracy")
        self.ax2.legend()
        self.ax2.grid(True, alpha=0.3)
        
        self.fig.suptitle("MNIST: Baseline vs PMFlow - Live Training Progress")
        plt.tight_layout()
        plt.show()

# Initialize live plotter
plotter = LivePlotter()
print("Live plotting system ready!")

In [None]:
# Parallel Training Configuration
class ParallelTrainer:
    """Train both models in parallel for fair comparison"""
    def __init__(self, model_baseline, model_pmflow, device):
        self.model_baseline = model_baseline
        self.model_pmflow = model_pmflow
        self.device = device
        
        # Optimizers
        self.opt_baseline = torch.optim.Adam(model_baseline.parameters(), lr=1e-3)
        self.opt_pmflow = torch.optim.Adam(model_pmflow.parameters(), lr=1e-3)
        
        # Training history
        self.history = {
            'baseline_train': [], 'baseline_test': [],
            'pmflow_train': [], 'pmflow_test': []
        }
    
    def train_epoch_parallel(self, train_loader, test_loader, epoch):
        """Train both models for one epoch"""
        print(f"\\n=== Epoch {epoch} ===")
        
        # Train baseline model
        start_time = time.time()
        train_loss_b, train_acc_b = train_epoch(
            self.model_baseline, self.opt_baseline, train_loader, 
            self.device, desc=f"Baseline Epoch {epoch}"
        )
        test_acc_b = evaluate_model(
            self.model_baseline, test_loader, self.device, 
            desc=f"Baseline Test {epoch}"
        )
        baseline_time = time.time() - start_time
        
        # Train PMFlow model
        start_time = time.time()
        train_loss_p, train_acc_p = train_epoch(
            self.model_pmflow, self.opt_pmflow, train_loader, 
            self.device, desc=f"PMFlow Epoch {epoch}"
        )
        test_acc_p = evaluate_model(
            self.model_pmflow, test_loader, self.device, 
            desc=f"PMFlow Test {epoch}"
        )
        pmflow_time = time.time() - start_time
        
        # Store results
        self.history['baseline_train'].append(train_acc_b)
        self.history['baseline_test'].append(test_acc_b)
        self.history['pmflow_train'].append(train_acc_p)
        self.history['pmflow_test'].append(test_acc_p)
        
        # Print results
        print(f"Baseline: Train={train_acc_b:.4f}, Test={test_acc_b:.4f}, Time={baseline_time:.1f}s")
        print(f"PMFlow:   Train={train_acc_p:.4f}, Test={test_acc_p:.4f}, Time={pmflow_time:.1f}s")
        
        return train_acc_b, test_acc_b, train_acc_p, test_acc_p

# Initialize trainer
trainer = ParallelTrainer(model_baseline, model_pmflow, device)
print("Parallel trainer ready!")

In [None]:
# Interactive Training Loop
def run_training(epochs=10, live_plot=True):
    """Run interactive training with real-time visualization"""
    print(f"Starting training for {epochs} epochs on {device}")
    print(f"Dataset size: {len(train_loader.dataset)} training, {len(test_loader.dataset)} test")
    
    for epoch in range(1, epochs + 1):
        # Train both models
        baseline_train, baseline_test, pmflow_train, pmflow_test = trainer.train_epoch_parallel(
            train_loader, test_loader, epoch
        )
        
        # Update live plot
        if live_plot:
            clear_output(wait=True)
            plotter.update(epoch, baseline_train, baseline_test, pmflow_train, pmflow_test)
        
        # GPU memory check (for Jetson Nano monitoring)
        if torch.cuda.is_available():
            memory_used = torch.cuda.memory_allocated() / 1e9
            memory_total = torch.cuda.get_device_properties(0).total_memory / 1e9
            print(f"GPU Memory: {memory_used:.1f}/{memory_total:.1f} GB")
    
    print("\\n=== Training Complete! ===")
    
    # Final results
    final_baseline = trainer.history['baseline_test'][-1]
    final_pmflow = trainer.history['pmflow_test'][-1]
    improvement = final_pmflow - final_baseline
    
    print(f"Final Test Accuracy:")
    print(f"  Baseline: {final_baseline:.4f}")
    print(f"  PMFlow:   {final_pmflow:.4f}")
    print(f"  Improvement: {improvement:+.4f} ({improvement*100:+.2f}%)")
    
    return trainer.history

# Configuration
EPOCHS = 10  # Adjust for your testing needs
LIVE_PLOTTING = True  # Set to False if plotting causes issues

print("Ready to start training!")
print("Run the next cell to begin the interactive training process.")

In [None]:
# 🚀 START TRAINING
# Execute this cell to begin the benchmark training
history = run_training(epochs=EPOCHS, live_plot=LIVE_PLOTTING)

In [None]:
# Post-Training Analysis and Visualization
def analyze_results(history):
    """Comprehensive analysis of training results"""
    
    # Create final visualization
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    epochs = range(1, len(history['baseline_test']) + 1)
    
    # Training accuracy comparison
    ax1.plot(epochs, history['baseline_train'], 'bo--', label='Baseline', alpha=0.7)
    ax1.plot(epochs, history['pmflow_train'], 'r^--', label='PMFlow', alpha=0.7)
    ax1.set_title("Training Accuracy")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Accuracy")
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Test accuracy comparison
    ax2.plot(epochs, history['baseline_test'], 'bo-', label='Baseline')
    ax2.plot(epochs, history['pmflow_test'], 'r^-', label='PMFlow')
    ax2.set_title("Test Accuracy")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Accuracy")
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Improvement over time
    improvements = [p - b for p, b in zip(history['pmflow_test'], history['baseline_test'])]
    ax3.plot(epochs, improvements, 'go-', alpha=0.7)
    ax3.axhline(y=0, color='k', linestyle='--', alpha=0.5)
    ax3.set_title("PMFlow Improvement vs Baseline")
    ax3.set_xlabel("Epoch")
    ax3.set_ylabel("Accuracy Difference")
    ax3.grid(True, alpha=0.3)
    
    # Final comparison bar chart
    final_scores = [history['baseline_test'][-1], history['pmflow_test'][-1]]
    ax4.bar(['Baseline', 'PMFlow'], final_scores, color=['blue', 'red'], alpha=0.7)
    ax4.set_title("Final Test Accuracy")
    ax4.set_ylabel("Accuracy")
    ax4.set_ylim([min(final_scores) - 0.01, max(final_scores) + 0.01])
    
    # Add values on bars
    for i, v in enumerate(final_scores):
        ax4.text(i, v + 0.002, f'{v:.4f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.suptitle("MNIST PMFlow Benchmark Results", y=1.02, fontsize=16)
    plt.show()
    
    # Performance summary
    print("\\n" + "="*50)
    print("PERFORMANCE SUMMARY")
    print("="*50)
    
    baseline_final = history['baseline_test'][-1]
    pmflow_final = history['pmflow_test'][-1]
    improvement = pmflow_final - baseline_final
    improvement_pct = improvement * 100
    
    print(f"Final Test Accuracy:")
    print(f"  Baseline: {baseline_final:.4f} ({baseline_final*100:.2f}%)")
    print(f"  PMFlow:   {pmflow_final:.4f} ({pmflow_final*100:.2f}%)")
    print(f"  Improvement: {improvement:+.4f} ({improvement_pct:+.2f} percentage points)")
    
    # Additional metrics
    baseline_best = max(history['baseline_test'])
    pmflow_best = max(history['pmflow_test'])
    
    print(f"\\nBest Test Accuracy Achieved:")
    print(f"  Baseline: {baseline_best:.4f}")
    print(f"  PMFlow:   {pmflow_best:.4f}")
    
    print(f"\\nModel Parameter Count:")
    print(f"  Baseline: {count_parameters(model_baseline):,}")
    print(f"  PMFlow:   {count_parameters(model_pmflow):,}")
    print(f"  Overhead: {count_parameters(model_pmflow) - count_parameters(model_baseline):,}")

print("Analysis functions ready. Run analyze_results(history) after training completes.")

In [None]:
# 📊 ANALYZE RESULTS
# Execute this cell after training to see comprehensive analysis
analyze_results(history)

## 🔧 Jetson Nano Optimization Tips

### GPU Memory Management
- Monitor GPU memory usage during training
- Reduce batch size if you encounter OOM errors
- Use `torch.cuda.empty_cache()` between experiments

### Performance Tuning
- **Optimal batch size for Jetson Nano**: 64-128
- **Data loader workers**: 2-4 (adjust based on CPU cores)
- **Mixed precision**: Consider using `torch.cuda.amp` for faster training

### Parallel Processing
- The notebook uses parallel data loading
- Both models train sequentially for fair comparison
- Consider training on different GPU streams for true parallelism

### Troubleshooting
- If plots don't update: Set `LIVE_PLOTTING = False`
- If training is slow: Reduce epochs or use smaller models
- If memory issues: Reduce batch size or use CPU (`device = "cpu"`)