# MICROnet Training - Comprehensive Model Comparison

This notebook trains 16 different microstructure prediction models using U-Net architecture with skip connections:

## CNN-LSTM Models (8 models):
### MSE Loss:
1. seq=2, MSE loss
2. seq=3, MSE loss
3. seq=4, MSE loss

### Combined Loss (seq=3):
4. T_solidus=1560, T_liquidus=1620, weights=70/30
5. T_solidus=1530, T_liquidus=1650, weights=70/30
6. T_solidus=1500, T_liquidus=1680, weights=70/30
7. T_solidus=1560, T_liquidus=1620, weights=50/50
8. T_solidus=1560, T_liquidus=1620, weights=30/70

## PredRNN Models (8 models):
### MSE Loss:
9. seq=2, MSE loss
10. seq=3, MSE loss
11. seq=4, MSE loss

### Combined Loss (seq=2):
12. T_solidus=1560, T_liquidus=1620, weights=70/30
13. T_solidus=1530, T_liquidus=1650, weights=70/30
14. T_solidus=1500, T_liquidus=1680, weights=70/30
15. T_solidus=1560, T_liquidus=1620, weights=50/50
16. T_solidus=1560, T_liquidus=1620, weights=30/70

All models use U-Net architecture with skip connections for improved gradient flow and feature reuse.

In [None]:
# Add project root to Python path so we can import lasernet
import sys
from pathlib import Path

project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
    print(f"Added {project_root} to Python path")

In [None]:
import os

from pathlib import Path
import json
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from IPython.display import display, Image

from lasernet.micronet.train.trainer import (
    get_device,
    train_microstructure,
    load_model_and_predict,
    load_model_and_predict_cascaded,
    save_prediction_visualization,
    save_solidification_mask_visualization,
    save_cascaded_prediction_visualization,
)
from lasernet.micronet.model.MicrostructureCNN_LSTM import MicrostructureCNN_LSTM
from lasernet.micronet.model.MicrostructurePredRNN import MicrostructurePredRNN
from lasernet.micronet.model.losses import CombinedLoss
from lasernet.micronet.utils import plot_losses

## Set Random Seeds for Reproducibility

Set all random seeds to ensure reproducible results across runs.

In [None]:
import random
import numpy as np

# Set random seed for reproducibility
RANDOM_SEED = 42

# Set Python random seed
random.seed(RANDOM_SEED)

# Set NumPy random seed
np.random.seed(RANDOM_SEED)

# Set PyTorch random seeds
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)  # For multi-GPU setups

# Configure PyTorch to use deterministic algorithms
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set environment variable for additional determinism
os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED)

print(f"Random seed set to {RANDOM_SEED} for reproducibility")
print("  ✓ Python random")
print("  ✓ NumPy")
print("  ✓ PyTorch (CPU and CUDA)")
print("  ✓ cuDNN deterministic mode enabled")

## Setup and Configuration

In [None]:
from dotenv import load_dotenv
from pathlib import Path
import os

# Load environment variables from .env file in project root
project_root = Path.cwd().parent
env_file = project_root / ".env"
if env_file.exists():
    load_dotenv(dotenv_path=env_file, override=True)
    print(f"Loaded .env from: {env_file}")
else:
    print(f"Warning: .env file not found at {env_file}")

# Verify BLACKHOLE is set
blackhole_path = os.environ.get("BLACKHOLE")
if blackhole_path:
    print(f"BLACKHOLE environment variable: {blackhole_path}")
else:
    raise ValueError("BLACKHOLE environment variable not set. Please create a .env file with BLACKHOLE=/path/to/data")

# Global configuration
DEVICE = get_device()
BATCH_SIZE = 16
EPOCHS = 100
LEARNING_RATE = 1e-3
PATIENCE = 25
OUTPUT_DIR = Path("MICROnet_output")
OUTPUT_DIR.mkdir(exist_ok=True)

# Prediction settings
PRED_TIMESTEP = 23
PRED_SLICE = 47
PLANE = "xz"

print(f"Device: {DEVICE}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Prediction settings: timestep={PRED_TIMESTEP}, slice={PRED_SLICE}, plane={PRED_TIMESTEP}")

## Load Datasets

Load datasets using fast loading from preprocessed files.

In [None]:
# Load datasets (same approach as notebook.ipynb)
print("Loading datasets...")

from pathlib import Path as PathLib
from lasernet.micronet.dataset import MicrostructureSequenceDataset
from lasernet.micronet.dataset.fast_loading import FastMicrostructureSequenceDataset
from torch.utils.data import DataLoader

# Configuration
SEQ_LENGTH = 4  # Default sequence length for dataset creation
SPLIT_RATIO = "12,6,6"

# Parse split ratios
split_ratios = list(map(int, SPLIT_RATIO.split(",")))
train_ratio = split_ratios[0] / sum(split_ratios)
val_ratio = split_ratios[1] / sum(split_ratios)
test_ratio = split_ratios[2] / sum(split_ratios)

