In [19]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
from IPython.display import Video, display

# Import custom modules from 'src' package
from src.dataset import DataModule
from src.models import Forecaster
from src.trainer import Trainer
from src.losses import combined_loss_fn
from src import utils

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running experiment on: {device}")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Running experiment on: cuda


### Experiment Configuration

In [2]:
CONFIG = {
    # Data Paths
    "DATA_DIR": "Data_embryos",  
    
    # Preprocessing
    "TARGET_SIZE": (128, 128),
    "T_START": 120,
    "T_END": 960,
    "START_HOUR": 2.0,
    "END_HOUR": 16.0,
    # Output
    "OUTPUT_DIR": "results_paper_final",
    
    # Sequence Parameters
    "N_PAST": 10,     
    "N_FUTURE": 1,    

    # Steps into future for autoregressive analysis
    "N_FUTURE_AR": 30,
    
    # Training Hyperparameters
    "BATCH_SIZE": 6,  
    "LR": 0.0001,
    "EPOCHS": 50,     
    "PATIENCE": 5,   # Early stopping patience
    "SPLIT": 0.8,     # Train/Val split
    "AUGMENT": True   # Use Data Augmentation
}

print("Configuration loaded.")

Configuration loaded.


In [3]:
# Create output directories
os.makedirs(os.path.join(CONFIG["OUTPUT_DIR"], "videos"), exist_ok=True)
os.makedirs(os.path.join(CONFIG["OUTPUT_DIR"], "plots"), exist_ok=True)

### Data Loading & Inspection

In [4]:
# Initialize DataModule
data_module = DataModule(
    data_dir=CONFIG["DATA_DIR"], 
    target_size=CONFIG["TARGET_SIZE"],
    t_start=CONFIG["T_START"],
    t_end=CONFIG["T_END"]
)

# Load and process images (Crop + Norm)
data_module.load_and_process_data()

Scanning directory: Data_embryos ...
Found 41 images in train set.


Processing train: 100%|█████████████████████████| 41/41 [14:27<00:00, 21.17s/it]


Found 12 images in test set.


Processing test: 100%|██████████████████████████| 12/12 [04:13<00:00, 21.13s/it]

Loaded 41 training stacks.
Loaded 12 test stacks.





In [5]:
# Classify Test Embryos (BMP, Nodal, Normal)
test_keys = data_module.test_keys
embryo_groups = utils.classify_embryo_types(test_keys)

print("\nEmbryo Classification (Test Set):")
for group, keys in embryo_groups.items():
    print(f"  - {group}: {len(keys)} embryos")


Embryo Classification (Test Set):
  - BMP: 4 embryos
  - Nodal: 4 embryos
  - Normal: 4 embryos


In [6]:
sample_id_inspection = embryo_groups['BMP'][0] if embryo_groups['BMP'] else test_keys[0]
print(f"Generating video for: {sample_id_inspection}")
save_path_insp = os.path.join(CONFIG["OUTPUT_DIR"], "videos", f"Inspection_{sample_id_inspection}.mp4")

utils.save_inspection_video(
    data_module=data_module,
    embryo_key=sample_id_inspection,
    save_path=save_path_insp,
    fps=20,
    start_h=CONFIG["START_HOUR"],
    end_h=CONFIG["END_HOUR"]
)

Generating video for: Key_Bmp_tes_E0007
Inspection video saved: results_paper_final/videos/Inspection_Key_Bmp_tes_E0007.mp4


### Dataloader Setup

In [8]:
data_module.prepare_dataloaders(
    n_past=CONFIG["N_PAST"],
    n_future=CONFIG["N_FUTURE"], 
    train_split_percent=CONFIG["SPLIT"],
    batch_size=CONFIG["BATCH_SIZE"],
    use_augmentation=CONFIG["AUGMENT"]
)

Splitting Logic (By Embryo ID): Train=32, Val=9
Total Sequences Generated: Train=14077, Val=3960, Test=5280
Augmentation pipeline enabled for training set.


### Model Initialization

In [9]:
# Define Architecture Hyperparameters
model_params = {
    "input_dim": 1,
    "hidden_dims": [32, 32, 64], 
    "kernel_size": (3, 3),
    "num_layers": 3
}

