# U-Net Training for Ultrasound Image Reconstruction
This notebook sets up the baseline training pipeline for U-Net using ultrasound data.  

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt


from model_library import UNet  

# pytorch modules
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import torchvision.transforms as T
import torch.nn.functional as F

from sklearn.model_selection import train_test_split

# for dataset preprocessing
from glob import glob
import scipy.io as sio
from PIL import Image

# evaluation metrics for reconstructed images
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from piq import SSIMLoss # loss function


## Configuration
Define all the training parameters in one variable.

In [None]:
CONFIG = {
    'model_name': 'unet',
    'batch_size': 4,
    'lr': 1e-3,
    'epochs': 150,
    'input_size': (75, 128, 128), # TODO: change this when the dataset becomes available
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'data_dir_rf':  os.path.join('..', 'data', 'raw'),
    'data_dir_img': os.path.join('..', 'data', 'processed'),
    'checkpoint_dir': './checkpoints/unet/',
}

## Data Loading
THIS NEEDS REPLACING! We assume .mat or .npy for input and .npy or .png for output.

In [None]:
# we first define a dataset class that will take care of data handling DONE
class UltrasoundDataset(torch.utils.data.Dataset):
    def __init__(self, rf_paths, img_paths, target_shape=(75, 128, 128)):
        self.rf_paths      =  rf_paths          # paths to RF data files
        self.img_paths     =  img_paths         # paths to img files
        self.target_shape  =  target_shape      # target shape we assume each RF file to have

    def __len__(self):
        return len(self.rf_paths)   # assume 1-to-1 correspondence between RF and images

    def __getitem__(self, idx):
        # load the RF and img data from .mat files
        rf   =  sio.loadmat(self.rf_paths[idx])['rf_raw']
        img  =  sio.loadmat(self.img_paths[idx])['img']
        
        # turn the RF and img data into tensorts
        rf   =  torch.tensor(rf, dtype=torch.float32)                 # [C, H, W] -> [plane wave, depth, transducer array elements]
        img  =  torch.tensor(img, dtype=torch.float32).unsqueeze(0)   # [1, H, W] -> [1, height, width] greyscale image

        # resize both to target shape 
        rf   =  self.resize_tensor(rf, self.target_shape)
        img  =  self.resize_tensor(img, (1, self.target_shape[1], self.target_shape[2]))

        # normalize the image to avoid problems with SSIM loss later on
        img = img / 255.0
        return rf, img

    def resize_tensor(self, tensor, target_shape):
        """Resize input tensor to target_shape."""
        return F.interpolate(
            tensor.unsqueeze(0), size=target_shape[1:], mode='bilinear', align_corners=False
        ).squeeze(0)


In [None]:
rf_dir   =  CONFIG['data_dir_rf']      # path to RF .mat files
img_dir  =  CONFIG['data_dir_img']     # path to image .mat files

rf_paths   =  sorted(glob(os.path.join(rf_dir, 'rf_*.mat')))
img_paths  =  sorted(glob(os.path.join(img_dir, 'img_*.mat')))

assert len(rf_paths) == len(img_paths), "Mismatch between RF and image files!" # make sure we have matching number of RF data and reconstructed images

In [None]:
# we do train test split of training data to perform validation on simulated data
rf_train, rf_val, img_train, img_val = train_test_split(rf_paths, img_paths, test_size=0.2, random_state=42)

In [None]:
target_shape = (75, 128, 128)  # TODO: change this so that it matches CONFIG when dataset is known

train_dataset = UltrasoundDataset(rf_train, img_train, target_shape=target_shape)
val_dataset   = UltrasoundDataset(rf_val, img_val, target_shape=target_shape)

In [None]:
# we create two dataloaders, we set num_workers=0 to avoid spawn issues
train_loader = DataLoader(
    train_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=True, 
    num_workers=0, 
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=False, 
    num_workers=0, 
    pin_memory=True
)

In [None]:
# sanity check
for rf_batch, img_batch in train_loader:
    print(f"RF shape: {rf_batch.shape}")   # Should be [B, 75, 128, 128]
    print(f"IMG shape: {img_batch.shape}") # Should be [B, 1, 128, 128]
    break

## Model Architecture
Here we define the model architecture used for training.

In [None]:
# the idea is that we use all plane waves as input (for example 75) and then predict 1 image based on that
# TODO: change input channels when the dataset is known
model = UNet(in_channels=75, out_channels=1).to(CONFIG['device'])