# Check if preprocessed files are available for fast loading
blackhole = os.environ.get("BLACKHOLE")
if not blackhole:
    raise ValueError("BLACKHOLE environment variable not set. Please set it in the makefile or shell.")

print(f"BLACKHOLE directory: {blackhole}")

processed_dir = PathLib(blackhole) / "processed" / "data"
required_files = ["coordinates.pt", "microstructure.pt", "temperature.pt"]
fast_loading_available = all((processed_dir / f).exists() for f in required_files)

if fast_loading_available:
    print("✓ Preprocessed files found - using fast loading")
    DatasetClass = FastMicrostructureSequenceDataset
    dataset_kwargs = {
        "plane": PLANE,
        "split": "train",  # will be overridden for each dataset
        "sequence_length": SEQ_LENGTH,
        "target_offset": 1,
        "train_ratio": train_ratio,
        "val_ratio": val_ratio,
        "test_ratio": test_ratio,
    }
else:
    print("⚠ Preprocessed files not found - using standard loading")
    DatasetClass = MicrostructureSequenceDataset
    dataset_kwargs = {
        "plane": PLANE,
        "split": "train",  # will be overridden for each dataset
        "sequence_length": SEQ_LENGTH,
        "target_offset": 1,
        "preload": True,
        "train_ratio": train_ratio,
        "val_ratio": val_ratio,
        "test_ratio": test_ratio,
    }

train_dataset = DatasetClass(**{**dataset_kwargs, "split": "train"})
val_dataset = DatasetClass(**{**dataset_kwargs, "split": "val"})
test_dataset = DatasetClass(**{**dataset_kwargs, "split": "test"})

print(f"\nDataset sizes:")
print(f"  Train samples: {len(train_dataset)}")
print(f"  Val samples:   {len(val_dataset)}")
print(f"  Test samples:  {len(test_dataset)}")

# Show sample dimensions
sample = train_dataset[0]
print(f"\nSample dimensions:")
print(f"  Context temp:  {sample['context_temp'].shape}")
print(f"  Context micro: {sample['context_micro'].shape}")
print(f"  Future temp:   {sample['future_temp'].shape}")
print(f"  Target micro:  {sample['target_micro'].shape}")

# Create DataLoaders
print("\nCreating DataLoaders...")

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
)

print(f"DataLoaders created with batch size: {BATCH_SIZE}")

## Model Configurations

Define all 10 model configurations.