model = Forecaster(**model_params)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=CONFIG["LR"])

print(f"Model initialized on {device}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Model initialized on cuda
Trainable parameters: 334,017


### Training Loop

**Upload a pretrainer model:**

In [10]:
#model_path = "saved_models/best_embryo_model_1.pth"
#if os.path.exists(model_path):
    # Load into the model
#    state_dict = torch.load(model_path, map_location=device)
#    model.load_state_dict(state_dict)
#    model.to(device)

  state_dict = torch.load(model_path, map_location=device)


**Train a New Model:**

In [11]:
trainer = Trainer(
    model=model,
    loss_fn=combined_loss_fn, # Imported from src.losses
    optimizer=optimizer,
    device=device,
    train_loader=data_module.train_loader,
    val_loader=data_module.val_loader,
    test_loader=data_module.test_loader,
    n_future=CONFIG["N_FUTURE"], # Pass the horizon for autoregressive training
    early_stopping_patience=CONFIG["PATIENCE"]
)

In [None]:
trainer.train(num_epochs=CONFIG["EPOCHS"])

### Save Trained Model

In [None]:
# New directory
save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True)

model_path = os.path.join(save_dir, "best_embryo_model.pth")

torch.save(trainer.model.state_dict(), model_path)
print(f"Model saved in: {model_path}")

### Learning Curves:

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(trainer.train_losses, label='Train Loss')
plt.plot(trainer.val_losses, label='Val Loss')
plt.title("Training Dynamics")
plt.xlabel("Epochs")
plt.ylabel("Loss (Combined)")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## Qualitative Results:

### Long-Term Autoregressive

In [13]:
for group_name, keys in embryo_groups.items():
    # Select representative sample
    sample_id = keys[0]
    print(f"Generating AR Video for Group: {group_name} | Sample: {sample_id}")
    full_seq = data_module.normalized_data[sample_id]
    
    # Limits of the sequence
    start_idx_ar = len(full_seq) // 2
    end_idx_ar = start_idx_ar + CONFIG["N_FUTURE_AR"]
    
    ar_results = utils.perform_autoregressive_inference(
        model=model,
        full_sequence=full_seq,
        start_idx=start_idx_ar,
        end_idx=end_idx_ar,
        n_past=CONFIG["N_PAST"],
        device=device
    )
    
    utils.save_prediction_video(
        gt_seq=ar_results['gt_future'],
        pred_seq=ar_results['predictions'],
        save_path=os.path.join(CONFIG["OUTPUT_DIR"], "videos", f"AR_LongTerm_{sample_id}.mp4"),
        fps=4,
        start_h=ar_results['t_start'],
        end_h=ar_results['t_end'],
        is_autoregressive=True,
        n_past=CONFIG["N_PAST"]
    )

Generating AR Video for Group: BMP | Sample: Key_Bmp_tes_E0007
Prediction video saved: results_paper_final/videos/AR_LongTerm_Key_Bmp_tes_E0007.mp4
Generating AR Video for Group: Nodal | Sample: Acq_Nd_tes_E1005
Prediction video saved: results_paper_final/videos/AR_LongTerm_Acq_Nd_tes_E1005.mp4
Generating AR Video for Group: Normal | Sample: Key_Nr_tes_E0015
Prediction video saved: results_paper_final/videos/AR_LongTerm_Key_Nr_tes_E0015.mp4


In [14]:
for group_name, keys in embryo_groups.items():
    # Select representative sample
    sample_id = keys[0]
    print(f"Generating OneStep Video for Group: {group_name} | Sample: {sample_id}")
    full_seq = data_module.normalized_data[sample_id]
    # Demonstrates theoretical upper-bound performance
    preds_os, gt_os = [], []
    model.eval()
    
    with torch.no_grad():
        for t in range(CONFIG["N_PAST"], len(full_seq)):
            ctx = full_seq[t - CONFIG["N_PAST"] : t]
            inp = torch.from_numpy(ctx).float().unsqueeze(0).unsqueeze(2).to(device)
            preds_os.append(utils.tensor_to_numpy(model(inp)))
            gt_os.append(full_seq[t])
    
    utils.save_prediction_video(
        gt_seq=np.array(gt_os),
        pred_seq=np.array(preds_os).squeeze(),
        save_path=os.path.join(CONFIG["OUTPUT_DIR"], "videos", f"OneStep_Full_{sample_id}.mp4"),
        fps=20,
        start_h=CONFIG["START_HOUR"],
        end_h=CONFIG["END_HOUR"],
        is_autoregressive=False
    )

Generating OneStep Video for Group: BMP | Sample: Key_Bmp_tes_E0007
Prediction video saved: results_paper_final/videos/OneStep_Full_Key_Bmp_tes_E0007.mp4
Generating OneStep Video for Group: Nodal | Sample: Acq_Nd_tes_E1005
Prediction video saved: results_paper_final/videos/OneStep_Full_Acq_Nd_tes_E1005.mp4
Generating OneStep Video for Group: Normal | Sample: Key_Nr_tes_E0015
Prediction video saved: results_paper_final/videos/OneStep_Full_Key_Nr_tes_E0015.mp4


## Quantitative Results:

In [20]:
metrics_ar = {}
metrics_os = {}

for group, keys in embryo_groups.items():
    if not keys: continue
    print(f"Analyzing group: {group} (n={len(keys)})")
    
    # 1. Degradation over autoregressive forecasting
    metrics_ar[group] = utils.compute_autoregressive_metrics(
        model=model, keys=keys, data_module=data_module,
        n_past=CONFIG["N_PAST"], n_future=CONFIG["N_FUTURE_AR"],
        device=device, start_h=CONFIG["START_HOUR"], end_h=CONFIG["END_HOUR"]
    )
    
    # 2. Error over one step forecasting
    metrics_os[group] = utils.compute_onestep_metrics(
        model=model, keys=keys, data_module=data_module,
        n_past=CONFIG["N_PAST"], device=device,
        start_h=CONFIG["START_HOUR"], end_h=CONFIG["END_HOUR"]
    )

# Plot 
plot_configs = [
    # AUTOREGRESSIVE 
    (metrics_ar, 'mse', 'Accumulated MSE', 'Forecast Horizon (Hours)', '', 'AR_Degradation_MSE.svg'),
    (metrics_ar, 'ssim_error', '1 - SSIM', 'Forecast Horizon (Hours)', '', 'AR_Degradation_SSIM.svg'),
    (metrics_ar, 'grad', 'Gradient Loss', 'Forecast Horizon (Hours)', '', 'AR_Degradation_Grad.svg'),
    
    # ONE-STEP
    (metrics_os, 'mse', 'Instantaneous MSE', 'Biological Age (hpf)', '', 'Dev_Error_MSE.svg'),
    (metrics_os, 'ssim_error', '1 - SSIM', 'Biological Age (hpf)', '', 'Dev_Error_SSIM.svg'),
    (metrics_os, 'grad', 'Gradient Loss', 'Biological Age (hpf)', '', 'Dev_Error_Grad.svg')
]

for metrics, key, ylab, xlab, title, fname in plot_configs:
    utils.plot_curves(
        metrics_by_type=metrics, 
        metric_key=key,
        y_label=ylab, 
        x_label=xlab, 
        title=title,
        save_path=os.path.join(CONFIG["OUTPUT_DIR"], "plots", fname)
    )

print("\nExperiment complete. Results saved to:", CONFIG["OUTPUT_DIR"])

Analyzing group: BMP (n=4)
Analyzing group: Nodal (n=4)
Analyzing group: Normal (n=4)
Plot saved: results_paper_final/plots/AR_Degradation_MSE.svg
Plot saved: results_paper_final/plots/AR_Degradation_SSIM.svg
Plot saved: results_paper_final/plots/AR_Degradation_Grad.svg
Plot saved: results_paper_final/plots/Dev_Error_MSE.svg
Plot saved: results_paper_final/plots/Dev_Error_SSIM.svg
Plot saved: results_paper_final/plots/Dev_Error_Grad.svg

Experiment complete. Results saved to: results_paper_final