In [None]:
ssim_loss_fn = SSIMLoss(data_range=1.0)  # assuming your images are normalized to [0, 1]
# TODO: consider combining SSIMLoss with MAE
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [None]:
def combined_loss(pred, target, alpha=0.84):
    ssim = ssim_loss_fn(pred, target)
    l1 = torch.nn.functional.l1_loss(pred, target)
    return alpha * ssim + (1 - alpha) * l1

In [None]:
def train_one_epoch(model, loader, optimizer):
    model.train()
    epoch_loss = 0
    for x, y in tqdm(loader):
        x, y = x.to(CONFIG['device']), y.to(CONFIG['device'])
        optimizer.zero_grad()
        y_pred = model(x)
        loss = combined_loss(y_pred, y)  # use your custom function
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(loader)


In [None]:
def compute_metrics(y_true, y_pred):
    """
    Computes PSNR and SSIM between two numpy arrays.
    
    Args:
        y_true (np.ndarray): Ground truth image, shape (H, W) or (1, H, W)
        y_pred (np.ndarray): Predicted image, shape (H, W) or (1, H, W)
    
    Returns:
        tuple: (PSNR, SSIM) float values
    """
    # Squeeze unnecessary singleton dimensions (e.g., channel)
    y_true = np.squeeze(y_true)
    y_pred = np.squeeze(y_pred)

    assert y_true.shape == y_pred.shape, f"Shape mismatch: {y_true.shape} vs {y_pred.shape}"
    assert y_true.ndim == 2, f"Expected 2D arrays after squeeze, got {y_true.ndim}D"

    psnr_val = psnr(y_true, y_pred, data_range=1.0)
    ssim_val = ssim(y_true, y_pred, data_range=1.0)

    return psnr_val, ssim_val

In [None]:
def validate(model, loader):
    model.eval()
    val_loss = 0
    total_psnr, total_ssim = 0.0, 0.0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(CONFIG['device']), y.to(CONFIG['device'])
            y_pred = model(x)
            loss = combined_loss(y_pred, y)
            val_loss += loss.item()

            # Detach and move to CPU
            preds = y_pred.detach().cpu().numpy()
            targets = y.detach().cpu().numpy()

            # Compute metrics per sample
            for i in range(preds.shape[0]):
                pred_i = preds[i, 0]  # remove channel dim: [1, H, W] → [H, W]
                target_i = targets[i, 0]

                psnr, ssim = compute_metrics(target_i, pred_i)
                total_psnr += psnr
                total_ssim += ssim

    n = len(loader.dataset)
    return (
        val_loss / len(loader),
        total_psnr / n,
        total_ssim / n
    )


## Model Training
In this part we set up the training loop, train the model and then save the parameters of the final one.

In [None]:
# Make sure checkpoint directory exists
os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)

for epoch in tqdm(range(CONFIG['epochs'])):
    train_loss = train_one_epoch(model, train_loader, optimizer)
    val_loss, val_psnr, val_ssim = validate(model, val_loader)
    scheduler.step()

    print(f"Epoch {epoch+1}/{CONFIG['epochs']} | "
          f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
          f"PSNR: {val_psnr:.2f} | SSIM: {val_ssim:.4f}")

    # Save checkpoint every 50 epochs
    if (epoch + 1) % 50 == 0 or (epoch + 1) == CONFIG['epochs']:
        ckpt_path = os.path.join(CONFIG['checkpoint_dir'], f"{CONFIG['model_name']}_epoch_{epoch+1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'psnr': val_psnr,
            'ssim': val_ssim,
        }, ckpt_path)
        print(f"✅ Saved checkpoint: {ckpt_path}")

In [None]:
final_model_path = os.path.join(CONFIG['checkpoint_dir'], 'final_model.pt')
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'epoch': CONFIG['epochs'],
}, final_model_path)


## Model Visualization
We visualize one example of the reconstructed images.

In [None]:
def show_example(model, loader):
    model.eval()
    x, y = next(iter(loader))
    x = x.to(CONFIG['device'])
    with torch.no_grad():
        y_pred = model(x).cpu()

    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(x[0, 0].cpu(), cmap='gray')
    axs[0].set_title("Input")
    axs[1].imshow(y[0, 0], cmap='gray')
    axs[1].set_title("Target")
    axs[2].imshow(y_pred[0, 0], cmap='gray')
    axs[2].set_title("Prediction")
    plt.show()
