In [2]:
import numpy as np
import nibabel as nib
import scipy as sp
import scipy.ndimage
import csv

In [3]:
file_idx = np.genfromtxt('./subjects_idx.txt', dtype='str') 
fold_idx = np.loadtxt('fold.txt')                   # to keep same-patient images together
label = np.loadtxt('dx.txt')
np.random.seed(seed=0)

In [4]:
len(file_idx)

1334

In [7]:
len(fold_idx)

1334

In [9]:
len(label)
file_idx.shape[0]

1334

In [10]:
subject_num = file_idx.shape[0]

In [11]:
patch_x = 64
patch_y = 64
patch_z = 64
min_x = 0
min_y = 0
min_z = 0
i = 0

In [12]:
data = np.zeros((subject_num, patch_x, patch_y, patch_z, 1))

In [13]:
data.shape

(1334, 64, 64, 64, 1)

In [14]:
for img_idx in file_idx:
    if i%100==0:
        print(i)
    filename_full = './data/' + img_idx
    img = nib.load(filename_full)
    img_data = img.get_fdata()
    data[i,:,:,:,0] = img_data[min_x:min_x+patch_x, min_y:min_y+patch_y, min_z:min_z+patch_z] 
    data[i,:,:,:,0] = (data[i,:,:,:,0] - np.mean(data[i,:,:,:,0])) / np.std(data[i,:,:,:,0])
    i += 1

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300


In [16]:
img_data.shape

(64, 64, 64)

1334

In [19]:
train_idx = (fold_idx <= 2) 
val_idx = (fold_idx == 3)
test_idx = (fold_idx == 4)

train_data = data[train_idx]
train_label = label[train_idx]

val_data = data[val_idx]
val_label = label[val_idx]

test_data = data[test_idx]
test_label = label[test_idx]

In [20]:
print(train_idx)

[ True  True  True ... False  True  True]


In [21]:
len(train_idx)

1334

In [22]:
test_data.shape

(251, 64, 64, 64, 1)

In [24]:
val_data.shape

(269, 64, 64, 64, 1)

In [26]:
train_data.shape

(814, 64, 64, 64, 1)

In [27]:
## Data Augmentation
def augment_by_transformation(data,n):
    if n <= data.shape[0]:
        return data
    else:
        raw_n = data.shape[0]           # number of examples we actually have
        m = n - raw_n                   # m = number of examples to generate to get n total examples (n = augment_size)
        new_data = np.zeros((m,data.shape[1],data.shape[2],data.shape[3],1))
        for i in range(0,m):
            idx = np.random.randint(0,raw_n)
            new_data[i] = data[idx].copy()
            new_data[i,:,:,:,0] = sp.ndimage.interpolation.rotate(new_data[i,:,:,:,0],np.random.uniform(-0.5,0.5),axes=(1,0),reshape=False)
            new_data[i,:,:,:,0] = sp.ndimage.interpolation.rotate(new_data[i,:,:,:,0],np.random.uniform(-0.5,0.5),axes=(0,2),reshape=False)
            new_data[i,:,:,:,0] = sp.ndimage.interpolation.rotate(new_data[i,:,:,:,0],np.random.uniform(-0.5,0.5),axes=(1,2),reshape=False)
            new_data[i,:,:,:,0] = sp.ndimage.shift(new_data[i,:,:,:,0],np.random.uniform(-0.5,0.5))
        data = np.concatenate((data, new_data), axis=0)
        return data

In [28]:
# Augment data
augment_size = 1024
augment_size_val = 256
augment_size_test = 256

In [29]:
train_data_pos = train_data[train_label==1]
train_data_neg = train_data[train_label==0]

In [None]:
print(train_data_neg.shape)
print(train_data_pos.shape)

In [32]:
train_data_pos_aug = augment_by_transformation(train_data_pos, augment_size)
train_data_neg_aug = augment_by_transformation(train_data_neg, augment_size)

KeyboardInterrupt: 

In [None]:
train_data_aug = np.concatenate((train_data_neg_aug, train_data_pos_aug), axis=0)
train_label_aug = np.zeros((augment_size * 2,))
train_label_aug[augment_size:] = 1

In [None]:
    val_data_pos = val_data[val_label==1]
    val_data_neg = val_data[val_label==0]
    val_data_pos_aug = augment_by_transformation(val_data_pos, augment_size_val)
    val_data_neg_aug = augment_by_transformation(val_data_neg, augment_size_val)
    val_data_aug = np.concatenate((val_data_neg_aug, val_data_pos_aug), axis=0)
    
    val_label_aug = np.zeros((augment_size_val * 2,))
    val_label_aug[augment_size_val:] = 1

In [None]:
    test_data_pos = test_data[test_label==1]
    test_data_neg = test_data[test_label==0]
    test_data_pos_aug = augment_by_transformation(test_data_pos, augment_size_test)
    test_data_neg_aug = augment_by_transformation(test_data_neg, augment_size_test)
    test_data_aug = np.concatenate((test_data_neg_aug, test_data_pos_aug), axis=0)

    test_label_aug = np.zeros((augment_size_test * 2,))
    test_label_aug[augment_size_test:] = 1