In [1]:
import os
import numpy as np
import SimpleITK as sitk
from glob import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split


In [2]:
class SingleModalityNiftiDataset(Dataset):
    def __init__(self, file_list):
        self.file_list = file_list

    def __len__(self):
        return len(self.file_list)

    '''
    def __getitem__(self, idx):
        path = self.file_list[idx]
        image = sitk.GetArrayFromImage(sitk.ReadImage(path)).astype(np.float32)
        image = (image - np.mean(image)) / (np.std(image) + 1e-5)
        image = np.clip(image, -1, 1)
        image = np.expand_dims(image, axis=0)  # Add channel dimension
        return torch.tensor(image, dtype=torch.float32)
    '''
    def __getitem__(self, idx):
        path = self.file_list[idx]
        image = sitk.GetArrayFromImage(sitk.ReadImage(path)).astype(np.float32)

        # Z-score normalization
        mean, std = np.mean(image), np.std(image)
        zscore_img = (image - mean) / (std + 1e-5)

        # Min-max normalization to [0, 1]
        z_min, z_max = np.min(zscore_img), np.max(zscore_img)
        normalized_img = (zscore_img - z_min) / (z_max - z_min)

        normalized_img = np.expand_dims(normalized_img, axis=0)  # [C, H, W, D]
        return torch.tensor(normalized_img, dtype=torch.float32)

class SingleChannelAutoencoder3D(nn.Module):
    def __init__(self):
        super(SingleChannelAutoencoder3D, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, 3, padding=1), nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(16, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool3d(2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(32, 16, 2, stride=2), nn.ReLU(),
            nn.ConvTranspose3d(16, 1, 2, stride=2), nn.Tanh()
        )

    def forward(self, x):
        """
        Forward pass through the autoencoder.
        :param x: Input tensor of shape [B, C, D, H, W]
        :return: Output tensor of the same shape as input
        """
        return self.decoder(self.encoder(x))


In [3]:
# Load and combine all .nii.gz files from both T1 and T2 folders
t1_files = sorted(glob(r'D:\DS18\data\IXI-T1_resampled\*.nii.gz'))
#t2_files = sorted(glob("/mnt/data/IXI-T2_resampled/*.nii.gz"))
all_files = t1_files # + t2_files

# Split into training and validation sets
train_files, val_files = train_test_split(all_files, test_size=0.2, random_state=42)


In [5]:
# Prepare datasets and loaders
train_dataset = SingleModalityNiftiDataset(train_files)
val_dataset = SingleModalityNiftiDataset(val_files)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)



In [6]:
# Prepare model and training configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SingleChannelAutoencoder3D().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

In [7]:
# Train loop (5 epochs)
n_epochs = 5
history = []
for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        recon = model(batch)
        loss = loss_fn(recon, batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(train_loader)
    history.append(avg_loss)
    print(f"Epoch {epoch+1}/{n_epochs}, Loss: {avg_loss:.4f}")

Epoch 1/5, Loss: 0.0384
Epoch 2/5, Loss: 0.0040
Epoch 3/5, Loss: 0.0030
Epoch 4/5, Loss: 0.0025
Epoch 5/5, Loss: 0.0022


1. "AE_T1_single-channel_ixi_ep_1-10.pt" - baseline, noisy
2. "AE_T1_10ep_v1.1.pt" - first improvement to fix noise & extra bright, low dynamic range
3. "AE_T1_10ep_v1.2.pt" - as 2, restored tanh to decoder
4. "AE_T1_10ep_v1.3.pt" - went back to 2, no tanh decoder, err map created against normalized input => BAD RESULT
5. "AE_T1_10ep_v1.4.pt" - reasonable...
6. "AE_T1_10ep_v1.5.pt" - learning rate lr=1e-5, normalize & clip (-3,3)
7. "AE_T1_10ep_v1.6.pt" - learning rate lr=1e-4, min/mas & z-score normalization =>BAD, go back to 5

In [8]:
latest_model_name = "AE_T1_10ep_v1.6.pt"
torch.save(model.state_dict(), latest_model_name)

In [None]:
# 1. Load model checkpoint if saved
'''
model.load_state_dict(torch.load("AE_T1_single-channel_ixi_ep_1-5.pt"))
model.to(device)

# 2. Continue training for more epochs
extra_epochs = 5
for epoch in range(n_epochs + 1, n_epochs + extra_epochs + 1):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        recon = model(batch)
        loss = loss_fn(recon, batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch}/{n_epochs + extra_epochs}, Loss: {avg_loss:.4f}")
'''

In [9]:
#torch.save(model.state_dict(), "AE_T1_single-channel_ixi_ep_1-10.pt")

# Inference: Reconstruct and Generate Heatmap with Proper Denormalization

In [9]:
from pathlib import Path

In [10]:
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SingleChannelAutoencoder3D().to(device)

model_path = latest_model_name
if Path(model_path).exists():
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
else:
    raise FileNotFoundError("Model file not found.")

  model.load_state_dict(torch.load(model_path, map_location=device))


In [14]:
# Load input volume
#input_path = r"D:\DS18\data\IXI-T1_resampled\IXI002-Guys-0828-T1.nii.gz"
input_path = r"D:\DS18\data\BrainTumour\imagesTr\T1_resampled\BRATS_001_T1.nii.gz"
input_img = sitk.ReadImage(input_path)
input_arr = sitk.GetArrayFromImage(input_img).astype(np.float32)

# Normalize input
#mean, std = np.mean(input_arr), np.std(input_arr)
#normalized_input = (input_arr - mean) / (std + 1e-5)
#normalized_input = np.clip(normalized_input, -3, 3)

# Z-score normalization
mean, std = np.mean(input_arr), np.std(input_arr)
zscore_img = (input_arr - mean) / (std + 1e-5)

# Min-max normalization to [0, 1]
z_min, z_max = np.min(zscore_img), np.max(zscore_img)
normalized_input = (zscore_img - z_min) / (z_max - z_min)


input_tensor = torch.tensor(normalized_input[None, None, ...], dtype=torch.float32).to(device)



In [15]:
# Reconstruct
with torch.no_grad():
    reconstructed = model(input_tensor).cpu().numpy()[0, 0]

# Denormalize reconstruction
reconstructed_denorm = reconstructed * (std + 1e-5) + mean

# Compute error map in original scale
#error_map = np.abs(input_arr - reconstructed_denorm).astype(np.float32)
error_map = np.abs(normalized_input - reconstructed).astype(np.float32)



In [16]:

# Convert to SimpleITK images
#recon_img = sitk.GetImageFromArray(reconstructed_denorm)
recon_img = sitk.GetImageFromArray(reconstructed)
recon_img.CopyInformation(input_img)

error_img = sitk.GetImageFromArray(error_map)
error_img.CopyInformation(input_img)

normalized_input_img = sitk.GetImageFromArray(normalized_input)
normalized_input_img.CopyInformation(input_img)

# Save outputs
norm_input_path = "normalized_input_volume.nii.gz"
recon_out_path = "reconstructed_volume.nii.gz"
error_out_path = "error_heatmap.nii.gz"
sitk.WriteImage(normalized_input_img, norm_input_path)
sitk.WriteImage(recon_img, recon_out_path)
sitk.WriteImage(error_img, error_out_path)

norm_input_path, recon_out_path, error_out_path

('normalized_input_volume.nii.gz',
 'reconstructed_volume.nii.gz',
 'error_heatmap.nii.gz')