# Model Training for Ultrasound Image Reconstruction
This notebook sets up the baseline training pipeline for U-Net using ultrasound data.  

In [None]:
import os       # file access
import matplotlib.pyplot as plt     # plotting of the results (possibly later)

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm   # for epoch iteration
from glob import glob   # for dataset preprocessing
from sklearn.model_selection import train_test_split # to perform validation on simulation data

############################################################################################################################
# user defined models
from models.resnet import CustomResNet      # custom resnet architecture - CNN baseline
from models.unet   import CustomUNet        # custom unet architecture - better CNN baseline
from models.effunet import EfficientUNetBeamformer  # efficientnet encoder + unet decoder - final model

# user defined scripts
from dataloader.dataset import UltrasoundDataset                # custom class to handle ultrasound RF and image data
from utils.losses import ssim_loss, mae_loss, combined_loss     # loss functions possibly interesting for US data
from utils.metrics import compute_metrics                       # function to compute SSIM and PSNR for comparisons later


## Configuration
Define all the training parameters in one variable.

In [None]:
# TODO: Finish the config when the dataset becomes available
# make sure to configure model and loss function before each run
CONFIG = {
    'model_name': 'resnet',  # options: 'resnet', 'unet', 'effunet'
    'loss_function': 'ssim', # options: 'ssim', 'mae', 'combined'
    'batch_size': 4,
    'lr': 1e-3,
    'epochs': 150,
    'input_size': (75, 128, 128),  # [channels, height, width] 
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'data_dir_rf': os.path.join('..', 'data', 'raw'),
    'data_dir_img': os.path.join('..', 'data', 'processed'),
    
    # set below depending on the model name
    'checkpoint_dir': None,
}
CONFIG['checkpoint_dir'] = f"./checkpoints/{CONFIG['model_name']}/"

## Data Loading
We assume .mat for input RF data and .mat for the reconstructed image files.

In [None]:
# we create the paths to both RF and image data
rf_dir   =  CONFIG['data_dir_rf']      # path to RF .mat files
img_dir  =  CONFIG['data_dir_img']     # path to image .mat files

# after this step we have sorted lists
rf_paths   =  sorted(glob(os.path.join(rf_dir, 'rf_*.mat')))
img_paths  =  sorted(glob(os.path.join(img_dir, 'img_*.mat')))

assert len(rf_paths) == len(img_paths), "Mismatch between RF and image files!" # make sure we have matching number of RF data and reconstructed images

In [None]:
# Get full list of RF and image paths
# Example:
# rf_paths = sorted(glob.glob(os.path.join(CONFIG['data_dir_rf'], '*.mat')))
# img_paths = sorted(glob.glob(os.path.join(CONFIG['data_dir_img'], '*.mat')))

# Split: 90% train, 10% val
rf_train, rf_val, img_train, img_val = train_test_split(
    rf_paths,
    img_paths,
    test_size=0.1,
    random_state=42
)

# Define datasets
target_shape = CONFIG['input_size']

train_dataset = UltrasoundDataset(rf_train, img_train, target_shape=target_shape)
val_dataset   = UltrasoundDataset(rf_val, img_val, target_shape=target_shape)


In [None]:
# we create a dataloader, we set num_workers=0 to avoid spawn issues
train_loader = DataLoader(
    train_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=True, 
    num_workers=0, 
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=0,
    pin_memory=True # we allocate dataloader batch tensors to page-locked memory - efficiency trick
)

In [None]:
# sanity check
for rf_batch, img_batch in train_loader:
    print(f"RF shape: {rf_batch.shape}")   # Should be [B, 75, 128, 128]
    print(f"IMG shape: {img_batch.shape}") # Should be [B, 1, 128, 128]
    break

## Model Architecture
Here we define the model architecture used for training.

In [None]:
# the idea is that we use all plane waves as input (for example 75) and then predict 1 image based on that
# TODO: change input channels when the dataset is known
if CONFIG['model_name'] == 'resnet':
    model = CustomResNet()
elif CONFIG['model_name'] == 'unet':
    model = CustomUNet()
elif CONFIG['model_name'] == 'effunet':
    model = EfficientUNetBeamformer()

In [None]:
# we define standard optimizer and scheduler, loss is defined later based on user CONFIG
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [None]:
def train_one_epoch(model, loader, optimizer):
    model.train()
    epoch_loss = 0
    for x, y in tqdm(loader):
        x, y = x.to(CONFIG['device']), y.to(CONFIG['device'])
        optimizer.zero_grad()
        y_pred = model(x)

        if CONFIG['loss_function'] == 'ssim':
            loss = ssim_loss(y_pred, y)
        elif CONFIG['loss_function'] == 'mae':
            loss = mae_loss(y_pred, y)
        elif CONFIG['loss_function'] == 'combined':
            loss = combined_loss(y_pred, y)
            
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(loader)


In [None]:
def validate(model, loader):
    model.eval()
    val_loss = 0
    total_psnr, total_ssim = 0.0, 0.0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(CONFIG['device']), y.to(CONFIG['device'])
            y_pred = model(x)
            loss = combined_loss(y_pred, y)
            val_loss += loss.item()

            # Detach and move to CPU
            preds = y_pred.detach().cpu().numpy()
            targets = y.detach().cpu().numpy()

            # Compute metrics per sample
            for i in range(preds.shape[0]):
                pred_i = preds[i, 0]  # remove channel dim: [1, H, W] → [H, W]
                target_i = targets[i, 0]

                psnr, ssim = compute_metrics(target_i, pred_i)
                total_psnr += psnr
                total_ssim += ssim

    n = len(loader.dataset)
    return (
        val_loss / len(loader),
        total_psnr / n,
        total_ssim / n
    )


## Model Training
In this part we set up the training loop, train the model and then save the parameters of the final one.

In [None]:
# Make sure checkpoint directory exists
os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)

for epoch in tqdm(range(CONFIG['epochs'])):
    train_loss = train_one_epoch(model, train_loader, optimizer)
    val_loss, val_psnr, val_ssim = validate(model, val_loader)
    scheduler.step()

    print(f"Epoch {epoch+1}/{CONFIG['epochs']} | "
          f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
          f"PSNR: {val_psnr:.2f} | SSIM: {val_ssim:.4f}")

    # Save checkpoint every 50 epochs
    if (epoch + 1) % 50 == 0 or (epoch + 1) == CONFIG['epochs']:
        ckpt_path = os.path.join(CONFIG['checkpoint_dir'], f"{CONFIG['model_name']}_epoch_{epoch+1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'psnr': val_psnr,
            'ssim': val_ssim,
        }, ckpt_path)
        print(f"Saved checkpoint: {ckpt_path}")

In [None]:
final_model_path = os.path.join(CONFIG['checkpoint_dir'], 'final_model.pt')
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'epoch': CONFIG['epochs'],
}, final_model_path)


## Model Visualization
We visualize one example of the reconstructed images.

In [None]:
def show_example(model, loader):
    model.eval()
    x, y = next(iter(loader))
    x = x.to(CONFIG['device'])
    with torch.no_grad():
        y_pred = model(x).cpu()

    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(x[0, 0].cpu(), cmap='gray')
    axs[0].set_title("Input")
    axs[1].imshow(y[0, 0], cmap='gray')
    axs[1].set_title("Target")
    axs[2].imshow(y_pred[0, 0], cmap='gray')
    axs[2].set_title("Prediction")
    plt.show()
