# Import Lib

In [1]:
import torch, os
import matplotlib.pyplot as plt
import numpy as np
import cv2
import scipy
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage import rotate
from random import randint
from pathlib import Path

def rm_tree(pth):
    pth = Path(pth)
    for child in pth.glob('*'):
        if child.is_file():
            child.unlink()
        else:
            rm_tree(child)
    pth.rmdir()

def fg_ratio(mask):
    if torch.is_tensor(mask):
        mask = mask.numpy()
    return np.count_nonzero(mask) / np.count_nonzero(np.ones_like(mask))

# Helper func

In [2]:
def rotate3d(volume, degree):
    """volume shape: numpy array (D, H, W)"""
    rotated_volume = np.zeros_like(volume)
    for i in range(volume.shape[0]):
        rotated_volume[i] = rotate(volume[i], degree, reshape=False)
    return rotated_volume

def crop_center_3d(mask, output_size=(14, 110, 110)):
    if len(mask.shape) == 5: # b, c, d, h, w
        d, h, w = mask.shape[2:]
    elif len(mask.shape) == 3: # d, h, w
        d, h, w = mask.shape

    start_d = (d - output_size[0]) // 2
    start_h = (h - output_size[1]) // 2
    start_w = (w - output_size[2]) // 2
    return mask[..., start_d:(start_d + output_size[0]), start_h:(start_h + output_size[1]), start_w:(start_w + output_size[2])]
    
def elastic_transform_3d(image, mask, alpha, sigma, alpha_affine=-1, random_state=None):
    """
    Param:
        image (np.ndarray): image to be deformed
        alpha (float): scale of transformation for each dimension, where larger
            values have more deformation
        sigma (float): Gaussian window of deformation for each dimension, where
            smaller values have more localised deformation
    Returns:
        np.ndarray: deformed image
    """
    ori_shape = image.shape

    if len(ori_shape) < 3:
        image = np.expand_dims(image, axis=-1)
        mask = np.expand_dims(mask, axis=-1)
    
    if random_state is None:
        seed = randint(1, 200)
        random_state = np.random.RandomState(seed)
    
    shape = image.shape
    
    # Random affine
    if alpha_affine > 0:
        # print('affine')
        shape_size = shape[:2]
        center_square = np.float32(shape_size) // 2
        square_size = min(shape_size) // 3
        pts1 = np.float32([center_square + square_size, [center_square[0]+square_size, center_square[1]-square_size], center_square - square_size])
        pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32)
        M = cv2.getAffineTransform(pts1, pts2)
        image = cv2.warpAffine(image, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101)
        mask = cv2.warpAffine(mask, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101)

        if len(image.shape) < 3:
            image = np.expand_dims(image, axis=-1)
            mask = np.expand_dims(mask, axis=-1)

    dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="reflect", cval=0) * alpha
    dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="reflect", cval=0) * alpha
    dz = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="reflect", cval=0) * alpha * shape[2]/shape[0]

    x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]))
    indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1)), np.reshape(z+dz, (-1, 1))

    image_trans = map_coordinates(image, indices, order=1, mode='reflect').reshape(ori_shape)
    mask_trans = map_coordinates(mask, indices, order=1, mode='reflect').reshape(ori_shape)
    return image_trans, mask_trans

# Aug