In [None]:
# Define all model configurations
MODEL_CONFIGS = [
    # CNN-LSTM models with MSE loss
    {
        "id": 1,
        "name": "01_MICROnet_cnn_lstm_seq2_MSEloss",
        "model_type": "cnn_lstm",
        "seq_length": 2,
        "loss_type": "mse",
        "t_solidus": None,
        "t_liquidus": None,
        "use_skip_connections": True,
    },
    {
        "id": 2,
        "name": "02_MICROnet_cnn_lstm_seq3_MSEloss",
        "model_type": "cnn_lstm",
        "seq_length": 3,
        "loss_type": "mse",
        "t_solidus": None,
        "t_liquidus": None,
        "use_skip_connections": True,
    },
    {
        "id": 3,
        "name": "03_MICROnet_cnn_lstm_seq4_MSEloss",
        "model_type": "cnn_lstm",
        "seq_length": 4,
        "loss_type": "mse",
        "t_solidus": None,
        "t_liquidus": None,
        "use_skip_connections": True,
    },
    # CNN-LSTM models with Combined loss
    {
        "id": 4,
        "name": "04_MICROnet_cnn_lstm_seq3_CombLoss_T1560-1620_s70_g30",
        "model_type": "cnn_lstm",
        "seq_length": 3,
        "loss_type": "combined",
        "t_solidus": 1560.0,
        "t_liquidus": 1620.0,
        "solidification_weight": 0.7,
        "global_weight": 0.3,
        "use_skip_connections": True,
    },
    {
        "id": 5,
        "name": "05_MICROnet_cnn_lstm_seq3_CombLoss_T1530-1650_s70_g30",
        "model_type": "cnn_lstm",
        "seq_length": 3,
        "loss_type": "combined",
        "t_solidus": 1530.0,
        "t_liquidus": 1650.0,
        "solidification_weight": 0.7,
        "global_weight": 0.3,
        "use_skip_connections": True,
    },
    {
        "id": 6,
        "name": "06_MICROnet_cnn_lstm_seq3_CombLoss_T1500-1680_s70_g30",
        "model_type": "cnn_lstm",
        "seq_length": 3,
        "loss_type": "combined",
        "t_solidus": 1500.0,
        "t_liquidus": 1680.0,
        "solidification_weight": 0.7,
        "global_weight": 0.3,
        "use_skip_connections": True,
    },
    {
        "id": 7,
        "name": "07_MICROnet_cnn_lstm_seq3_CombLoss_T1560-1620_s50_g50",
        "model_type": "cnn_lstm",
        "seq_length": 3,
        "loss_type": "combined",
        "t_solidus": 1560.0,
        "t_liquidus": 1620.0,
        "solidification_weight": 0.5,
        "global_weight": 0.5,
        "use_skip_connections": True,
    },
    {
        "id": 8,
        "name": "08_MICROnet_cnn_lstm_seq3_CombLoss_T1560-1620_s30_g70",
        "model_type": "cnn_lstm",
        "seq_length": 3,
        "loss_type": "combined",
        "t_solidus": 1560.0,
        "t_liquidus": 1620.0,
        "solidification_weight": 0.3,
        "global_weight": 0.7,
        "use_skip_connections": True,
    },
    # PredRNN models with MSE loss
    {
        "id": 9,
        "name": "09_MICROnet_predrnn_seq2_MSEloss",
        "model_type": "predrnn",
        "seq_length": 2,
        "loss_type": "mse",
        "t_solidus": None,
        "t_liquidus": None,
        "use_skip_connections": True,
    },
    {
        "id": 10,
        "name": "10_MICROnet_predrnn_seq3_MSEloss",
        "model_type": "predrnn",
        "seq_length": 3,
        "loss_type": "mse",
        "t_solidus": None,
        "t_liquidus": None,
        "use_skip_connections": True,
    },
    {
        "id": 11,
        "name": "11_MICROnet_predrnn_seq4_MSEloss",
        "model_type": "predrnn",
        "seq_length": 4,
        "loss_type": "mse",
        "t_solidus": None,
        "t_liquidus": None,
        "use_skip_connections": True,
    },
    # PredRNN models with Combined loss
    {
        "id": 12,
        "name": "12_MICROnet_predrnn_seq2_CombLoss_T1560-1620_s70_g30",
        "model_type": "predrnn",
        "seq_length": 2,
        "loss_type": "combined",
        "t_solidus": 1560.0,
        "t_liquidus": 1620.0,
        "solidification_weight": 0.7,
        "global_weight": 0.3,
        "use_skip_connections": True,
    },
    {
        "id": 13,
        "name": "13_MICROnet_predrnn_seq2_CombLoss_T1530-1650_s70_g30",
        "model_type": "predrnn",
        "seq_length": 2,
        "loss_type": "combined",
        "t_solidus": 1530.0,
        "t_liquidus": 1650.0,
        "solidification_weight": 0.7,
        "global_weight": 0.3,
        "use_skip_connections": True,
    },
    {
        "id": 14,
        "name": "14_MICROnet_predrnn_seq2_CombLoss_T1500-1680_s70_g30",
        "model_type": "predrnn",
        "seq_length": 2,
        "loss_type": "combined",
        "t_solidus": 1500.0,
        "t_liquidus": 1680.0,
        "solidification_weight": 0.7,
        "global_weight": 0.3,
        "use_skip_connections": True,
    },
    {
        "id": 15,
        "name": "15_MICROnet_predrnn_seq2_CombLoss_T1560-1620_s50_g50",
        "model_type": "predrnn",
        "seq_length": 2,
        "loss_type": "combined",
        "t_solidus": 1560.0,
        "t_liquidus": 1620.0,
        "solidification_weight": 0.5,
        "global_weight": 0.5,
        "use_skip_connections": True,
    },
    {
        "id": 15,
        "name": "16_MICROnet_predrnn_seq2_CombLoss_T1560-1620_s30_g70",
        "model_type": "predrnn",
        "seq_length": 2,
        "loss_type": "combined",
        "t_solidus": 1560.0,
        "t_liquidus": 1620.0,
        "solidification_weight": 0.3,
        "global_weight": 0.7,
        "use_skip_connections": True,
    },
]

print(f"Configured {len(MODEL_CONFIGS)} models for training:")
print(f"\n{'ID':<4} {'Model':<10} {'Seq':<5} {'Loss':<40}")
print("-" * 65)
for cfg in MODEL_CONFIGS:
    skip_str = " (U-Net)" if cfg.get('use_skip_connections', False) else ""
    loss_str = f"MSE" if cfg['loss_type'] == 'mse' else f"Combined T{cfg['t_solidus']:.0f}-{cfg['t_liquidus']:.0f}"
    print(f"{cfg['id']:<4} {cfg['model_type']:<10} {cfg['seq_length']:<5} {loss_str:<40}{skip_str}")

## Training Loop

Train each model. Skip if already trained.

