In [2]:
''' Notebook for creating VAE liver dataset. Therefore labels are not included. '''
import torch
import torch.nn.functional as F
import torch.utils as U
import torchvision.transforms as T
from torch import Tensor
import json
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np

In [3]:
path = 'dataset.json'
with open(path) as f:
    data_set = json.load(f)
training_paths = data_set['training']

In [4]:
d = 256 # width and height to resize to
p1 = int(d/4)
p2 = int(d/8)
p3 = int(d/16)
resize_transform = T.Resize((d, d))
pad_transforms = [T.Pad(p1), T.Pad(p2), T.Pad(p3)]
rand_rot_transform = T.RandomRotation(180) 

In [8]:
data_tensor = torch.zeros((0, d, d))
total_paths = len(training_paths)

for n, path in enumerate(training_paths):
    print(f'Iteration {n}/{total_paths}')
    slice_tensor = torch.zeros((0, d, d))
    nii_img = nib.load(path['image'])
    nii_data = Tensor(nii_img.get_fdata())
    # ensure scale [0; 1]
    nii_data -= nii_data.min()
    nii_data /= nii_data.max() # are the max the same in every data point?
    nii_data = nii_data.permute(2, 0, 1) # ensure the tensor shape is (slice, rows, columns)
    nii_data = resize_transform(nii_data)
    slice_tensor = torch.cat((slice_tensor, nii_data), 0)
    nii_data_h = T.functional.hflip(nii_data)
    slice_tensor = torch.cat((slice_tensor, T.functional.hflip(nii_data)), 0)
    slice_tensor = torch.cat((slice_tensor, T.functional.vflip(nii_data)), 0)
    # make rotated data
    nii_data_r = rand_rot_transform(nii_data)
    for i in range(1, 4): 
        slice_tensor = torch.cat((slice_tensor, T.functional.rotate(nii_data, i*90)), 0)
    for pad_transform in pad_transforms:
        nii_data_p = resize_transform(pad_transform(nii_data))
        nii_data_pr = resize_transform(pad_transform(nii_data_r))
        slice_tensor = torch.cat((slice_tensor, nii_data_p), 0)
        slice_tensor = torch.cat((slice_tensor, nii_data_pr), 0)
    data_tensor = torch.cat((data_tensor, slice_tensor), 0) 
    if n % 10 == 0 and n != 0:
        torch.save(data_tensor, f'augmented_data_chunks/data_chunk_{int(n/10)}.pt')
        data_tensor = torch.zeros((0, d, d)) # reset
    if n == total_paths - 1:
        torch.save(data_tensor, f'augmented_data_chunks/data_chunk_{int(np.ceil(n)/10)}.pt')

Iteration 0/63
Iteration 1/63
Iteration 2/63
Iteration 3/63
Iteration 4/63
Iteration 5/63
Iteration 6/63
Iteration 7/63
Iteration 8/63
Iteration 9/63
Iteration 10/63
Iteration 11/63
Iteration 12/63
Iteration 13/63
Iteration 14/63
Iteration 15/63
Iteration 16/63
Iteration 17/63
Iteration 18/63
Iteration 19/63
Iteration 20/63
Iteration 21/63
Iteration 22/63
Iteration 23/63
Iteration 24/63
Iteration 25/63
Iteration 26/63
Iteration 27/63
Iteration 28/63
Iteration 29/63
Iteration 30/63
Iteration 31/63
Iteration 32/63
Iteration 33/63
Iteration 34/63
Iteration 35/63
Iteration 36/63
Iteration 37/63
Iteration 38/63
Iteration 39/63
Iteration 40/63
Iteration 41/63
Iteration 42/63
Iteration 43/63
Iteration 44/63
Iteration 45/63
Iteration 46/63
Iteration 47/63
Iteration 48/63
Iteration 49/63
Iteration 50/63
Iteration 51/63
Iteration 52/63
Iteration 53/63
Iteration 54/63
Iteration 55/63
Iteration 56/63
Iteration 57/63
Iteration 58/63
Iteration 59/63
Iteration 60/63
Iteration 61/63
Iteration 62/63