In [3]:
import random
def aug_data(volume, mask):
    #make a copy
    aug_volume = np.array(volume, dtype=np.float32)
    aug_mask = np.array(mask, dtype=np.uint8)

    #Random rotation x y
    if random.uniform(0, 1) < ROTATE_PROB and MAX_ROTATE_ANGLE_X > 0.:
        rotate_angle = random.uniform(-MAX_ROTATE_ANGLE_X, MAX_ROTATE_ANGLE_X)
        aug_volume = scipy.ndimage.interpolation.rotate(aug_volume, rotate_angle, (0,2), reshape = False, mode = 'nearest')
        aug_mask = scipy.ndimage.interpolation.rotate(aug_mask, rotate_angle, (0,2), reshape = False, mode = 'nearest')
        # print('\t+ Random rotation x')

    if random.uniform(0, 1) < ROTATE_PROB and MAX_ROTATE_ANGLE_Y > 0.:
        rotate_angle = random.uniform(-MAX_ROTATE_ANGLE_Y, MAX_ROTATE_ANGLE_Y)
        aug_volume = scipy.ndimage.interpolation.rotate(aug_volume, rotate_angle, (0, 1), reshape=False, mode='nearest')
        aug_mask = scipy.ndimage.interpolation.rotate(aug_mask, rotate_angle, (0, 1), reshape=False, mode='nearest')
        # print('\t+ Random rotation y')
    
    # Random rotation z
    if random.uniform(0, 1) < ROTATE_PROB and MAX_ROTATE_ANGLE_Z > 0.:
        rotate_angle = random.uniform(-MAX_ROTATE_ANGLE_Z, MAX_ROTATE_ANGLE_Z)
        aug_volume = scipy.ndimage.interpolation.rotate(aug_volume, rotate_angle,(1,2), reshape=False, mode='nearest')
        aug_mask = scipy.ndimage.interpolation.rotate(aug_mask, rotate_angle,(1,2), reshape=False, mode='nearest')
        # print('\t+ Random rotation z')

    # Elastic Transform
    if random.uniform(0, 1) < ELASTIC_PROB:
        # Change D H W to H W D
        aug_volume = np.moveaxis(aug_volume, 0, -1)
        aug_mask = np.moveaxis(aug_mask, 0, -1)
        alpha_factor = 2
        sigma_factor = 0.08
        alpha_affine_factor = 0.08

        # print(aug_mask.shape, aug_mask.shape)
        aug_volume, aug_mask = elastic_transform_3d(aug_volume, aug_mask, alpha = aug_volume.shape[1] * alpha_factor, sigma = aug_volume.shape[1] * sigma_factor, alpha_affine = aug_volume.shape[1] * alpha_affine_factor)
        # Change H W D to D H W
        aug_volume = np.moveaxis(aug_volume, -1, 0)
        aug_mask = np.moveaxis(aug_mask, -1, 0)
        # print('\t+ Elastic Transform Perform')

    # Normalize
    aug_volume[aug_volume<0.] = 0.
    aug_volume[aug_volume>1.] = 1.
    aug_mask[aug_mask>=0.5] = 1
    aug_mask[aug_mask<0.5] = 0

    return aug_volume, aug_mask

# get patch

