In [5]:
import os
import numpy as np
import SimpleITK as sitk
from glob import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [7]:

print(torch.version.cuda)          # Version PyTorch was compiled with
print(torch.backends.cudnn.version())  #

if torch.cuda.is_available():
    print(f"✅ CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("⚠️ CUDA not available. Using CPU.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

11.8
90100
✅ CUDA is available. Using GPU: NVIDIA T600 Laptop GPU


In [None]:
# Dataset that loads 3D NIfTI files from T1 and T2 folders
class IXINiftiDataset(Dataset):
    def __init__(self, t1_files, t2_files):
        self.file_list = list(zip(t1_files, t2_files))

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        t1_path, t2_path = self.file_list[idx]
        t1_img = sitk.GetArrayFromImage(sitk.ReadImage(t1_path))
        t2_img = sitk.GetArrayFromImage(sitk.ReadImage(t2_path))
        
        # Normalize each channel
        t1_img = (t1_img - np.mean(t1_img)) / (np.std(t1_img) + 1e-5)
        t2_img = (t2_img - np.mean(t2_img)) / (np.std(t2_img) + 1e-5)
        
        # Clip intensities
        t1_img = np.clip(t1_img, -3, 3)
        t2_img = np.clip(t2_img, -3, 3)
        
        # Stack to form 2-channel input: shape [2, D, H, W]
        volume = np.stack([t1_img, t2_img], axis=0)
        return torch.tensor(volume, dtype=torch.float32)



In [None]:
class MultiChannelAutoencoder3D(nn.Module):
    def __init__(self):
        super(MultiChannelAutoencoder3D, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(2, 16, 3, padding=1), nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(16, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool3d(2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(32, 16, 2, stride=2), nn.ReLU(),
            nn.ConvTranspose3d(16, 2, 2, stride=2), nn.Tanh()
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))


In [None]:
# File paths
t1_dir = "/mnt/data/IXI-T1"
t2_dir = "/mnt/data/IXI-T2"
t1_files = sorted(glob(os.path.join(t1_dir, '*.nii.gz')))
t2_files = sorted(glob(os.path.join(t2_dir, '*.nii.gz')))


In [None]:
# Ensure matching T1 and T2 file counts
file_pairs = list(zip(t1_files, t2_files))
train_pairs, val_pairs = train_test_split(file_pairs, test_size=0.2, random_state=42)


In [None]:
# Dataloaders
train_dataset = IXINiftiDataset([p[0] for p in train_pairs], [p[1] for p in train_pairs])
val_dataset = IXINiftiDataset([p[0] for p in val_pairs], [p[1] for p in val_pairs])
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)


In [None]:
# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiChannelAutoencoder3D().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

In [None]:
# Training loop
n_epochs = 5
for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        recon = model(batch)
        loss = loss_fn(recon, batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss / len(train_loader):.4f}")



In [None]:
model.eval()
sample = next(iter(val_loader)).to(device)
with torch.no_grad():
    recon = model(sample)
slice_idx = sample.shape[2] // 2
fig, axs = plt.subplots(2, 3, figsize=(12, 6))
for i in range(2):  # T1 and T2
    axs[i, 0].imshow(sample[0, i, slice_idx].cpu(), cmap='gray')
    axs[i, 0].set_title(f'Original Modality {i+1}')
    axs[i, 1].imshow(recon[0, i, slice_idx].cpu(), cmap='gray')
    axs[i, 1].set_title(f'Reconstructed Modality {i+1}')
    axs[i, 2].imshow(torch.abs(sample[0, i, slice_idx] - recon[0, i, slice_idx]).cpu(), cmap='hot')
    axs[i, 2].set_title(f'Error Map Modality {i+1}')
    for j in range(3):
        axs[i, j].axis('off')
plt.tight_layout()
plt.show()