In [None]:
# ANTSPYPX documentation https://antspyx.readthedocs.io/en/latest/core.html
!pip install antspyx

# not needed anymore
# %env SM_FRAMEWORK=tf.keras


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting antspyx
  Downloading antspyx-0.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (326.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m326.4/326.4 MB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting chart-studio
  Downloading chart_studio-1.1.0-py3-none-any.whl (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.4/64.4 KB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Collecting webcolors
  Downloading webcolors-1.12-py3-none-any.whl (9.9 kB)
Collecting retrying>=1.3.3
  Downloading retrying-1.3.4-py3-none-any.whl (11 kB)
Installing collected packages: webcolors, retrying, chart-studio, antspyx
Successfully installed antspyx-0.3.7 chart-studio-1.1.0 retrying-1.3.4 webcolors-1.12


In [None]:
import os 
import ants
import zipfile
import pandas as pd
from pathlib import Path
import nibabel as nib
from tqdm import tqdm
import psutil
import matplotlib.pyplot as plt

import numpy as np
from sklearn.model_selection import train_test_split

import random
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
proj_path = r'/content/drive/MyDrive/Capstone-Project/'
os.chdir(proj_path)

base_data_path = os.path.join(proj_path, 'data')
nfbs_path = os.path.join(base_data_path, 'nfbs-data')
nfbs_zip = 'NFBS_Dataset.tar.gz'

full_path = os.path.join(nfbs_path, 'NFBS_Dataset.tar.gz')

# unzip files 
# import tarfile
# file = tarfile.open(full_path)
# file.extractall(base_path)
# file.close()

In [None]:
print('Each folder contains..')
print(os.listdir(os.path.join(nfbs_path, 'NFBS_Dataset/A00028185')))

Each folder contains..
['sub-A00028185_ses-NFB3_T1w_brainmask.nii.gz', 'sub-A00028185_ses-NFB3_T1w_brain.nii.gz', 'sub-A00028185_ses-NFB3_T1w.nii.gz']


Each folder contains 3 types of images in nifti format..

T1weighted :This is the raw MRI image with a single channel. Image is 3D and can be imagined as multiple 2D images stacked together.

T1w_brainmask: It is the image mask of the brain or can be called as the ground truth. It is obtained using Beast(Brain extraction based on non local segmentation) method and applying manual edits by domain experts to remove non brain tissue.

T1w_brain:This can be thought of as part of brain stripped from above T1weighted image. This is similar to overlaying mask to actual images.

In [None]:
img=ants.image_read(os.path.join(nfbs_path,'NFBS_Dataset/A00028185/sub-A00028185_ses-NFB3_T1w.nii.gz'))
print('Shape of image=',img.shape)

Shape of image= (256, 256, 192)


In [None]:
#storing the address of 3 types of files
brain_mask=[]
brain=[]
raw=[]
for subdir, dirs, files in os.walk(os.path.join(nfbs_path,'NFBS_Dataset')):
  for file in files:

    #print os.path.join(subdir, file)y
    filepath = subdir + os.sep + file

    if filepath.endswith(".gz"):
      if '_brainmask.' in filepath:
        brain_mask.append(filepath)
      elif '_brain.' in filepath:
        brain.append(filepath)
      else:
        raw.append(filepath)

#creating a dataframe for ease of use..
data=pd.DataFrame({'brain_mask':brain_mask,'brain':brain,'raw':raw})
data.head()

Unnamed: 0,brain_mask,brain,raw
0,/content/drive/MyDrive/Capstone-Project/data/n...,/content/drive/MyDrive/Capstone-Project/data/n...,/content/drive/MyDrive/Capstone-Project/data/n...
1,/content/drive/MyDrive/Capstone-Project/data/n...,/content/drive/MyDrive/Capstone-Project/data/n...,/content/drive/MyDrive/Capstone-Project/data/n...
2,/content/drive/MyDrive/Capstone-Project/data/n...,/content/drive/MyDrive/Capstone-Project/data/n...,/content/drive/MyDrive/Capstone-Project/data/n...
3,/content/drive/MyDrive/Capstone-Project/data/n...,/content/drive/MyDrive/Capstone-Project/data/n...,/content/drive/MyDrive/Capstone-Project/data/n...
4,/content/drive/MyDrive/Capstone-Project/data/n...,/content/drive/MyDrive/Capstone-Project/data/n...,/content/drive/MyDrive/Capstone-Project/data/n...


In [None]:
img = ants.image_read(data.raw.iloc[0])
[ants.plot(img, crop=False, axis=ax) for ax in range(3)]

In [None]:
#lets visualize a couple of examples
import matplotlib.pyplot as plt
for i in range(2):
  fig,ax=plt.subplots(1,3,figsize=(10,7))
  ax[0].set_title('Raw image')
  img = nib.load(data.raw.iloc[i]).get_fdata()
  ax[0].imshow(img[img.shape[0]//2])
  ax[1].set_title('Skull strippedimage')
  img = nib.load(data.brain.iloc[i]).get_fdata()
  ax[1].imshow(img[img.shape[0]//2])
  ax[2].set_title('Brain mask image')
  img = nib.load(data.brain_mask.iloc[i]).get_fdata()
  ax[2].imshow(img[img.shape[0]//2])

We compute a histogram of the image intensity values. Its counting the number of pixels that falls into each bin. Because there are a set 256 bins, I can choose between the 125th and 175 bin to compute the mean.

After computing the mean we can adjust the low and high thresholds above and below the mean to ensure that the mask includes only the brain tissue and not the empty space around it. This helps to exclude the regions of the image that have lower or higher intensity values than the brain tissue, which are likely to correspond to background noise or non-brain tissue regions.

This masking is necessary for subsequent image processing steps, such as segmentation or registration, where we want to focus on the brain tissue and avoid any unwanted artifacts or background noise.

In [None]:
# a = np.array([])
# for idx, row in tqdm(data.iterrows()):
#   img = ants.image_read(row['raw'])
#   a = np.append(a, img)
#   process = psutil.Process()
#   memory_info = process.memory_info()
#   # print(memory_info)
#   if idx % 10 == 0:
#     print(f" After loading {idx + 1} images, Memory used: {memory_info.rss / 1024 / 1024:.2f} MB")
#   #print(f"Finish reading {idx}")

In [None]:
# TAKES 1 HOUR TO RUN, DO NOT RUN AGAIN, WE HAVE DATAbIN PP_Images folder

# want new img to have some similar indicator to previous img
file_labels = [dirs for subdir, dirs, files in os.walk(os.path.join(nfbs_path, 'NFBS_Dataset'))]
file_labels = file_labels[0]

pp_img_dir = os.path.join(nfbs_path,'PP_Images', 'raw')
pp_mask_dir = os.path.join(nfbs_path,'PP_Images', 'mask')

# check if path exists, if not make it
Path(pp_img_dir).mkdir(parents=True, exist_ok=True)
Path(pp_mask_dir).mkdir(parents=True, exist_ok=True)

suffix = '_pp.nii.gz'

# preprocessing loop
for idx, row in tqdm(data.iterrows()):
  img = ants.image_read(row['raw'])

  # image histogram
  hist, bins = np.histogram(img.numpy().ravel(), bins=256)

  # mean intensity value of the brain tissue
  mean_intensity = bins[125:175].mean()

  # Set the lower and upper thresholds to exclude the empty space
  low_threshold = mean_intensity + 0.1 * mean_intensity
  high_threshold = np.amax(img.numpy()) - 0.1 * np.amax(img.numpy())

  # binary mask of the brain region
  mask = ants.get_mask(img, low_thresh=low_threshold, high_thresh=high_threshold, cleanup=2)

  # n4 bias, normalize z score
  img_n4 = ants.n4_bias_field_correction(img, shrink_factor=3, mask=mask, convergence={'iters':[20, 10, 10, 5], 'tol':1e-07}, rescale_intensities=True).iMath_normalize()

  # resize
  # interp_type= 1 (nearest neighbor)
  img_mask_ds = ants.image_read(row['brain_mask']).resample_image((96,128,160), use_voxels=True, interp_type=1)
  img_n4_ds = img_n4.resample_image((96,128,160), use_voxels=True, interp_type=1)

  # saving
  ants.image_write(img_n4_ds, os.path.join(pp_img_dir, file_labels[idx] + '_raw' + suffix), ri=False)
  ants.image_write(img_mask_ds, os.path.join(pp_mask_dir, file_labels[idx] + '_mask' + suffix), ri=False)


In [None]:
# code very messy, fix later
pp_path = os.path.join(nfbs_path, 'PP_Images')
raw = []
mask = []
j = 0

for subdir, dirs, files in os.walk(os.path.join(pp_path)):
  j += 1
  for i, file in enumerate(files):
    if j == 2:
      img = ants.image_read(os.path.join(pp_path, 'raw', file))
      raw.append(img)
    elif j == 3:
      img = ants.image_read(os.path.join(pp_path, 'mask', file))
      mask.append(img)
    
    process = psutil.Process()
    memory_info = process.memory_info()
    # print(memory_info)
    if i % 10 == 0:
      print(f" After loading {i + 1} images, Memory used: {memory_info.rss / 1024 / 1024:.2f} MB")
    #print(f"Finish reading {idx}")

 After loading 1 images, Memory used: 1108.59 MB
 After loading 11 images, Memory used: 1135.56 MB
 After loading 21 images, Memory used: 1210.33 MB
 After loading 31 images, Memory used: 1285.35 MB
 After loading 41 images, Memory used: 1360.38 MB
 After loading 51 images, Memory used: 1435.39 MB
 After loading 61 images, Memory used: 1510.46 MB
 After loading 71 images, Memory used: 1585.48 MB
 After loading 81 images, Memory used: 1660.50 MB
 After loading 91 images, Memory used: 1735.53 MB
 After loading 101 images, Memory used: 1810.62 MB
 After loading 111 images, Memory used: 1885.69 MB
 After loading 121 images, Memory used: 1960.58 MB
 After loading 1 images, Memory used: 1998.22 MB
 After loading 11 images, Memory used: 2072.99 MB
 After loading 21 images, Memory used: 2148.01 MB
 After loading 31 images, Memory used: 2223.04 MB
 After loading 41 images, Memory used: 2298.06 MB
 After loading 51 images, Memory used: 2373.08 MB
 After loading 61 images, Memory used: 2448.11 MB

Way less memory used

In [None]:
# 3D mask nice
ms = mask[0]
[ants.plot(ms, crop=False, axis=ax) for ax in range(3)]

NameError: ignored

In [None]:
import torch
from skimage.io import imread
from torch.utils import data as d
import matplotlib.pyplot as plt

class SegmentationDataSet(d.Dataset):
  def __init__(self,
               inputs: object,
               targets: object,
               transform=None
               ):
    self.inputs = inputs
    self.targets = targets
    self.transform = transform
    self.inputs_dtype = torch.float32
    self.targets_dtype = torch.long

  def __len__(self):
    return len(self.inputs)

  def __getitem__(self,
                  index: int):
    # select the sample
    input_ID = self.inputs[index]
    target_ID = self.targets[index]

    # load input and target
    x, y = ants.image_read(input_ID), ants.image_read(target_ID)
    
    # PREPROCESSING
    if self.transform is not None:
      
      ######## Intensity correction 
      # image histogram
      hist, bins = np.histogram(x.numpy().ravel(), bins=256)

      # mean intensity value of the brain tissue
      mean_intensity = bins[125:175].mean()

      # Set the lower and upper thresholds to exclude the empty space
      low_threshold = mean_intensity + 0.1 * mean_intensity
      high_threshold = np.amax(x.numpy()) - 0.1 * np.amax(x.numpy())

      # binary mask of the brain region
      mask = ants.get_mask(x, low_thresh=low_threshold, high_thresh=high_threshold, cleanup=2)

      # n4 bias
      x = ants.n4_bias_field_correction(x, 
                                        shrink_factor=3, 
                                        mask=mask, 
                                        convergence={'iters':[20, 10, 10, 5],
                                                     'tol':1e-07},
                                        rescale_intensities=True)
      ###########

      # normalize intensities by z standard normal
      x = x.iMath_normalize()

      # resize
      # interp_type= 1 (nearest neighbor)
      y = y.resample_image((96,128,160), use_voxels=True, interp_type=1)
      x = x.resample_image((96,128,160), use_voxels=True, interp_type=1)

      # saving Fix later ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
      # ants.image_write(x, os.path.join(pp_img_dir, file_labels[idx] + '_raw' + suffix), ri=False)
      # ants.image_write(y, os.path.join(pp_mask_dir, file_labels[idx] + '_mask' + suffix), ri=False)

      # type casting
    x = torch.from_numpy(x.numpy()).type(self.inputs_dtype)
    y = torch.from_numpy(y.numpy()).type(self.inputs_dtype)

    return x, y

In [None]:
training_dataset = SegmentationDataSet(inputs=data['raw'][0:2],
                    targets=data['brain_mask'][0:2],
                    transform=True)

In [None]:
training_dataloader = d.DataLoader(dataset=training_dataset,
             batch_size=2,
             shuffle=True)

# iter calls the __iter__() method on the dataloader
# https://stackoverflow.com/a/62550190/16800940
x, y = next(iter(training_dataloader))

print(f'x = shape: {x.shape}; type: {x.dtype}')
print(f'x = min: {x.min()}; max: {x.max()}')
print(f'y = shape: {y.shape}; class: {y.unique()}; type: {y.dtype}')

x = shape: torch.Size([2, 96, 128, 160]); type: torch.float32
x = min: 0.0; max: 1.0
y = shape: torch.Size([2, 96, 128, 160]); class: tensor([0.0000, 0.0819, 0.3243, 0.3337, 1.0000]); type: torch.float32
