In [6]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from skimage.io import imread

import nibabel as nib

In [2]:
train_dir = "./Dataset/Training/Data"

file_list = os.listdir("./Dataset/Training/Data")
nii_files = [os.path.join(train_dir, file) for file in file_list if file.endswith(".nii")]
glm_files = [file for file in nii_files if file.endswith("glm.nii")]
for glm in glm_files:
    nii_files.remove(glm)

In [3]:
nib.Nifti1Header.quaternion_threshold = -1e-06  # https://github.com/nipy/nibabel/issues/626
img = np.array(nib.load(nii_files[0]).dataobj).astype(np.float64)

### padding
img = np.pad(img, ((14,14),(14,14),(14,14)),'constant', constant_values=(0))

print(img.shape)

i,j,k =100,100,100
# center (14,14)
print(img[i-14:i+15, j-14:j+15, k].shape)
print(img[i, j-14:j+15, k-14:k+15].shape)
print(img[i-14:i+15, j, k-14:k+15].shape)

(284, 284, 347)
(29, 29)
(29, 29)
(29, 29)


In [4]:
data_dir = "./Dataset/Training/Data"
# data_file = "1000_3.nii"
data_file = None

if data_file is None and data_dir:
    file_list = os.listdir(data_dir)
    nii_files = [os.path.join(data_dir, file) for file in file_list if file.endswith(".nii")]
    glm_files = [file for file in nii_files if file.endswith("glm.nii")]
    for glm in glm_files:
        nii_files.remove(glm)

elif isinstance(data_file, str):
    nii_files = [os.path.join(data_dir, data_file)]
    glm_files = [nii_files[0].split(".n")[0] + "_glm.nii"]


### settings for nibabel ###
nib.Nifti1Header.quaternion_threshold = -1e-06
data_list = [np.array(nib.load(data).dataobj).astype(np.float64) for data in nii_files]

len(data_list)

15

In [7]:
class BrainSegmentationDataset(Dataset):

    def __init__(self, data_dir, data_file=None, transform=None):
        
        if data_file is None and data_dir:
            file_list = os.listdir(data_dir)
            nii_files = [os.path.join(data_dir, file) for file in file_list if file.endswith(".nii")]
            glm_files = [file for file in nii_files if file.endswith("glm.nii")]
            for glm in glm_files:
                nii_files.remove(glm)

        elif isinstance(data_file, str):
            nii_files = [os.path.join(data_dir, data_file)]
            glm_files = [nii_file.split(".n")[0] + "_glm.nii"]
            
        
        ### settings for nibabel ###
        nib.Nifti1Header.quaternion_threshold = -1e-06
        
        self.data_list = [np.array(nib.load(data).dataobj).astype(np.float64) for data in nii_files]
        
        

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample