In [3]:
# pytorch
!pip install pytorch_lightning torchvision torchaudio --quiet

# imaging libraries
!pip install SimpleITK nibabel monai --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m776.3/776.3 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 MB[0m [31m24.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m74.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
# google drive
from google.colab import drive

# Data libraries
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler

# import supporting libraries
import os
import tarfile
from PIL import Image
from tqdm import tqdm
import tarfile
from io import BytesIO
from concurrent.futures import ThreadPoolExecutor
import pickle

# Visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns

# modeling librarires
import torch, torchvision as tv, torchaudio as ta
import pytorch_lightning as pl
from sklearn.decomposition import PCA
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.metrics import confusion_matrix

# imaging libraries
import SimpleITK as sitk, nibabel as nib, monai

In [5]:
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
def list_folders_subfolders(directory):
  for root, dirs, files in os.walk(directory):
    level = root.replace(directory, '').count(os.sep)
    indent = ' ' * 4 * (level)
    print('{}{}/'.format(indent, os.path.basename(root)))
    subindent = ' ' * 4 * (level + 1)
    for f in files:
      print('{}{}'.format(subindent, f))

In [7]:
def dice(y_true, y_pred):
  card_interaction = np.logical_and(y_true, y_pred).sum()

  card_true = np.bool_(y_true).sum()
  card_pred = np.bool_(y_pred).sum()

  dsc = (2. * card_interaction / (card_true + card_pred))

  return dsc

def mse(y_true, y_pred):
  return ((y_true - y_pred)**2).mean()

In [8]:
def register_image(fixed_image, moving_image,
                   transform_type = 'similarity',
                   metric = 'mmi',
                   num_bins = 50,
                   learning_rate = 1.0,
                   max_iters = 200,
                   interp_method = 'linear',
                   default_pixel_value = 0.,
                   min_convergence = 1e-6, convergence_window = 20,
                   sitk_dtype = sitk.sitkFloat32):

  device, final_dtype = None, None
  # Get fixed image as sitk array
  if isinstance(fixed_image, np.ndarray):
    if fixed_image.dtype == np.float16:
      final_dtype = np.float16
    elif fixed_image.dtype == np.float32:
      final_dtype = np.float32
    elif fixed_image.dtype == np.float64:
      final_dtype = np.float64

    fixed_image = sitk.GetImageFromArray(fixed_image)

  elif isinstance(fixed_image, torch.Tensor):
    if fixed_image.dtype == torch.float16:
      final_dtype = torch.float16
    elif fixed_image.dtype == np.float32:
      final_dtype = torch.float32
    elif fixed_image.dtype == np.float64:
      final_dtype = torch.float64

    device = fixed_image.device
    fixed_image = sitk.GetImageFromArray(fixed_image.cpu().numpy())

  # Get moving image as sitk array
  if isinstance(moving_image, str):
    moving_image = sitk.ReadImage(moving_image, sitk_dtype)
  elif isinstance(moving_image, np.ndarray):
    moving_image = sitk.GetImageFromArray(moving_image)
  elif isinstance(moving_image, torch.Tensor):
    moving_image = sitk.GetImageFromArray(moving_image.cpu())

  assert fixed_image.GetSize() == moving_image.GetSize(), "Image sizes do not match"
  assert fixed_image.GetSpacing() == moving_image.GetSpacing(), "Image spacings do not match"
  assert fixed_image.GetDimension() == moving_image.GetDimension(), "Image dimensions do not match"

  # Instantiate registration method
  R = sitk.ImageRegistrationMethod()

  # Set up metric
  if metric == 'mmi':
    R.SetMetricAsMattesMutualInformation(num_bins)
  elif metric == 'ms':
    R.SetMetricAsMeanSquares()
  elif metric == 'jhmi':
    R.SetMetricAsJointHistogramMutualInformation(num_bins)
  else:
    raise ValueError(f"The metric ({metric}) must be 'mmi', 'ms', or 'jhmi'.")

  # Set optimizer
  R.SetOptimizerAsGradientDescent(learningRate = learning_rate,
                                  numberOfIterations = max_iters,
                                  convergenceMinimumValue = min_convergence,
                                  convergenceWindowSize = convergence_window)

  # Define transform
  if transform_type == 'similarity':
    transform = sitk.Similarity3DTransform()
  elif transform_type == 'euler':
    transform = sitk.Euler3DTransform()
  elif transform_type == 'translation':
    transform = sitk.TranslationTransform()
  elif transform_type == 'versor':
    transform = sitk.VersorTransform()
  elif transform_type == 'versorrigid':
    transform = sitk.VersorRigid3DTransform()
  elif transform_type == 'scale':
    transform = sitk.ScaleTransform()
  elif transform_type == 'scaleversor':
    transform = sitk.ScaleVersor3DTransform()
  elif transform_type == 'scaleskewversor':
    transform = sitk.ScaleSkewVersor3DTransform()
  elif transform_type == 'composescaleskewversor':
    transform = sitk.ComposeScaleSkewVersor3DTransform()
  elif transform_type == 'affine':
    transform = sitk.AffineTransform()
  elif transform_type == 'bspline':
    transform = sitk.BSplineTransform()
  elif transform_type == 'displacement':
    transform = sitk.DisplacementFieldTransform()
  elif transform_type == 'composite':
    transform = sitk.CompositeTransform()

  # Align centers of fixed and moving image
  initial_transform = sitk.CenteredTransformInitializer(fixedImage = fixed_image,
                                                        movingImage = moving_image,
                                                        transform = transform)

  R.SetInitialTransform(initial_transform)

  # Set scales for optimization
  R.SetOptimizerScalesFromIndexShift()

  # Set up interpolator
  if interp_method == 'linear':
    interpolator = sitk.sitkLinear
  elif interp_method == 'nn':
    interpolator = sitk.sitkNearestNeighbor
  elif interp_method == 'spline':
    interpolator = sitk.sitkBSpline
  else:
    raise ValueError(f"The interpolator ({interpolator}) must be 'linear', 'nn', or 'spline'.")

  R.SetInterpolator(interpolator)

  # Set progress
  metric_values = []

  def command_iteration(method):

    metric_values.append(method.GetMetricValue())

    print(f"Iter. {method.GetOptimizerIteration():3} "
          + f"= {metric_values[-1]:.4f} ")
          # + f": {method.GetOptimizerPosition()}")

  R.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(R))

  # Set transform
  final_transform = R.Execute(fixed = fixed_image, moving = moving_image)

  print(final_transform)
  print(f"Optimizer stop condition: {R.GetOptimizerStopConditionDescription()}")
  print(f"Iteration: {R.GetOptimizerIteration()}")
  print(f"Metric: {R.GetMetricValue():.4f}")

  # Apply transformation to image and interpolate the result
  resampler = sitk.ResampleImageFilter()
  resampler.SetReferenceImage(fixed_image)
  resampler.SetInterpolator(interpolator)
  resampler.SetDefaultPixelValue(default_pixel_value)
  resampler.SetTransform(final_transform)

  output_image = resampler.Execute(moving_image)

  # Convert moving image to original dtype
  if final_dtype in [torch.float16, torch.float32, torch.float64]:
    output_image = torch.tensor(sitk.GetArrayFromImage(output_image)).to(device = device,
                                                                         dtype = final_dtype)
  elif final_dtype in [np.float16, np.float32, np.float64]:
    output_image = sitk.GetArrayFromImage(output_image).astype(final_dtype)


  return output_image, metric_values

def register_subject_images(images,
                            fixed_image_idx = 0,
                            normalize_images = False,
                            transform_type = 'similarity',
                            metric = 'mmi',
                            num_bins = 50,
                            learning_rate = 1.0,
                            max_iters = 200,
                            interp_method = 'linear',
                            default_pixel_value = 0.,
                            min_convergence = 1e-6, convergence_window = 20,
                            sitk_dtype = sitk.sitkFloat32):

  num_subjects = len(images)
  num_images = images[0]['images'].shape[-1]

  moving_image_idx = np.where(np.arange(num_images) != fixed_image_idx)[0]

  for i,subject_images in enumerate(images):
    subject_images['scaler'] = None
    if normalize_images:
      subject_images['scaler'] = []
      for i in range(num_images):
        scaler = MinMaxScaler(feature_range = (0, 1))

        image_i = subject_images['image'][:, :, :, i]
        if isinstance(image_i, torch.Tensor):
          image_is = torch.tensor(scaler.fit_transform(image_i.cpu().reshape(-1, 1)).reshape(image_i.shape)).to(image_i.device, image_i.dtype)
        else:
          image_is = scaler.fit_transform(image_i.cpu().reshape(-1, 1)).reshape(image_i.shape)

        subject_images['images'][:, :, :, i] = image_is
        subject_images['scaler'].append(scaler)

    fixed_image = subject_images['images'][:, :, :, fixed_image_idx]

    for idx in moving_image_idx:
      print(f"Registering image {idx+1} to image {fixed_image_idx+1} for subject {subject_images['id']} ({i+1}/{num_subjects}))")

      subject_images['images'][:, :, :, idx], _ = register_image(fixed_image = fixed_image,
                                                                  moving_image = subject_images['images'][:, :, :, idx],
                                                                  transform_type = transform_type,
                                                                  metric = metric,
                                                                  num_bins = num_bins,
                                                                  learning_rate = learning_rate,
                                                                  max_iters = max_iters,
                                                                  interp_method = interp_method,
                                                                  default_pixel_value = default_pixel_value,
                                                                  min_convergence = min_convergence,
                                                                  convergence_window = convergence_window,
                                                                  sitk_dtype = sitk_dtype)

  return images

In [9]:

def get_images(image_path, dtype=np.float32):
    # Load the NIfTI file using nibabel
    images = nib.load(image_path)

    # Access the data from the NIfTI file and convert it to the specified data type
    data = images.get_fdata().astype(dtype)

    # Return the image data array
    return data

def access_images(task_path,
                  image_dir='imagesTr/',
                  label_dir='labelsTr/',
                  dtype=None,
                  sample_size=None, random_seed=42,
                  register_images=False,
                  fixed_image_idx=0,
                  normalize_images=False,
                  transform_type='similarity',
                  metric='mmi',
                  num_bins=50,
                  learning_rate=1.0,
                  max_iters=200,
                  interp_method='linear',
                  default_pixel_value=0.,
                  min_convergence=1e-6, convergence_window=20,
                  sitk_dtype=sitk.sitkFloat32):

    # Construct the dataset and label paths from the given directories
    dataset_path = os.path.join(task_path, image_dir)
    label_path = None
    if 'Tr' in image_dir:
        label_path = os.path.join(task_path, label_dir)

    # List all NIfTI files in the dataset_path that match the criteria
    nii_list = [nii for nii in os.listdir(dataset_path) if ('nii' in nii) and nii.startswith('BRATS')]

    # Sample a subset of files if a sample size is specified
    if sample_size is not None:
        if random_seed is not None: np.random.seed(random_seed)
        np.random.shuffle(nii_list)
        nii_list = nii_list[:sample_size]

    # Get the total number of subjects to process
    num_subjects = len(nii_list)

    # Initialize an empty list to hold the data
    data = []
    for i, image_file in enumerate(tqdm(nii_list, desc="Loading NIfTI files", unit='file')):

        # Extract the ID from the filename
        id = image_file.split('_')[1].split('.')[0]

        # Append a new dictionary to hold data for the current subject
        data.append({})

        # Assign the extracted ID to the current subject's data
        data[-1]['id'] = id

        # Retrieve the image data using the get_images function
        data[-1]['images'] = get_images(os.path.join(dataset_path, image_file))

        # Register images if specified
        if register_images:
            # Get the number of images (e.g., time points or modalities)
            num_images = data[-1]['images'].shape[-1]

            # Determine the indices of images that are not the fixed image
            moving_image_idx = np.where(np.arange(num_images) != fixed_image_idx)[0]

            # Initialize the image scaler if image normalization is needed
            data[-1]['image_scaler'] = None
            if normalize_images:
                data[-1]['image_scaler'] = []
                for j in range(num_images):
                    # Initialize the MinMaxScaler
                    scaler = MinMaxScaler(feature_range=(0, 1))

                    # Retrieve the j-th image from the current subject
                    image_j = data[-1]['images'][:, :, :, j]
                    if isinstance(image_j, torch.Tensor):
                        # Normalize the tensor image and retain its type and device
                        image_js = torch.tensor(scaler.fit_transform(image_j.cpu().reshape(-1, 1))
                                                .reshape(image_j.shape)).to(image_j.device, image_j.dtype)
                    else:
                        # Normalize the numpy array image
                        image_js = scaler.fit_transform(image_j.reshape(-1, 1)).reshape(image_j.shape)

                    # Update the j-th image with the normalized image
                    data[-1]['images'][:, :, :, j] = image_js
                    # Append the scaler to the list of scalers for future inverse transformation if necessary
                    data[-1]['image_scaler'].append(scaler)

            # Set the fixed image for registration
            fixed_image = data[-1]['images'][:, :, :, fixed_image_idx]

            # Register each moving image to the fixed image
            for idx in moving_image_idx:
                print()
                print(f"Registering image {idx+1} to image {fixed_image_idx+1} for subject {data[-1]['id']} ({i+1}/{num_subjects}))")
                data[-1]['images'][:, :, :, idx], _ = register_image(fixed_image=fixed_image,
                                                                     moving_image=data[-1]['images'][:, :, :, idx],
                                                                     transform_type=transform_type,
                                                                     metric=metric,
                                                                     num_bins=num_bins,
                                                                     learning_rate=learning_rate,
                                                                     max_iters=max_iters,
                                                                     interp_method=interp_method,
                                                                     default_pixel_value=default_pixel_value,
                                                                     min_convergence=min_convergence,
                                                                     convergence_window=convergence_window,
                                                                     sitk_dtype=sitk_dtype)

        # If a label path is specified, load the labels for the current subject
        if label_path is not None:
            labels = get_images(os.path.join(label_path, image_file))
            data[-1]['labels'] = labels

    # Return the list of subjects with their corresponding images and labels
    return data


In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float32

print(f"device = {device}, dtype = {dtype}")

device = cuda, dtype = torch.float32


In [11]:
# torch.cuda.empty_cache()
# print(torch.cuda.memory_summary())

In [None]:
task_path = '/content/drive/MyDrive/data/MSD/Task01_BrainTumour'
images = access_images(task_path,
                       dtype = dtype,
                       sample_size = 100,
                       register_images = True,
                       num_bins = 25,
                       convergence_window = 10)

In [None]:
file_path = "/content/drive/MyDrive/MRI/images.pkl"
with open(file_path, "wb") as file:
  pickle.dump(images, file)

In [13]:
# Define the file path where your data is stored
file_path = "/content/drive/MyDrive/MRI/images.pkl"

# Open the file in binary read mode
with open(file_path, "rb") as file:
    # Load the data from the pickle file
    images = pickle.load(file)

In [None]:
scaler_t1 = MinMaxScaler(feature_range = (0, 1))
scaler_t1c = MinMaxScaler(feature_range = (0, 1))
scaler_t2 = MinMaxScaler(feature_range = (0, 1))
scaler_flair = MinMaxScaler(feature_range = (0, 1))

i_subject = 1

img = images[i_subject]['images']
img_shape = img.shape

labels = images[i_subject]['labels']

t1 = img[:, :, :, 0]
t1c = img[:, :, :, 1]
t2 = img[:, :, :, 2]
flair = img[:, :, :, 3]

t1_n = scaler_t1.fit_transform(t1.reshape(-1,1))
t1c_n = scaler_t1.fit_transform(t1c.reshape(-1,1)).reshape(img_shape[:3])
t2_n = scaler_t1.fit_transform(t2.reshape(-1,1)).reshape(img_shape[:3])
flair_n = scaler_t1.fit_transform(flair.reshape(-1,1)).reshape(img_shape[:3])

t1_n = torch.tensor(t1_n).to(device = device, dtype = dtype)
t1c_n = torch.tensor(t1c_n).to(device = device, dtype = dtype)
t2_n = torch.tensor(t2_n).to(device = device, dtype = dtype)
flair_n = torch.tensor(flair_n).to(device = device, dtype = dtype)


In [None]:

i, j, k = 120, 120, 77
fig, ax = plt.subplots(5, 3, figsize = (20, 20))
ax[0,0].imshow(t1[i, :, :], cmap = 'gray')
ax[1,0].imshow(t1c[i, :, :], cmap = 'gray')
ax[2,0].imshow(t2[i, :, :], cmap = 'gray')
ax[3,0].imshow(flair[i, :, :], cmap = 'gray')
ax[4,0].imshow(labels[i, :, :])

ax[0,1].imshow(t1[:, j, :], cmap = 'gray')
ax[1,1].imshow(t1c[:, j, :], cmap = 'gray')
ax[2,1].imshow(t2[:, j, :], cmap = 'gray')
ax[3,1].imshow(flair[:, j, :], cmap = 'gray')
ax[4,1].imshow(labels[:, j, :])

ax[0,2].imshow(t1[:, :, k], cmap = 'gray')
ax[1,2].imshow(t1c[:, :, k], cmap = 'gray')
ax[2,2].imshow(t2[:, :, k], cmap = 'gray')
ax[3,2].imshow(flair[:, :, k], cmap = 'gray')
ax[4,2].imshow(labels[:, :, k])

fig.tight_layout()