# Promethium: Model Training Tutorial

This notebook demonstrates how to train reconstruction models using the Promethium framework.

**Author:** Olaf Yunus Laitinen Imanov  
**Date:** December 2025  
**Framework:** Promethium v1.0.0

In [None]:
import torch
import numpy as np
from pathlib import Path

from promethium.ml.models import UNet
from promethium.ml.training import Trainer
from promethium.ml.data import SeismicDataset

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Configure Training

Set up training parameters and data paths.

In [None]:
config = {
    "model": {
        "architecture": "unet",
        "in_channels": 1,
        "out_channels": 1,
        "features": [64, 128, 256, 512]
    },
    "training": {
        "epochs": 100,
        "batch_size": 16,
        "learning_rate": 1e-4,
        "weight_decay": 1e-5
    },
    "data": {
        "train_path": "data/train",
        "val_path": "data/val",
        "patch_size": 256
    }
}

## 2. Initialize Model

Create the U-Net model for reconstruction.

In [None]:
model = UNet(
    in_channels=config["model"]["in_channels"],
    out_channels=config["model"]["out_channels"],
    features=config["model"]["features"]
)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")

## 3. Prepare Data

Create data loaders for training and validation.

In [None]:
from torch.utils.data import DataLoader

# Create datasets
train_dataset = SeismicDataset(
    config["data"]["train_path"],
    patch_size=config["data"]["patch_size"]
)

val_dataset = SeismicDataset(
    config["data"]["val_path"],
    patch_size=config["data"]["patch_size"]
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config["training"]["batch_size"],
    shuffle=True,
    num_workers=4
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config["training"]["batch_size"],
    shuffle=False,
    num_workers=4
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

## 4. Train Model

Run the training loop with validation.

In [None]:
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    learning_rate=config["training"]["learning_rate"],
    weight_decay=config["training"]["weight_decay"],
    device=device
)

# Train for specified epochs
history = trainer.fit(epochs=config["training"]["epochs"])

## 5. Visualize Training

Plot training and validation losses.

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(history["train_loss"], label="Train Loss")
plt.plot(history["val_loss"], label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Progress")
plt.legend()
plt.grid(True)
plt.show()

## 6. Save Model

Save the trained model for inference.

In [None]:
# Save model checkpoint
torch.save({
    "model_state_dict": model.state_dict(),
    "config": config,
    "history": history
}, "checkpoints/unet_trained.pt")

print("Model saved successfully")

---

**Promethium** - State-of-the-art seismic data reconstruction