# Create Lits dataset

In [2]:
import os
import numpy as np
import torch
import nibabel as nib
from skimage.transform import downscale_local_mean

def read_vol(fn):
    import nibabel as nib
    
    img = nib.load(fn)
    img_a = img.get_fdata()
    img_a = img_a.astype(np.float32)
    
    return img_a

def get_liver_center(lab):
    import math
        
    first = -1
    last = -1        
    for i in range(lab.shape[0]):
        if round(lab[i, :, :].max()) >= 1.0:            
            if first < 0:
                first = i
            last = i    
    x = math.floor(first + (last - first)/2)
    
    first = -1
    last = -1        
    for i in range(lab.shape[1]):
        #print(first)
        if round(lab[:, i, :].max()) >= 1.0:            
            if first < 0:
                first = i
            last = i    
    y = math.floor(first + (last - first)/2)
    
    first = -1
    last = -1        
    for i in range(lab.shape[2]):
        if round(lab[:, :, i].max()) >= 1.0:            
            if first < 0:
                first = i
            last = i    
    z = math.floor(first + (last - first)/2)
    
    return (x, y, z)

def get_cube(vol, indx, size=128):
    import math
    
    n = int(size/2)
    sub_vol = vol[int(indx[0])-n:int(indx[0])+n, int(indx[1])-n:int(indx[1])+n, int(indx[2])-n:int(indx[2])+n]
    
    dim_size = sub_vol.shape[0]
    if dim_size < size:        
        pad_down = size -dim_size
        sub_vol = np.pad(sub_vol, ((0, pad_down), (0,0), (0,0)), 'constant')
    
    dim_size = sub_vol.shape[1]
    if dim_size < size:        
        pad_down = size -dim_size
        sub_vol = np.pad(sub_vol, ((0,0), (0, pad_down), (0, 0)), 'constant')
    
    dim_size = sub_vol.shape[2]
    if dim_size < size:        
        pad_down = size -dim_size
        sub_vol = np.pad(sub_vol, ((0,0), (0,0), (0, pad_down)), 'constant')
    
    return sub_vol


In [3]:
n = 131

downscale_factor = 2

for i in range(n):
    img = nib.load("full_dataset/volume-" + str(i) + ".nii")
    print("vol " + str(i) + " " + str(img.shape))
    
    z = img.shape[2]
    
    img = read_vol("full_dataset/volume-" + str(i) + ".nii")
    lab = read_vol("full_dataset/segmentation-" + str(i) + ".nii")
       
    if z <= 512:
        pad_up = (512 -z)//2
        pad_down = (512 -z)//2 + (512 -z)%2
        
        img = np.pad(img, ((0,0), (0,0), (pad_up, pad_down)), 'constant', constant_values=(img.min(), img.min()))
        lab = np.pad(lab, ((0,0), (0,0), (pad_up, pad_down)), 'constant')
        
    else:
        cut_up = (z -512)//2
        cut_down = z - (z -512)//2 - (z -512)%2
        
        img = img[:, :, cut_up:cut_down]
        lab = lab[:, :, cut_up:cut_down]
    
    img = downscale_local_mean(img, (downscale_factor, downscale_factor, downscale_factor))
    lab = downscale_local_mean(lab, (downscale_factor, downscale_factor, downscale_factor))
    lab = np.clip(np.round(lab), 0, 2) #3 clases
    
    c = get_liver_center(lab)
    c = np.array(c)

    lab = get_cube(lab, c)
    img = get_cube(img, c)
        
    torch.save(img.astype(np.float32), 'unet/ds/img_' + str(i) + '.pt')
    torch.save(lab.astype(np.uint8), 'unet/ds/lab_' + str(i) + '.pt')



vol 0 (512, 512, 75)
c [111 127 138] [47 63 74] [175 191 202]
vol 1 (512, 512, 123)
c [113 139 126] [49 75 62] [177 203 190]
vol 2 (512, 512, 517)


KeyboardInterrupt: 