# U-Net Training for Ultrasound Image Reconstruction
This notebook sets up the baseline training pipeline for U-Net using ultrasound data.  

In [7]:
import os
import numpy as np
import matplotlib.pyplot as plt

# pytorch modules
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import torchvision.transforms as T

# for dataset preprocessing
import scipy.io as sio
from PIL import Image

# evaluation metrics for reconstructed images
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim


## Configuration
Define all the training parameters in one variable.

In [None]:
CONFIG = {
    'model_name': 'unet',
    'batch_size': 4,
    'lr': 1e-3,
    'epochs': 150,
    'input_size': (128, 128),
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'data_dir_in':  os.path.join('..', 'data', 'raw'),
    'data_dir_out': os.path.join('..', 'data', 'processed'),
    'checkpoint_dir': './checkpoints/unet/',
}

## Data Loading
THIS NEEDS REPLACING! We assume .mat or .npy for input and .npy or .png for output.

In [None]:
class UltrasoundDataset(Dataset):
    def __init__(self, rf_dir, img_dir, transform=None):
        self.rf_paths = sorted([
            os.path.join(rf_dir, f) for f in os.listdir(rf_dir) if f.endswith('.mat')
        ])
        self.img_paths = sorted([
            os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.png') or f.endswith('.npy')
        ])
        self.transform = transform

        assert len(self.rf_paths) == len(self.img_paths), "Mismatched dataset sizes!"

    def __len__(self):
        return len(self.rf_paths)

    def __getitem__(self, idx):
        # Load RF .mat file
        rf = sio.loadmat(self.rf_paths[idx])
        rf = rf[list(rf.keys())[-1]]  # or use exact key if known (e.g., rf['rf'])

        # Load GT image
        if self.img_paths[idx].endswith('.png'):
            img = np.array(Image.open(self.img_paths[idx])) / 255.0
        else:
            img = np.load(self.img_paths[idx])

        # Ensure shape: (C, H, W)
        rf = torch.tensor(rf, dtype=torch.float32).unsqueeze(0) if rf.ndim == 2 else torch.tensor(rf, dtype=torch.float32)
        img = torch.tensor(img, dtype=torch.float32).unsqueeze(0) if img.ndim == 2 else torch.tensor(img, dtype=torch.float32)

        return rf, img

train_loader = DataLoader(UltrasoundDataset('train'), batch_size=CONFIG['batch_size'], shuffle=True)
val_loader = DataLoader(UltrasoundDataset('val'), batch_size=CONFIG['batch_size']) 

## Model Architecture
Here we define the model architecture used for training.

In [5]:
# example for U-Net (same cell can be replaced with ResNet later)
from model_library import UNet  # optional if in separate file

model = UNet(in_channels=1, out_channels=1).to(CONFIG['device'])


In [None]:
criterion = nn.MSELoss()  # or SSIM, L1Loss, etc.
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [None]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    epoch_loss = 0
    for x, y in tqdm(loader):
        x, y = x.to(CONFIG['device']), y.to(CONFIG['device'])
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(loader)

In [None]:
def validate(model, loader, criterion):
    model.eval()
    val_loss = 0
    total_psnr, total_ssim = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(CONFIG['device']), y.to(CONFIG['device'])
            y_pred = model(x)
            loss = criterion(y_pred, y)
            val_loss += loss.item()

            # PSNR & SSIM (on CPU, single example per batch assumed)
            x_np = y_pred.cpu().numpy()
            y_np = y.cpu().numpy()
            for i in range(x_np.shape[0]):
                p, s = compute_metrics(y_np[i], x_np[i])
                total_psnr += p
                total_ssim += s

    n = len(loader.dataset)
    return (
        val_loss / len(loader),
        total_psnr / n,
        total_ssim / n
    )

In [None]:
def compute_metrics(y_true, y_pred):
    """
    Expects numpy arrays with shape (1, H, W)
    """
    y_true = y_true.squeeze()
    y_pred = y_pred.squeeze()

    _psnr = psnr(y_true, y_pred, data_range=1.0)
    _ssim = ssim(y_true, y_pred, data_range=1.0)
    return _psnr, _ssim

## Model Training
In this part we set up the training loop, train the model and then save the parameters of the final one.

In [None]:
for epoch in range(CONFIG['epochs']):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_psnr, val_ssim = validate(model, val_loader, criterion)
    scheduler.step()
    
    print(f"Epoch {epoch+1}/{CONFIG['epochs']} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | PSNR: {val_psnr:.2f} | SSIM: {val_ssim:.4f}")


In [None]:
final_model_path = os.path.join(CONFIG['checkpoint_dir'], 'final_model.pt')
torch.save(model.state_dict(), final_model_path)

## Model Visualization
We visualize one example of the reconstructed images.

In [None]:
def show_example(model, loader):
    model.eval()
    x, y = next(iter(loader))
    x = x.to(CONFIG['device'])
    with torch.no_grad():
        y_pred = model(x).cpu()

    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(x[0, 0].cpu(), cmap='gray')
    axs[0].set_title("Input")
    axs[1].imshow(y[0, 0], cmap='gray')
    axs[1].set_title("Target")
    axs[2].imshow(y_pred[0, 0], cmap='gray')
    axs[2].set_title("Prediction")
    plt.show()
