# Medical Image Registration Workflow

This notebook demonstrates the complete workflow for training and testing a deformation-based medical image registration model using segmentation data.

## 1. File Path Configuration

Update file paths in the training text file by replacing old path prefixes with new ones.

In [None]:
# Read the old file
with open('train_npy_copy.txt', 'r') as f:
    lines = f.readlines()

# Replace the path prefix
new_lines = [line.replace('neurite-oasis.v1.0/', 'segmentation_data/') for line in lines]

# Write the new file
with open('train_npy.txt', 'w') as f:
    f.writelines(new_lines)

print("Path replacement completed!")
print(f"Updated {len(new_lines)} file paths")

## 2. Model Training Setup

Execute the training script with specified parameters including training data path, template path, batch size, and epochs.

In [None]:
# Train the model using the updated paths and simplified loss functions
!python train.py --train_txt train_npy.txt --template_path segmentation_data/OASIS_OAS1_0016_MR1/seg4_onehot.npy --batch_size 1 --epochs 5 --save_model_path checkpoints/regloss_model_latest.pth

## 3. Quick Test & Save Model

Let's do a quick test to see if the training worked and save the model to your drive.

In [None]:
from get_data import SegDataset
from model import UNet, SpatialTransformer
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_txt_path = 'train_npy.txt'
template_path = 'segmentation_data/OASIS_OAS1_0016_MR1/seg4_onehot.npy'
model_path = 'checkpoints/regloss_model_latest.pth'

# Re-create dataset and dataloader
train_dataset = SegDataset(train_txt_path, template_path)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)

# Re-create your model (UNet + SpatialTransformer)
model = UNet(in_channels=10, out_channels=3).to(device)
stn = SpatialTransformer(size=(128,128,128), device=device).to(device)

# Load the trained model
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("Model loaded successfully!")

In [None]:
sample_batch = next(iter(train_loader))
moving, fixed = sample_batch
moving = moving.to(device)
fixed = fixed.to(device)

with torch.no_grad():
    # 1. Concatenate moving and fixed as input to UNet
    input_ = torch.cat([moving, fixed], dim=1)  # Shape: (B, 10, 128, 128, 128)

    # 2. Get deformation field from UNet
    deformation_field = model(input_)

    # 3. Apply deformation using SpatialTransformer
    warped_template = stn(moving, deformation_field)

In [None]:
from visualize import visualize_registration_colab
visualize_registration_colab(moving, fixed, warped_template)

In [None]:
!cp checkpoints/regloss_model_latest.pth /content/drive/MyDrive/segmentation-project/