In [4]:
def generate_patches(volume, mask, size=(1, 128, 128), stride=(1, 120, 120), padding=True, remove=True, percent=0.01):
    """
    volume, mask: numpy array (D, H, W)
    size: shape of an output patch 
    @return numpy array (n, c, h, w) or (n, 1, d, h, w)
    """
    assert volume.shape == mask.shape, 'Shape of volume and mask are different'
    assert len(volume.shape) == 3, 'Invalid volume shape'
    remove_patch = 0
    if len(size) == 4:
        ch, size_d, size_h, size_w = size
    elif len(size) == 3:
        ch, size_h, size_w = size
        size_d = ch

    D, H, W = volume.shape

    stride_d, stride_h, stride_w = stride

    overlap_d, overlap_h, overlap_w = size_d - stride_d, size_h - stride_h, size_w - stride_w

    if padding:
        d_pad = (size_d - overlap_d) - ((D - overlap_d) % (size_d - overlap_d))
        h_pad = (size_h - overlap_h) - ((H - overlap_h) % (size_h - overlap_h))
        w_pad = (size_w - overlap_w) - ((W - overlap_w) % (size_w - overlap_w))
        volume = np.pad(volume, ((0, d_pad), (0, h_pad), (0, w_pad)), mode='constant', constant_values=0)
        mask = np.pad(mask, ((0, d_pad), (0, h_pad), (0, w_pad)), mode='constant', constant_values=0)

    d_steps = int(np.ceil( (D - overlap_d)/(size_d - overlap_d) ))
    h_steps = int(np.ceil( (H - overlap_h)/(size_h - overlap_h) ))
    w_steps = int(np.ceil( (W - overlap_w)/(size_w - overlap_w) ))
    # print(d_steps, h_steps, w_steps)
    
    out_volume = []
    out_mask = []
    step_d = 0
    done_d = False
    while not done_d:
        # Depth direction
        start_d = step_d * (size_d - overlap_d)
        if start_d < 0: start_d = 0
        end_d = start_d + size_d
        if end_d >= D:
            done_d = True
        if end_d > D and not padding:
            continue
        # print(overlap_d, start_d, end_d)
        done_h = False
        step_h = 0
        while not done_h:
            # Height direction
            start_h = step_h * (size_h - overlap_h)
            if start_h < 0: start_h = 0
            end_h = start_h + size_h
            if end_h >= H:
                done_h = True
            if end_h > H and not padding:
                continue
            done_w = False
            step_w = 0
            while not done_w:
                # Width derection
                start_w = step_w * (size_w - overlap_w)
                if start_w < 0: start_w = 0
                end_w = start_w + size_w
                if end_w >= W:
                    done_w = True
                if end_w > W and not padding:
                    continue
                # print(f'{start_d}:{end_d}, {start_h}:{end_h}, {start_w}:{end_w}')
                vol_voxel = volume[start_d:end_d, start_h:end_h, start_w:end_w]
                mask_voxel = mask[start_d:end_d, start_h:end_h, start_w:end_w]
                if vol_voxel.shape[0] != ch:
                    vol_voxel = np.expand_dims(vol_voxel, axis=0)
                    mask_voxel = np.expand_dims(mask_voxel, axis=0)
                if remove:
                    if np.count_nonzero(mask_voxel) / np.count_nonzero(np.ones_like(mask_voxel)) > percent:
                        out_volume.append(vol_voxel)
                        out_mask.append(mask_voxel)
                    else:
                        remove_patch += 1
                else:
                    out_volume.append(vol_voxel)
                    out_mask.append(mask_voxel)
                step_w += 1
            step_h += 1
        step_d += 1
    # print(step_h, step_w, step_d)
    print(f"The number of removed patchs: {remove_patch}")
    if len(out_volume) > 0:
        return np.stack(out_volume), np.stack(out_mask)
    return np.array([], dtype=np.float32), np.array([], dtype=np.float32)

# Preprocessing data

In [5]:
import SimpleITK as sitk

HU_MIN = -100
HU_MAX = 400

def scale_to_0_1(arr):
    return (arr - arr.min()) / (arr.max() - arr.min())

def preprocess(volume, mask, AD_filter=False):
    # HU window and scale to [0, 1]
    volume = np.clip(volume, HU_MIN, HU_MAX)
    volume = scale_to_0_1(volume)
    mask = np.clip(mask, 0, 1)
    
    #Anisotropic Diffusion Filter
    if AD_filter:
        image_itk2 = sitk.GetImageFromArray(volume, isVector=False)
        AD_filter = sitk.CurvatureAnisotropicDiffusionImageFilter()
        AD_filter.SetTimeStep(0.0625)
        AD_filter.SetNumberOfIterations(4)
        AD_filter.SetConductanceParameter(1.5)
        image_itk2 = AD_filter.Execute(image_itk2)
        volume = sitk.GetArrayFromImage(image_itk2)
    
    volume = scale_to_0_1(volume)
    
    return volume, mask

# Gen pre-training data





In [6]:
from pathlib import Path
prevol_dir  = '/workspace/dataset/data_pretrain/VOLUME'
premask_dir = '/workspace/dataset/data_pretrain/VESSEL'

vol_path = Path(prevol_dir)
mask_path = Path(premask_dir)
if vol_path.exists():
    rm_tree(vol_path.parent)