In [None]:
def train_model(config: dict) -> dict:
    """
    Train a single model based on configuration.
    Skip if already trained.
    """
    run_dir = OUTPUT_DIR / config['name']
    checkpoint_path = run_dir / "checkpoints" / "best_model.pt"

    # Check if already trained
    if checkpoint_path.exists():
        print(f"\n{'='*70}")
        print(f"Model {config['id']}: {config['name']}")
        print(f"{'='*70}")
        print("✓ Model already trained. Skipping training.")

        # Load history if available
        history_path = run_dir / "history.json"
        if history_path.exists():
            with open(history_path, 'r') as f:
                history = json.load(f)
        else:
            history = None

        return {"status": "skipped", "history": history, "run_dir": run_dir}

    # Create directories
    run_dir.mkdir(parents=True, exist_ok=True)
    (run_dir / "checkpoints").mkdir(exist_ok=True)

    print(f"\n{'='*70}")
    print(f"Model {config['id']}: {config['name']}")
    print(f"{'='*70}")
    print(f"  Model type: {config['model_type']}")
    print(f"  Sequence length: {config['seq_length']}")
    print(f"  Loss type: {config['loss_type']}")
    print(f"  Skip connections: {config.get('use_skip_connections', False)}")
    if config['loss_type'] == 'combined':
        print(f"  T_solidus: {config['t_solidus']} K")
        print(f"  T_liquidus: {config['t_liquidus']} K")
        print(f"  Solidification weight: {config.get('solidification_weight', 0.7)}")
        print(f"  Global weight: {config.get('global_weight', 0.3)}")

    # FIX: Create dataloaders with the CORRECT sequence length for THIS model
    print(f"\n  Creating datasets with sequence length {config['seq_length']}...")
    if fast_loading_available:
        model_train_dataset = FastMicrostructureSequenceDataset(
            plane=PLANE,
            split="train",
            sequence_length=config['seq_length'],
            target_offset=1,
            train_ratio=train_ratio,
            val_ratio=val_ratio,
            test_ratio=test_ratio,
        )
        model_val_dataset = FastMicrostructureSequenceDataset(
            plane=PLANE,
            split="val",
            sequence_length=config['seq_length'],
            target_offset=1,
            train_ratio=train_ratio,
            val_ratio=val_ratio,
            test_ratio=test_ratio,
        )
    else:
        model_train_dataset = MicrostructureSequenceDataset(
            plane=PLANE,
            split="train",
            sequence_length=config['seq_length'],
            target_offset=1,
            preload=True,
            train_ratio=train_ratio,
            val_ratio=val_ratio,
            test_ratio=test_ratio,
        )
        model_val_dataset = MicrostructureSequenceDataset(
            plane=PLANE,
            split="val",
            sequence_length=config['seq_length'],
            target_offset=1,
            preload=True,
            train_ratio=train_ratio,
            val_ratio=val_ratio,
            test_ratio=test_ratio,
        )

    model_train_loader = DataLoader(model_train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    model_val_loader = DataLoader(model_val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    print(f"  ✓ Datasets created: train={len(model_train_dataset)}, val={len(model_val_dataset)}")

    # Create model with skip connections option
    use_skip_connections = config.get('use_skip_connections', False)

    if config['model_type'] == 'cnn_lstm':
        model = MicrostructureCNN_LSTM(
            input_channels=10,
            future_channels=1,
            output_channels=9,
            use_skip_connections=use_skip_connections,
        ).to(DEVICE)
    else:  # predrnn
        model = MicrostructurePredRNN(
            input_channels=10,
            future_channels=1,
            output_channels=9,
            use_skip_connections=use_skip_connections,
        ).to(DEVICE)

    param_count = model.count_parameters()
    print(f"  Parameters: {param_count:,}")

    # Create loss function
    if config['loss_type'] == 'mse':
        criterion = nn.MSELoss()
    else:  # combined
        criterion = CombinedLoss(
            solidification_weight=config.get('solidification_weight', 0.7),
            global_weight=config.get('global_weight', 0.3),
            T_solidus=config['t_solidus'],
            T_liquidus=config['t_liquidus'],
            weight_type="gaussian",
            weight_scale=0.1,
            base_weight=0.1,
            return_components=True,  # Enable component tracking for visualization
        )

    # Create optimizer
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Save configuration
    config_dict = {
        "model": {
            "name": config['model_type'],
            "parameters": param_count,
            "sequence_length": config['seq_length'],
            "use_skip_connections": use_skip_connections,
        },
        "training": {
            "epochs": EPOCHS,
            "batch_size": BATCH_SIZE,
            "learning_rate": LEARNING_RATE,
            "patience": PATIENCE,
            "loss_type": config['loss_type'],
        },
        "device": str(DEVICE),
    }

    if config['loss_type'] == 'combined':
        config_dict['training']['t_solidus'] = config['t_solidus']
        config_dict['training']['t_liquidus'] = config['t_liquidus']
        config_dict['training']['solidification_weight'] = config.get('solidification_weight', 0.7)
        config_dict['training']['global_weight'] = config.get('global_weight', 0.3)

    with open(run_dir / "config.json", "w") as f:
        json.dump(config_dict, f, indent=2)

    # Train with the CORRECT dataloaders for this model's sequence length
    print("\nStarting training...")
    history = train_microstructure(
        model=model,
        train_loader=model_train_loader,
        val_loader=model_val_loader,
        optimizer=optimizer,
        criterion=criterion,
        device=DEVICE,
        epochs=EPOCHS,
        run_dir=run_dir,
        patience=PATIENCE,
    )

    # Save history
    with open(run_dir / "history.json", "w") as f:
        json.dump(history, f, indent=2)

    # Plot losses
    plot_losses(history, str(run_dir / "training_losses.png"))

    # Save final model
    torch.save({
        'epoch': len(history['train_loss']),
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'config': config_dict,
        'history': history,
    }, run_dir / "checkpoints" / "final_model.pt")

    print(f"\nTraining complete!")
    print(f"  Final train loss: {history['train_loss'][-1]:.6f}")
    print(f"  Final val loss: {history['val_loss'][-1]:.6f}")
    print(f"  Epochs: {len(history['train_loss'])}")

    return {"status": "trained", "history": history, "run_dir": run_dir}


# Train all models
results = {}
for config in MODEL_CONFIGS:
    results[config['name']] = train_model(config)

print(f"\n{'='*70}")
print("All models processed!")
print(f"{'='*70}")

## Generate Cascaded Predictions

Generate predictions for timestep 23, slice 47 for all models using the cascaded pipeline (TempNet → MicroNet).

This implements the correct theory: given previous microstructure and temperature frames, we first predict the next temperature frame using TempNet, then use that prediction (not ground truth) to predict the next microstructure frame.

In [None]:
# TempNet checkpoint path - UPDATE THIS PATH TO YOUR TRAINED TEMPNET MODEL
TEMPNET_CHECKPOINT = "runs/2024-XX-XX_XX-XX-XX/checkpoints/best_model.pt"  # TODO: Update this path

# Check if TempNet checkpoint exists
from pathlib import Path as PathLib
if not PathLib(TEMPNET_CHECKPOINT).exists():
    print(f"⚠ WARNING: TempNet checkpoint not found at: {TEMPNET_CHECKPOINT}")
    print("Please update TEMPNET_CHECKPOINT path to your trained TempNet model")
    print("Cascaded predictions will not be generated without a trained TempNet model.")
    TEMPNET_AVAILABLE = False
else:
    print(f"✓ TempNet checkpoint found: {TEMPNET_CHECKPOINT}")
    TEMPNET_AVAILABLE = True


def generate_cascaded_prediction(config: dict) -> None:
    """
    Generate cascaded prediction visualization (TempNet → MicroNet).
    This is the theoretically correct approach.
    """
    if not TEMPNET_AVAILABLE:
        print(f"Model {config['id']}: ✗ Skipping (TempNet not available)")
        return
    
    from lasernet.micronet.train.trainer import load_model_and_predict_cascaded, save_cascaded_prediction_visualization
    
    run_dir = OUTPUT_DIR / config['name']
    checkpoint_path = run_dir / "checkpoints" / "best_model.pt"
    pred_path = run_dir / f"pred_t{PRED_TIMESTEP}_s{PRED_SLICE}.png"

    # Check if prediction already exists
    if pred_path.exists():
        # For combined loss models, also check solidification mask
        if config['loss_type'] == 'combined':
            solid_mask_path = run_dir / f"solidification_mask_t{PRED_TIMESTEP}_s{PRED_SLICE}.png"
            if solid_mask_path.exists():
                print(f"Model {config['id']}: ✓ Predictions already exist. Skipping.")
                return
        else:
            print(f"Model {config['id']}: ✓ Prediction already exists. Skipping.")
            return

    if not checkpoint_path.exists():
        print(f"Model {config['id']}: ✗ No checkpoint found")
        return
    
    if not checkpoint_path.exists():
        print(f"Model {config['id']}: ✗ No MicroNet checkpoint found")
        return
    
    print(f"Model {config['id']}: Generating cascaded prediction...")
    
    try:
        pred_micro, target_micro, future_temp, mask, metadata = load_model_and_predict(
            checkpoint_path=str(checkpoint_path),
            timestep=PRED_TIMESTEP,
            slice_index=PRED_SLICE,
            sequence_length=config['seq_length'],
            plane=PLANE,
            device=str(DEVICE),
        )

        # Create loss function for visualization
        if config['loss_type'] == 'mse':
            loss_fn = nn.MSELoss()
        else:  # combined
            loss_fn = CombinedLoss(
                solidification_weight=config.get('solidification_weight', 0.7),
                global_weight=config.get('global_weight', 0.3),
                T_solidus=config['t_solidus'],
                T_liquidus=config['t_liquidus'],
                weight_type="gaussian",
                weight_scale=0.1,
                base_weight=0.1,
                return_components=True,
            )

        # Generate standard prediction visualization
        save_prediction_visualization(
            pred_micro=pred_micro,
            target_micro=target_micro,
            mask=mask,
            save_path=str(pred_path),
            title=f"Model {config['id']}: {config['name']}",
            future_temp=future_temp,
            loss_fn=loss_fn,
        )
        print(f"  ✓ Saved prediction to {pred_path}")

        # Generate solidification mask visualization for combined loss models
        if config['loss_type'] == 'combined':
            solid_mask_path = run_dir / f"solidification_mask_t{PRED_TIMESTEP}_s{PRED_SLICE}.png"
            save_solidification_mask_visualization(
                future_temp=pred_temp,  # Use PREDICTED temperature, not ground truth
                pred_micro=pred_micro,
                target_micro=target_micro,
                mask=mask,
                loss_fn=loss_fn,
                save_path=str(solid_mask_path),
                title=f"Model {config['id']}: {config['name']} (Cascaded with Predicted Temp)",
                timestep=metadata['timestep'],
                slice_coord=metadata['slice_coord'],
            )
            print(f"  ✓ Saved cascaded solidification mask to {solid_mask_path}")
        
    except Exception as e:
        print(f"  ✗ Error: {e}")
        import traceback
        traceback.print_exc()


print(f"\nGenerating predictions for timestep={PRED_TIMESTEP}, slice={PRED_SLICE}...\n")

for config in MODEL_CONFIGS:
    generate_prediction(config)

for config in MODEL_CONFIGS:
    generate_cascaded_prediction(config)

print("\nAll cascaded predictions generated!")

## Display Cascaded Predictions

Display cascaded prediction visualizations showing both TempNet and MicroNet outputs.

In [None]:
print(f"Cascaded Prediction Visualizations (t={PRED_TIMESTEP}, s={PRED_SLICE}):\n")
print("Shows: Ground Truth Temp → Predicted Temp → Ground Truth Micro → Predicted Micro\n")

for config in MODEL_CONFIGS:
    run_dir = OUTPUT_DIR / config['name']
    cascaded_pred_path = run_dir / f"cascaded_pred_t{PRED_TIMESTEP}_s{PRED_SLICE}.png"
    
    if cascaded_pred_path.exists():
        print(f"\nModel {config['id']}: {config['name']}")
        display(Image(filename=str(cascaded_pred_path)))
    else:
        print(f"\nModel {config['id']}: No cascaded prediction found")

## Display Training Losses

Display training loss plots for all models.

In [None]:
print("Training Loss Plots:\n")

for config in MODEL_CONFIGS:
    run_dir = OUTPUT_DIR / config['name']
    loss_plot_path = run_dir / "training_losses.png"

    if loss_plot_path.exists():
        print(f"\nModel {config['id']}: {config['name']}")
        display(Image(filename=str(loss_plot_path)))
    else:
        print(f"\nModel {config['id']}: No training losses plot found")

## Display Predictions

Display prediction visualizations for all models.

In [None]:
print(f"Prediction Visualizations (t={PRED_TIMESTEP}, s={PRED_SLICE}):\n")

for config in MODEL_CONFIGS:
    run_dir = OUTPUT_DIR / config['name']
    pred_path = run_dir / f"pred_t{PRED_TIMESTEP}_s{PRED_SLICE}.png"

    if pred_path.exists():
        print(f"\nModel {config['id']}: {config['name']}")
        display(Image(filename=str(pred_path)))
    else:
        print(f"\nModel {config['id']}: No prediction found")

## Display Solidification Mask Visualizations

Display solidification mask visualizations for models trained with combined loss.

In [None]:
print(f"Solidification Mask Visualizations (t={PRED_TIMESTEP}, s={PRED_SLICE}):\n")
print("(Only for models trained with combined loss)\n")

for config in MODEL_CONFIGS:
    # Skip non-combined loss models
    if config['loss_type'] != 'combined':
        continue

    run_dir = OUTPUT_DIR / config['name']
    solid_mask_path = run_dir / f"solidification_mask_t{PRED_TIMESTEP}_s{PRED_SLICE}.png"

    if solid_mask_path.exists():
        print(f"\nModel {config['id']}: {config['name']}")
        display(Image(filename=str(solid_mask_path)))
    else:
        print(f"\nModel {config['id']}: No solidification mask found")

## Model Comparison

Compare training losses across all models.

In [None]:
# Load all histories
histories = {}
for config in MODEL_CONFIGS:
    history_path = OUTPUT_DIR / config['name'] / "history.json"
    if history_path.exists():
        with open(history_path, 'r') as f:
            histories[config['name']] = json.load(f)

if histories:
    # Plot comparison
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))

    # Training loss
    for name, history in histories.items():
        ax1.plot(history['train_loss'], label=name, linewidth=2, alpha=0.7)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Training Loss', fontsize=12)
    ax1.set_title('Training Loss Comparison', fontsize=14, fontweight='bold')
    ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    ax1.grid(True, alpha=0.3)

    # Validation loss
    for name, history in histories.items():
        ax2.plot(history['val_loss'], label=name, linewidth=2, alpha=0.7)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Validation Loss', fontsize=12)
    ax2.set_title('Validation Loss Comparison', fontsize=14, fontweight='bold')
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    comparison_path = OUTPUT_DIR / "all_models_comparison.png"
    plt.savefig(comparison_path, dpi=150, bbox_inches='tight')
    plt.show()

    print(f"\nComparison plot saved to {comparison_path}")

    # Calculate solidification loss for MSE models on validation set
    print("\nCalculating solidification loss for MSE models on validation set...")

    # Create a CombinedLoss function for solidification region evaluation (T_solidus=1560, T_liquidus=1620)
    solidification_loss_fn = CombinedLoss(
        solidification_weight=1.0,
        global_weight=0.0,
        T_solidus=1560.0,
        T_liquidus=1620.0,
        weight_type="gaussian",
        weight_scale=0.1,
        base_weight=0.1,
        return_components=True,
    )

    mse_solidification_losses = {}

    for config in MODEL_CONFIGS:
        if config['loss_type'] == 'mse':
            checkpoint_path = OUTPUT_DIR / config['name'] / "checkpoints" / "best_model.pt"

            if checkpoint_path.exists():
                print(f"  Model {config['id']}: Computing solidification loss...")

                # Load model
                checkpoint = torch.load(checkpoint_path, map_location=DEVICE)

                if config['model_type'] == 'cnn_lstm':
                    model = MicrostructureCNN_LSTM(
                        input_channels=10,
                        future_channels=1,
                        output_channels=9,
                        use_skip_connections=config.get('use_skip_connections', False),
                    ).to(DEVICE)
                else:  # predrnn
                    model = MicrostructurePredRNN(
                        input_channels=10,
                        future_channels=1,
                        output_channels=9,
                        use_skip_connections=config.get('use_skip_connections', False),
                    ).to(DEVICE)

                model.load_state_dict(checkpoint['model_state_dict'])
                model.eval()

                # Create validation dataset with correct sequence length
                if fast_loading_available:
                    eval_dataset = FastMicrostructureSequenceDataset(
                        plane=PLANE,
                        split="val",
                        sequence_length=config['seq_length'],
                        target_offset=1,
                        train_ratio=train_ratio,
                        val_ratio=val_ratio,
                        test_ratio=test_ratio,
                    )
                else:
                    eval_dataset = MicrostructureSequenceDataset(
                        plane=PLANE,
                        split="val",
                        sequence_length=config['seq_length'],
                        target_offset=1,
                        preload=True,
                        train_ratio=train_ratio,
                        val_ratio=val_ratio,
                        test_ratio=test_ratio,
                    )

                eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

                # Evaluate on validation set
                total_solid_loss = 0.0
                num_batches = 0

                with torch.no_grad():
                    for batch in eval_loader:
                        context_temp = batch['context_temp'].to(DEVICE)
                        context_micro = batch['context_micro'].to(DEVICE)
                        future_temp = batch['future_temp'].to(DEVICE)
                        target_micro = batch['target_micro'].to(DEVICE)
                        mask = batch['target_mask'].to(DEVICE)

                        # Concatenate context temperature and microstructure
                        context = torch.cat([context_temp, context_micro], dim=2)

                        # Forward pass
                        pred_micro = model(context, future_temp)

                        # Calculate solidification loss
                        total_loss, solid_loss, global_loss = solidification_loss_fn(pred_micro, target_micro, future_temp, mask)

                        total_solid_loss += solid_loss.item()
                        num_batches += 1

                avg_solid_loss = total_solid_loss / num_batches
                mse_solidification_losses[config['name']] = avg_solid_loss
                print(f"    ✓ Solidification loss: {avg_solid_loss:.6f}")

    # Calculate test set metrics for all models
    print("\nCalculating test set metrics for all models...")

    test_mse_losses = {}
    test_solid_losses = {}

    mse_loss_fn = nn.MSELoss()

    for config in MODEL_CONFIGS:
        checkpoint_path = OUTPUT_DIR / config['name'] / "checkpoints" / "best_model.pt"

        if checkpoint_path.exists():
            print(f"  Model {config['id']}: Computing test set metrics...")

            # Load model
            checkpoint = torch.load(checkpoint_path, map_location=DEVICE)

            if config['model_type'] == 'cnn_lstm':
                model = MicrostructureCNN_LSTM(
                    input_channels=10,
                    future_channels=1,
                    output_channels=9,
                    use_skip_connections=config.get('use_skip_connections', False),
                ).to(DEVICE)
            else:  # predrnn
                model = MicrostructurePredRNN(
                    input_channels=10,
                    future_channels=1,
                    output_channels=9,
                    use_skip_connections=config.get('use_skip_connections', False),
                ).to(DEVICE)

            model.load_state_dict(checkpoint['model_state_dict'])
            model.eval()

            # Create test dataset with correct sequence length
            if fast_loading_available:
                test_dataset = FastMicrostructureSequenceDataset(
                    plane=PLANE,
                    split="test",
                    sequence_length=config['seq_length'],
                    target_offset=1,
                    train_ratio=train_ratio,
                    val_ratio=val_ratio,
                    test_ratio=test_ratio,
                )
            else:
                test_dataset = MicrostructureSequenceDataset(
                    plane=PLANE,
                    split="test",
                    sequence_length=config['seq_length'],
                    target_offset=1,
                    preload=True,
                    train_ratio=train_ratio,
                    val_ratio=val_ratio,
                    test_ratio=test_ratio,
                )

            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

            # Evaluate on test set
            total_mse_loss = 0.0
            total_solid_loss = 0.0
            num_batches = 0

            with torch.no_grad():
                for batch in test_loader:
                    context_temp = batch['context_temp'].to(DEVICE)
                    context_micro = batch['context_micro'].to(DEVICE)
                    future_temp = batch['future_temp'].to(DEVICE)
                    target_micro = batch['target_micro'].to(DEVICE)
                    mask = batch['target_mask'].to(DEVICE)

                    # Concatenate context
                    context = torch.cat([context_temp, context_micro], dim=2)

                    # Forward pass
                    pred_micro = model(context, future_temp)

                    # Calculate MSE loss (on valid pixels only)
                    mask_expanded = mask.unsqueeze(1).expand_as(target_micro)
                    mse_loss = mse_loss_fn(pred_micro[mask_expanded], target_micro[mask_expanded])

                    # Calculate solidification loss
                    _, solid_loss, _ = solidification_loss_fn(pred_micro, target_micro, future_temp, mask)

                    total_mse_loss += mse_loss.item()
                    total_solid_loss += solid_loss.item()
                    num_batches += 1

            avg_mse_loss = total_mse_loss / num_batches
            avg_solid_loss = total_solid_loss / num_batches

            test_mse_losses[config['name']] = avg_mse_loss
            test_solid_losses[config['name']] = avg_solid_loss

            print(f"    ✓ Test MSE: {avg_mse_loss:.6f}, Test Solid: {avg_solid_loss:.6f}")

    # Print summary table with accuracy (1 - loss)
    print("\n" + "="*170)
    print("MODEL COMPARISON SUMMARY (Accuracy = 1 - Loss)")
    print("="*170)
    print(f"{'ID':<4} {'Name':<50} {'Train':<10} {'Val Global':<12} {'Val Solid':<12} {'Test Global':<12} {'Test Solid':<12} {'Epochs':<8}")
    print("-"*170)

    for config in MODEL_CONFIGS:
        name = config['name']
        if name in histories:
            history = histories[name]
            final_train = history['train_loss'][-1]
            final_val = history['val_loss'][-1]
            epochs = len(history['train_loss'])

            # Convert to accuracy
            train_acc = 1.0 - final_train
            val_global_acc = 1.0 - final_val

            # Val solidification loss
            if config['loss_type'] == 'combined' and 'val_solidification_loss' in history:
                final_val_solid = history['val_solidification_loss'][-1]
                val_solid_acc = 1.0 - final_val_solid
            elif config['loss_type'] == 'mse' and name in mse_solidification_losses:
                final_val_solid = mse_solidification_losses[name]
                val_solid_acc = 1.0 - final_val_solid
            else:
                val_solid_acc = None

            # Test losses -> accuracy
            test_mse = test_mse_losses.get(name, None)
            test_solid = test_solid_losses.get(name, None)

            test_global_acc = 1.0 - test_mse if test_mse is not None else None
            test_solid_acc = 1.0 - test_solid if test_solid is not None else None

            # Format output
            val_solid_str = f"{val_solid_acc:.6f}" if val_solid_acc is not None else "N/A"
            test_global_str = f"{test_global_acc:.6f}" if test_global_acc is not None else "N/A"
            test_solid_str = f"{test_solid_acc:.6f}" if test_solid_acc is not None else "N/A"

            print(f"{config['id']:<4} {name:<50} {train_acc:<10.6f} {val_global_acc:<12.6f} {val_solid_str:<12} {test_global_str:<12} {test_solid_str:<12} {epochs:<8}")
        else:
            print(f"{config['id']:<4} {name:<50} {'N/A':<10} {'N/A':<12} {'N/A':<12} {'N/A':<12} {'N/A':<12} {'N/A':<8}")

    print("="*170)
else:
    print("No training histories found.")

## Summary

All 10 models have been trained and evaluated!