In [1]:
import os
import numpy as np
import SimpleITK as sitk
from glob import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split


In [2]:
# Dataset class that loads each modality (T1 or T2) separately
class SingleModalityNiftiDataset(Dataset):
    def __init__(self, file_list):
        self.file_list = file_list

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        path = self.file_list[idx]
        image = sitk.GetArrayFromImage(sitk.ReadImage(path)).astype(np.float32)
        image = (image - np.mean(image)) / (np.std(image) + 1e-5)
        image = np.clip(image, -3, 3)
        image = np.expand_dims(image, axis=0)  # Add channel dimension
        return torch.tensor(image, dtype=torch.float32)

# Define a simple 3D Autoencoder for single-channel input
class SingleChannelAutoencoder3D(nn.Module):
    def __init__(self):
        super(SingleChannelAutoencoder3D, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, 3, padding=1), nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(16, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool3d(2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(32, 16, 2, stride=2), nn.ReLU(),
            nn.ConvTranspose3d(16, 1, 2, stride=2), nn.Tanh()
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))


In [4]:
# Load and combine all .nii.gz files from both T1 and T2 folders
t1_files = sorted(glob(r'D:\DS18\data\IXI-T1_resampled\*.nii.gz'))
#t2_files = sorted(glob("/mnt/data/IXI-T2_resampled/*.nii.gz"))
all_files = t1_files # + t2_files

# Split into training and validation sets
train_files, val_files = train_test_split(all_files, test_size=0.2, random_state=42)


In [5]:
# Prepare datasets and loaders
train_dataset = SingleModalityNiftiDataset(train_files)
val_dataset = SingleModalityNiftiDataset(val_files)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)

# Prepare model and training configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SingleChannelAutoencoder3D().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

In [6]:
# Train loop (5 epochs)
n_epochs = 5
history = []
for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        recon = model(batch)
        loss = loss_fn(recon, batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(train_loader)
    history.append(avg_loss)
    print(f"Epoch {epoch+1}/{n_epochs}, Loss: {avg_loss:.4f}")

Epoch 1/5, Loss: 0.4535
Epoch 2/5, Loss: 0.2123
Epoch 3/5, Loss: 0.1982
Epoch 4/5, Loss: 0.1923
Epoch 5/5, Loss: 0.1887


In [7]:
torch.save(model.state_dict(), "AE_T1_single-channel_ixi_ep_1-5.pt")

In [8]:
# 1. Load model checkpoint if saved
model.load_state_dict(torch.load("AE_T1_single-channel_ixi_ep_1-5.pt"))
model.to(device)

# 2. Continue training for more epochs
extra_epochs = 5
for epoch in range(n_epochs + 1, n_epochs + extra_epochs + 1):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        recon = model(batch)
        loss = loss_fn(recon, batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch}/{n_epochs + extra_epochs}, Loss: {avg_loss:.4f}")


  model.load_state_dict(torch.load("AE_T1_single-channel_ixi_ep_1-5.pt"))


Epoch 6/10, Loss: 0.1858
Epoch 7/10, Loss: 0.1834
Epoch 8/10, Loss: 0.1815
Epoch 9/10, Loss: 0.1799
Epoch 10/10, Loss: 0.1786


In [9]:
torch.save(model.state_dict(), "AE_T1_single-channel_ixi_ep_1-10.pt")

##Load & test

In [10]:
import torch
import SimpleITK as sitk
import numpy as np
from pathlib import Path
import torch.nn as nn

In [11]:
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SingleChannelAutoencoder3D().to(device)

model_path = "AE_T1_single-channel_ixi_ep_1-10.pt"
if Path(model_path).exists():
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
else:
    raise FileNotFoundError("Model file not found.")

  model.load_state_dict(torch.load(model_path, map_location=device))


In [16]:
# Load input volume
#input_path = r"D:\DS18\data\IXI-T1_resampled\IXI002-Guys-0828-T1.nii.gz"
input_path = r"D:\DS18\data\BrainTumour\imagesTr\T1_resampled\BRATS_001_T1.nii.gz"
input_img = sitk.ReadImage(input_path)
input_arr = sitk.GetArrayFromImage(input_img).astype(np.float32)

# Normalize input
normalized_input = (input_arr - np.mean(input_arr)) / (np.std(input_arr) + 1e-5)
normalized_input = np.clip(normalized_input, -3, 3)
input_tensor = torch.tensor(normalized_input[None, None, ...], dtype=torch.float32).to(device)


In [17]:
# Reconstruct
with torch.no_grad():
    reconstructed = model(input_tensor).cpu().numpy()[0, 0]

# Compute reconstruction error map
error_map = np.square(normalized_input - reconstructed).astype(np.float32)

# Convert to SimpleITK images
recon_img = sitk.GetImageFromArray(reconstructed)
recon_img.CopyInformation(input_img)

error_img = sitk.GetImageFromArray(error_map)
error_img.CopyInformation(input_img)

# Save outputs
recon_out_path = "reconstructed_volume.nii.gz"
error_out_path = "error_heatmap.nii.gz"
sitk.WriteImage(recon_img, recon_out_path)
sitk.WriteImage(error_img, error_out_path)

recon_out_path, error_out_path

('reconstructed_volume.nii.gz', 'error_heatmap.nii.gz')