# Baseline U-Net Model for Spondylolisthesis Grading

This notebook implements the baseline U-Net model for automated grading of spondylolisthesis. It includes data loading, model training, and evaluation.

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from src.data.dataset import SpondylolisthesisDataset
from src.models.unet import UNet
from src.training.trainer import Trainer
from experiments.configs.unet_config import get_config
import matplotlib.pyplot as plt
import numpy as np

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
# Load dataset
config = get_config()
train_dataset = SpondylolisthesisDataset(root=config['data']['train_dir'], transform=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
]))
train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True)


In [None]:
# Initialize U-Net model
model = UNet(in_channels=3, out_channels=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=config['training']['learning_rate'])


In [None]:
# Train the model
trainer = Trainer(model, criterion, optimizer, device)
trainer.train(train_loader, num_epochs=config['training']['num_epochs'])


In [None]:
# Evaluate the model
metrics = trainer.evaluate(train_loader)
print(f"Training Metrics: {metrics}")


In [None]:
# Visualize some results
def visualize_predictions(loader, model, num_images=5):
    model.eval()
    with torch.no_grad():
        for i, (images, masks) in enumerate(loader):
            if i >= num_images:
                break
            images = images.to(device)
            outputs = model(images)
            outputs = torch.sigmoid(outputs).cpu().numpy()
            images = images.cpu().numpy()
            masks = masks.cpu().numpy()
            
            # Plotting
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            axes[0].imshow(images[0].transpose(1, 2, 0))
            axes[0].set_title('Input Image')
            axes[1].imshow(masks[0][0], cmap='gray')
            axes[1].set_title('Ground Truth Mask')
            axes[2].imshow(outputs[0][0], cmap='gray')
            axes[2].set_title('Predicted Mask')
            plt.show()

visualize_predictions(train_loader, model)
