## Setting up

In [None]:
%pip install tensorflow==2.16.1 nibabel matplotlib numpy opencv-python image-classifiers==1.0.0b1 keras_applications keras_preprocessing keras_cv albumentations

In [None]:
import os
from glob import glob 
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import random
import cv2
from albumentations import HorizontalFlip, CoarseDropout, RandomBrightnessContrast, GaussNoise, RandomGamma, MixUp, ElasticTransform

In [None]:
input_shape = (630, 630, 1)

## Preprocessing

In [None]:
def read_nifti_file(filepath):
  """Read and load volume"""
  # Read file
  scan = nib.load(filepath)
  # Get raw data
  scan = scan.get_fdata()
  return scan


def load_data(path):
  images = sorted(glob(os.path.join(path, 'rp_im', '*')))
  lung_masks = sorted(glob(os.path.join(path, 'rp_lung_msk', '*')))
  masks = sorted(glob(os.path.join(path, 'rp_msk', '*')))

  return (images, lung_masks, masks)


def parse(images, lung_masks, masks):
  images = read_nifti_file(images)
  lung_masks = read_nifti_file(lung_masks)
  masks = read_nifti_file(masks)
  return images, lung_masks, masks

def to_uint8(data):
    data -= data.min()
    data /= data.max()
    data *= 255
    return data.astype(np.uint8)

In [None]:
%rm -r 'dataset/test'
%rm -r 'dataset/train'

%mkdir 'dataset/test'
%mkdir 'dataset/test/images'
%mkdir 'dataset/test/masks'

%mkdir 'dataset/train'
%mkdir 'dataset/train/images'
%mkdir 'dataset/train/masks'

%mkdir 'dataset/validation'
%mkdir 'dataset/validation/images'
%mkdir 'dataset/validation/masks'

In [None]:
images, lung_masks, masks = load_data('dataset/MedSeg Covid Dataset 2')

for image, lung_mask, mask in zip(images, lung_masks, masks):
  image_slices, lung_mask_slices, mask_slices = parse(image, lung_mask, mask)
  for i in range(image_slices.shape[2]):
    image_slice = to_uint8(image_slices[:, :, i])
    image_slice = np.rot90(image_slice, k=1, axes=(1, 0))

    lung_mask_slice = np.uint8(lung_mask_slices[:, :, i])
    lung_mask_slice = np.rot90(lung_mask_slice, k=1, axes=(1, 0))

    mask_slice = np.uint8(mask_slices[:, :, i])
    mask_slice = np.rot90(mask_slice, k=1, axes=(1, 0))

    p = np.where(mask_slice != 0)
    p1 = np.where(lung_mask_slice == 1)
    p2 = np.where(mask_slice > 0)

    mask_slice[p] = 0
    mask_slice[p1] = 1
    mask_slice[p2] = 2

    if (2 in np.unique(mask_slice)):
      cv2.imwrite(f'dataset/train/images/{os.path.splitext(os.path.basename(image))[0]}_{i}.png', image_slice)
      cv2.imwrite(f'dataset/train/masks/{os.path.splitext(os.path.basename(image))[0]}_{i}.png', mask_slice)

## Augment

In [None]:
images = sorted(glob(os.path.join('zip/dataset/train/images', '*')))
masks = sorted(glob(os.path.join('zip/dataset/train/masks', '*')))

for n in range(2):
  for i in range(len(images)):
    image = cv2.imread(images[i])
    mask = cv2.imread(masks[i])

    aug = HorizontalFlip(p=1)
    augmented = aug(image=image, mask=mask)
    x1 = augmented["image"]
    y1 = augmented["mask"]

    aug = RandomBrightnessContrast(p=1, brightness_limit=(-0.2, 0.2), contrast_limit=(0.0, 0.0))
    augmented = aug(image=x1, mask=y1)
    x2 = augmented["image"]
    y2 = augmented["mask"]

    aug = RandomBrightnessContrast(p=1, brightness_limit=(0.0, 0.0), contrast_limit=(-0.2, 0.2))
    augmented = aug(image=x2, mask=y2)
    x3 = augmented["image"]
    y3 = augmented["mask"]

    aug = ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03)
    augmented = aug(image=x3, mask=y3)
    x4 = augmented["image"]
    y4 = augmented["mask"]
    cv2.imwrite(f'zip/dataset/train/images/aug_{n}_{os.path.splitext(os.path.basename(images[i]))[0]}_{i}.png', x4)
    cv2.imwrite(f'zip/dataset/train/masks/aug_{n}_{os.path.splitext(os.path.basename(images[i]))[0]}_{i}.png', y4)

In [None]:
images = sorted(glob(os.path.join('dataset/validation/images', '*')))
for image in images:
  name = os.path.basename(image)
  os.rename(os.path.join('dataset/train/masks', name), os.path.join('dataset/validation/masks', name))