vol_path.mkdir(parents=True)
mask_path.mkdir(parents=True)
del vol_path, mask_path

In [7]:
def preprocess(vol, mask, minv, maxv):
    vol = np.clip(vol, minv, maxv)
    vol = (vol - minv)/ (maxv - minv)
    mask = np.clip(mask, 0, 1)
    return vol, mask

In [None]:
pth = '/workspace/dataset/3dircadb_zoom_crop/Zoom_Crop_05'
size = (1, 64, 96, 96)
stride = (32, 32, 32)

lst_test = [0, 3, 5, 10, 16, 18]
n = 0
for idx in range(20):
    if idx in lst_test:
        continue
    volume_path = pth + f'/VOLUME/volume_{idx}.pth'
    mask_path   = pth + f'/VESSEL/volume_{idx}.pth'
    img_ori  = torch.load(volume_path)
    mask_ori = torch.load(mask_path)
    print(f'# Load volume {idx} ---> Shape of volume: ', img_ori.shape)
    
    img_ori, mask_ori = preprocess(img_ori, mask_ori, -100, 400)
    img, mask = generate_patches(img_ori, mask_ori, size=size, stride=stride, 
                                 padding=True, remove=True, percent=0.001)
    
    print(f'-- Done volume {idx} -- {img.shape}')
    
    torch.save(img,  prevol_dir + f'/volume_{n}.pth')
    torch.save(mask, premask_dir + f'/volume_{n}.pth')

    del img, mask, img_ori, mask_ori
    n += 1

# Augment training data

In [None]:
from pathlib import Path

folder_train = 'data/train'
vol_dir = f'/workspace/dataset/{folder_train}/volumes'
mask_dir = f'/workspace/dataset/{folder_train}/masks'
# the number of training patient
n = 14
# n sampling augmentation
n_sampling = 1

vol_path = Path(vol_dir)
mask_path = Path(mask_dir)
vol_path.mkdir(parents=True, exist_ok=True)
mask_path.mkdir(parents=True, exist_ok=True)
del vol_path, mask_path

rnd_state = np.random.RandomState(24)

VOL_SIZE  = (64, 96, 96)
MASK_SIZE = VOL_SIZE

# Aug setting
ROTATE_PROB = 0.5
ELASTIC_PROB = 0.3
MAX_ROTATE_ANGLE_X = 10
MAX_ROTATE_ANGLE_Y = 10
MAX_ROTATE_ANGLE_Z = 45

data_pretrain_path = 'data_pretrain'
prevol_dir = f'/workspace/dataset/{data_pretrain_path}/VOLUME'
premask_dir = f'/workspace/dataset/{data_pretrain_path}/VESSEL'

fg_mean = []
count = 0
buffer = 400
for i in range(n):
    volumes = torch.load(prevol_dir + f'/volume_{i}.pth') # numpy array
    masks = torch.load(premask_dir + f'/volume_{i}.pth')
    print(f"Patient #{i}", volumes.shape)
    for idx in range(masks.shape[0]):
        vol = volumes[idx,0]
        mask = masks[idx,0]

        if rnd_state.random_sample() < 0.4:     
            vol_crop = vol.copy()
            mask_crop = mask.copy()

            fg = np.count_nonzero(mask_crop) / np.count_nonzero(np.ones_like(mask_crop))
            fg_mean.append(fg)
        
            vol_crop = torch.from_numpy(vol_crop.copy())
            vol_crop = vol_crop.unsqueeze_(0)
            mask_crop = torch.from_numpy(mask_crop.copy())
            mask_crop = mask_crop.unsqueeze_(0)

            index_file = count // buffer
            frag_vol  = Path(vol_dir  + f'/data{index_file}')
            frag_mask = Path(mask_dir + f'/data{index_file}')
            frag_vol.mkdir(parents=True, exist_ok=True)
            frag_mask.mkdir(parents=True, exist_ok=True)
            torch.save(vol_crop,  vol_dir  + f'/data{index_file}/vol_{str(count)}.pth')
            torch.save(mask_crop, mask_dir + f'/data{index_file}/vol_{str(count)}.pth')
            count += 1 
            
        for i in range(n_sampling):
            if rnd_state.random_sample() < 0.4:     
                continue
            aug_vol, aug_mask = aug_data(vol, mask)

            aug_vol = torch.from_numpy(aug_vol.copy())
            aug_vol = aug_vol.unsqueeze_(0)
            aug_mask = torch.from_numpy(aug_mask.copy())
            aug_mask = aug_mask.unsqueeze_(0)

            index_file = count // buffer
            frag_vol  = Path(vol_dir  + f'/data{index_file}')
            frag_mask = Path(mask_dir + f'/data{index_file}')
            frag_vol.mkdir(parents=True, exist_ok=True)
            frag_mask.mkdir(parents=True, exist_ok=True)

            torch.save(aug_vol,  vol_dir  + f'/data{index_file}/vol_{str(count)}.pth')
            torch.save(aug_mask, mask_dir + f'/data{index_file}/vol_{str(count)}.pth')
            count += 1 
        # -- end n sampling -- 
    # -- end 1 patient --
