In [None]:
import numpy as np
import os

In [None]:
'''
The shape of a single MRI sample is (193, 229, 193). We concatenate across the axial dimension (last dimension) to become 
(n_slices, 1, H, W) where H, W become 256 to fit the VQ-VAE dimension. Thus there are 193 axial slices per 3D MRI sample. 
72 samples for training and 31 samples for validation in VQ-VAE training 
'''

def pad_data(arr):
    _, _, h, w = arr.shape
    h_pad = 2 ** int(np.ceil(np.log2(h))) - h
    w_pad = 2 ** int(np.ceil(np.log2(w))) - w
    pad_top, pad_bottom = h_pad // 2, h_pad - h_pad // 2
    pad_left, pad_right = w_pad // 2, w_pad - w_pad // 2
    return np.pad(arr, ((0,0),(0,0),(pad_top,pad_bottom),(pad_left,pad_right)), mode='constant')

data_dir = '/home/mingjie/mri230/data'
data_lst = os.listdir(data_dir)
train_dir = '/home/mingjie/mri230/train_data'
val_dir = '/home/mingjie/mri230/val_data'

train_cutoff = int(len(data_lst) * 0.7)

train_mris = []
val_mris = [] 

for i in range(len(data_lst)): 
    if i < train_cutoff:
        train_mris.append(np.load(os.path.join(data_dir, data_lst[i])))
    else:
        val_mris.append(np.load(os.path.join(data_dir, data_lst[i])))
        
train_data = np.concatenate(train_mris, axis=-1).astype(np.float32)
val_data = np.concatenate(val_mris, axis=-1).astype(np.float32)

global_min = min(train_data.min(), val_data.min())
global_max = max(train_data.max(), val_data.max())

train_data = (train_data - global_min) / (global_max - global_min + 1e-8)
val_data   = (val_data - global_min) / (global_max - global_min + 1e-8)

train_data = np.moveaxis(train_data, -1, 0)  # (n_slices, H, W)
val_data   = np.moveaxis(val_data, -1, 0)

train_data = np.expand_dims(train_data, 1)  # (n_slices, 1, H, W)
val_data   = np.expand_dims(val_data, 1)

train_data = pad_data(train_data)
val_data   = pad_data(val_data)

np.save('train_data/train_data.npy', train_data)
np.save('val_data/val_data.npy', val_data)

In [None]:
print(train_data.shape) # (72 * 193, 1, H, W)
print(val_data.shape)   # (31 * 193, 1, H, W)