In [None]:
import numpy as np
import pandas as pd
import sys
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os
import numpy as np
from torch.utils.data import DataLoader
sys.path.append('..')
from load_data import load_data
from config import xvertseg_dir, verse2019_dir, resolution, patch_size, batch_size
from tiger.patches import PatchExtractor3D
import torch
from monai.transforms import Compose, RandGaussianNoise, RandRotate, RandGaussianSmooth, RandGaussianSharpen, Rand3DElastic

In [None]:
# load data from corresponding data dir
xvertseg_imgs, xvertseg_msks, xvertseg_scores = load_data(xvertseg_dir, resolution)

In [None]:
class Dataset(torch.utils.data.Dataset):
    """
    Dataset class for a simple dataset containing patches of vertebra and associated scores.
    Also keeps track of where the vertebra was found in the dataset (ID and type)'
    """
    def __init__(self, scores, images, masks, patch_size, transforms=False):
        self.patches = []                      # (N, 2)    N is the number of vertebrae, img and msk channel
        self.scores = []                       # (N, 1)    fractured or not
        self.sources = []                      # (N, 1)    dataset of where the image was found
        self.IDS = []                          # (N, 1)    ID of the image, in which this vertebra is found
        self.vertebrae = []                    # (N, 1)    8-25: T1-T12, L1-L6

        # transform after patch extraction
        if transforms:
            self.spatial_transforms = Compose([
                RandRotate(range_x= 1/6 *np.pi, range_y=1/6 *np.pi, range_z=0, prob=0.5, mode='nearest')                
            ])
            
            self.other_transforms = Compose([
                RandGaussianNoise(prob=0.2),
                RandGaussianSharpen(prob=0.2),
                RandGaussianSmooth(prob=0.2)
            ])

        # the patch extraction
        for row, mask in enumerate(masks):
            # get the dataset and id of this case
            source = scores[row][0]
            id = scores[row][1]

            # get the vert scores, 18 vertebrae, grade and case, need float to detect nans
            vert_scores = scores[row][2:].reshape(18, 2).astype(float)

            # find annotated labels in the score sheet
            for i, vert_score in enumerate(vert_scores):
                if not (np.isnan(vert_score).any()):
                    label = i + 8                              # because we skip the 7 C-vertebrae

                    # if we also find this label in the mask
                    if label in np.unique(mask):
                        # get the patch containing this vertebra
                        centre = tuple(np.mean(np.argwhere(mask == label), axis=0, dtype=int))

                        # patch extractor for the image, pad with -1000 (air)
                        patch_extracter_img = PatchExtractor3D(images[row], pad_value=-1000)
                        patch_img = patch_extracter_img.extract_cuboid(centre, patch_size)

                        # patch extractor for the mask
                        patch_extracter_msk = PatchExtractor3D(mask)
                        patch_msk = patch_extracter_msk.extract_cuboid(centre, patch_size)
                        patch_msk = np.where(patch_msk == label, 1, 0)  # only contain this vertebra, binary

                        # add channel dimension
                        patch_img = np.expand_dims(patch_img, axis=0)
                        patch_msk = np.expand_dims(patch_msk, axis=0)
                        patch = np.concatenate((patch_img, patch_msk))

                        # add score and info about this patch
                        self.patches.append(patch)
                        self.scores.append(vert_score.any().astype(int))       # binarize: fractured or not
                        self.sources.append(source)
                        self.IDS.append(id)
                        self.vertebrae.append(label)

    def __len__(self):
        """
        Returns N, the number of vertebrae in this dataset.
        """
        return len(self.patches)

    def __getitem__(self, i):
        """"
        Return a single sample: a patch of mask containing one vertebra and its binary score"
        """
        # use float32 as type
        x = torch.tensor(self.patches[i], dtype=torch.float32)
        y = torch.tensor(self.scores[i], dtype=torch.float32).unsqueeze(0)

        # apply transformation, only for the training set
        if self.spatial_transforms:
            # apply spatial transform on both image and mask
            x_trans = self.spatial_transforms(x)
            
            # apply the others only on the image
            x_trans[0] = self.other_transforms(x_trans[0])
            return x, x_trans, y

        return x, y

In [None]:
# make a dataset
np_scores = xvertseg_scores.to_numpy()
dataset = Dataset(np_scores[:1], xvertseg_imgs[:1], xvertseg_msks[:1], patch_size, transforms=True)
loader = DataLoader(dataset, batch_size=8, num_workers=8)

In [None]:
x, x_trans, y = next(iter(loader))
    
num_samples = x.shape[0]

for s in range(num_samples):    

    plt.figure(figsize=(10, 10))

    x_img = x[s, 0, :, :]
    x_msk = x[s, 1, :, :]
    x_trans_img = x_trans[s, 0, :, :]
    x_trans_msk = x_trans[s, 1, :, :]
    label = y[s]

    mid_slice = x_img.shape[0] // 2
    x_msk = np.ma.masked_where(x_msk == 0, x_msk)  
    x_trans_msk = np.ma.masked_where(x_trans_msk == 0, x_trans_msk)  

    plt.subplot(1, 2, 1)
    plt.imshow(x_img[mid_slice, :, :], cmap='gray')
    plt.imshow(x_msk[mid_slice, :, :], alpha=0.2)
    plt.title('Before transforms')

    plt.subplot(1, 2, 2)
    plt.imshow(x_trans_img[mid_slice, :, :], cmap='gray')
    plt.imshow(x_trans_msk[mid_slice, :, :], alpha=0.2)
    plt.title('After transforms')