x = torch.tensor(count)
torch.save(x, f'/workspace/dataset/{folder_train}/info.pth')

In [None]:
fg_mean = torch.tensor(fg_mean)
print(x)
print(fg_mean.mean())

# Load val data

In [28]:
def scale_to_0_1(arr):
    return (arr - arr.min()) / (arr.max() - arr.min())

def preprocess(volume, mask, hu_min, hu_max):
    volume = np.clip(volume, hu_min, hu_max)
    volume = scale_to_0_1(volume)
    mask[mask >= 0.5] = 1
    mask[mask < 0.5] = 0
    return volume, mask

def load_all_subvols(subvols_path, sub_vol_shape, sub_vol_type=torch.float32):
    n_subvols = int(torch.load(subvols_path + 'num.pth'))
    subvols = torch.empty(eval('(n_subvols, )') + sub_vol_shape[-3:], dtype=sub_vol_type)
    for i in range(0, n_subvols):
        subvols[i] = torch.load(subvols_path + f'vol_{i}.pth').to(sub_vol_type)
    return subvols

In [29]:
def generate_patches(volume, mask, size=(64, 192, 192), stride=(64, 192, 192), padding=False, remove=True, number=10):
    """
    volume, mask: numpy array (D, H, W)
    size: shape of an output patch 
    @return numpy array (n, c, h, w) or (n, 1, d, h, w)
    """
    assert volume.shape == mask.shape, 'Shape of volume and mask are different'
    assert len(volume.shape) == 3, 'Invalid volume shape'
    assert len(size) == 3, 'Invalid size'
    assert len(stride) == 3, 'Invalid stride'
    remove_patch = 0
    size_d, size_h, size_w = size
    channels = 1 # Fixed

    D, H, W = volume.shape

    stride_d, stride_h, stride_w = stride

    overlap_d, overlap_h, overlap_w = size_d - stride_d, size_h - stride_h, size_w - stride_w

    if padding:
        d_pad = (size_d - overlap_d) - ((D - overlap_d) % (size_d - overlap_d))
        h_pad = (size_h - overlap_h) - ((H - overlap_h) % (size_h - overlap_h))
        w_pad = (size_w - overlap_w) - ((W - overlap_w) % (size_w - overlap_w))
        volume = np.pad(volume, ((0, d_pad), (0, h_pad), (0, w_pad)), mode='constant', constant_values=0)
        mask = np.pad(mask, ((0, d_pad), (0, h_pad), (0, w_pad)), mode='constant', constant_values=0)

    d_steps = int(np.ceil( (D - overlap_d)/(size_d - overlap_d) ))
    h_steps = int(np.ceil( (H - overlap_h)/(size_h - overlap_h) ))
    w_steps = int(np.ceil( (W - overlap_w)/(size_w - overlap_w) ))
    # print(d_steps, h_steps, w_steps)

    out_volume = []
    out_mask = []
    step_d = 0
    done_d = False
    while not done_d:
        # Depth direction
        start_d = step_d * (size_d - overlap_d)
        if start_d < 0: start_d = 0
        end_d = start_d + size_d
        if end_d >= D:
            done_d = True
        if end_d > D and not padding:
            continue
        # print(overlap_d, start_d, end_d)
        done_h = False
        step_h = 0
        while not done_h:
            # Height direction
            start_h = step_h * (size_h - overlap_h)
            if start_h < 0: start_h = 0
            end_h = start_h + size_h
            if end_h >= H:
                done_h = True
            if end_h > H and not padding:
                continue
            done_w = False
            step_w = 0
            while not done_w:
                # Width derection
                start_w = step_w * (size_w - overlap_w)
                if start_w < 0: start_w = 0
                end_w = start_w + size_w
                if end_w >= W:
                    done_w = True
                if end_w > W and not padding:
                    continue
                # print(f'{start_d}:{end_d}, {start_h}:{end_h}, {start_w}:{end_w}')
                vol_voxel = volume[start_d:end_d, start_h:end_h, start_w:end_w]
                mask_voxel = mask[start_d:end_d, start_h:end_h, start_w:end_w]
                if vol_voxel.shape[0] != channels:
                    vol_voxel = np.expand_dims(vol_voxel, axis=0)
                    mask_voxel = np.expand_dims(mask_voxel, axis=0)

                if remove:
                    if np.sum(mask_voxel) > number:
                        out_volume.append(vol_voxel)
                        out_mask.append(mask_voxel)
                    else:
                        remove_patch += 1
                else:
                    out_volume.append(vol_voxel)
                    out_mask.append(mask_voxel)
                step_w += 1
            step_h += 1
        step_d += 1
    # print(step_h, step_w, step_d)
    # print(f"The number of removed patchs: {remove_patch}")
    if len(out_volume) > 0:
        return np.stack(out_volume), np.stack(out_mask)
    return np.array([], dtype=np.float32), np.array([], dtype=np.float32)


In [30]:
SUB_VOL_SIZE = (64, 96, 96)
PTH_DIR = '/workspace/dataset/3dircadb_zoom_crop/Zoom_Crop_05'
SAVE_DIR = f'/workspace/dataset/data/val/'

LST_IDX = [5, 16, 18]

In [31]:
for i in LST_IDX:
    ### Save sub-volume for one volume
    patient_idx = i
    # 1. Load volume
    volume = torch.load(PTH_DIR + '/VOLUME' + f'/volume_{patient_idx}.pth')
    mask = torch.load(PTH_DIR + '/VESSEL' + f'/volume_{patient_idx}.pth')
    # 2. Preprocess
    volume, mask = preprocess(volume, mask, -100, 400)
    # 3. Generate sub-volume from volume
    sub_vols, _ = generate_patches(volume, mask, size=SUB_VOL_SIZE, stride=SUB_VOL_SIZE, padding=True, remove=False)
    # 4. Save all sub-volume to file
    for i in range(sub_vols.shape[0]):
        pth_vol = torch.from_numpy(sub_vols[i])
        # Make save path
        save_vol_path = SAVE_DIR + f'patient_{patient_idx}/'
        vol_path = Path(save_vol_path)
        vol_path.mkdir(parents=True, exist_ok=True)

        torch.save(pth_vol, save_vol_path  + f'vol_{i}.pth')
    torch.save(sub_vols.shape[0], save_vol_path + f'num.pth')
    torch.save(sub_vols[0].shape, save_vol_path+ f'shape.pth')
    # 5. Save mask
    torch.save(torch.from_numpy(mask), save_vol_path + f'mask.pth')
    print(f'Saved patient {patient_idx}')

Saved patient 5
Saved patient 16
Saved patient 18
