In [1]:
import glob

from tqdm import tqdm
import nibabel as nib
import numpy as np
from sklearn.preprocessing import MinMaxScaler
scalar = MinMaxScaler()
import torch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

from self_attention_cv.Transformer3Dsegmentation.tranf3Dseg import Transformer3dSeg
from self_attention_cv.transunet import TransUnet
from self_attention_cv import TransformerEncoder
from self_attention_cv import ViT, ResNet50ViT

In [2]:
def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    return np.eye(num_classes, dtype='uint8')[y.astype(np.int16)]

In [3]:
def generate_brats_batch(prefix, 
                         contrasts, 
                         batch_size=32, 
                         tumour='*', 
                         patient_ids='*',
                         augment_size=None,
                         infinite=True):
    """
    Generate arrays for each batch, for x (data) and y (labels), where the contrast is treated like a colour channel.
    
    Example:
    x_batch shape: (32, 240, 240, 155, 4)
    y_batch shape: (32, 240, 240, 155)
    
    augment_size must be less than or equal to the batch_size, if None will not augment.
    
    """
    file_pattern = '{prefix}/MICCAI_BraTS_2018_Data_Training/{tumour}/{patient_id}/{patient_id}_{contrast}.nii.gz'
    while True:
        n_classes = 4

        # get list of filenames for every contrast available
        keys = dict(prefix=prefix, tumour=tumour)
        filenames_by_contrast = {}
        for contrast in contrasts:
            filenames_by_contrast[contrast] = glob.glob(file_pattern.format(contrast=contrast, patient_id=patient_ids, **keys)) if patient_ids == '*' else []
            if patient_ids != '*':
                contrast_files = []
                for patient_id in patient_ids:
                    contrast_files.extend(glob.glob(file_pattern.format(contrast=contrast, patient_id=patient_id, **keys)))
                filenames_by_contrast[contrast] = contrast_files

        # get the shape of one 3D volume and initialize the batch lists
        arbitrary_contrast = contrasts[0]
        shape = nib.load(filenames_by_contrast[arbitrary_contrast][0]).get_fdata().shape

        # initialize empty array of batches
        x_batch = np.empty((batch_size, ) + shape + (len(contrasts), )) #, dtype=np.int32)
        y_batch = np.empty((batch_size, ) + shape + (n_classes,)) #, dtype=np.int32)
        num_images = len(filenames_by_contrast[arbitrary_contrast])
        np.random.shuffle(filenames_by_contrast[arbitrary_contrast])
        for bindex in tqdm(range(0, num_images, batch_size), total=num_images):
            filenames = filenames_by_contrast[arbitrary_contrast][bindex:bindex + batch_size]
            for findex, filename in enumerate(filenames):
                for cindex, contrast in enumerate(contrasts):

                    # load raw image batches and normalize the pixels
                    tmp_img = nib.load(filename.replace(arbitrary_contrast, contrast)).get_fdata()
                    tmp_img = scalar.fit_transform(tmp_img.reshape(-1, tmp_img.shape[-1])).reshape(tmp_img.shape)
                    x_batch[findex, ..., cindex] = tmp_img

                    # load mask batches and change to categorical
                    tmp_mask = nib.load(filename.replace(arbitrary_contrast, 'seg')).get_fdata()
                    tmp_mask[tmp_mask==4] = 3
                    tmp_mask = tmp_mask
                    tmp_mask = to_categorical(tmp_mask, num_classes = 4)
                    y_batch[findex] = tmp_mask

            if bindex + batch_size > num_images:
                x_batch, y_batch = x_batch[:num_images - bindex], y_batch[:num_images - bindex]
            if augment_size is not None:
                # x_aug, y_aug = augment(x_batch, y_batch, augment_size)
                x_aug = None
                y_aug = None
                yield np.append(x_batch, x_aug), np.append(y_batch, y_aug)
            else:
                yield x_batch, y_batch
        if not infinite:
            break

Model Architecture Hyperparameters
---

In [4]:
prefix = '/home/atom/Documents/datasets/brats' # Adam's Station
batch_size = 4
contrasts = ['t1ce', 'flair', 't2', 't1']

In [5]:
brats_classes = 4
brats_contrasts = 4
brats_x = 240
brats_y = 240
brats_z = 155

block_side = 24 # W in the paper
patch_side = 8 # w in the paper, so n = W/w = 3, N = 27
embedding_size = 1024 # D
transformer_blocks = 5 # K
msa_heads = 3
mlp_size = 1024

dropout = 0.2
max_epochs = 5
learning_rate = 0.001

In [6]:
device

device(type='cuda', index=0)

In [7]:
model = Transformer3dSeg(subvol_dim=block_side, 
                         patch_dim=patch_side, 
                         num_classes=brats_classes,
                         in_channels=brats_contrasts,
                         dim=embedding_size,
                         blocks=transformer_blocks, 
                         heads=msa_heads, 
                         dim_linear_block=mlp_size,
                         dropout=dropout) #, transformer=TransformerEncoder)
model = model.to(device)

In [8]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
running_loss = 0
for epoch in range(max_epochs):
    for img, mask in generate_brats_batch(prefix, contrasts, batch_size=batch_size):
        # for img, mask in zip(imgs, masks):
        for i in range(0, brats_x, block_side):
            for j in range(0, brats_y, block_side):
                for k in range(6, brats_z - block_side, block_side):
                    img_block = img[:, i:i+block_side, j:j+block_side, k:k+block_side]
                    mask_block = mask[:, i:i+block_side, j:j+block_side, k:k+block_side]
                        
        # img_block, mask_block = torch.FloatTensor(np.rollaxis(img_block, -1, 1)[..., 6:150]), torch.FloatTensor(np.rollaxis(mask_block, -1, 1)[..., 6:150])
                    img_block, mask_block = torch.FloatTensor(np.rollaxis(img_block, -1, 1)), torch.FloatTensor(np.rollaxis(mask_block, -1, 1))
                    img_block, mask_block = img_block.to(device), mask_block.to(device)

                    optimizer.zero_grad()
                    output = model(img_block)
                    # TODO: get max (worst case) segment class within each patch, not just centre
                    current_loss = loss(output, mask_block[..., patch_side // 2::patch_side, patch_side // 2::patch_side, patch_side // 2::patch_side].argmax(axis=1))
                    current_loss.backward()
                    optimizer.step()
                    running_loss += current_loss.item()

 25%|███████████████████████▏                                                                    | 72/285 [20:03<59:20, 16.72s/it]
 25%|███████████████████████▏                                                                    | 72/285 [20:11<59:43, 16.82s/it]
 25%|██████████████████████▋                                                                   | 72/285 [20:18<1:00:03, 16.92s/it]
 25%|██████████████████████▋                                                                   | 72/285 [20:26<1:00:27, 17.03s/it]
 25%|██████████████████████▋                                                                   | 72/285 [20:31<1:00:43, 17.11s/it]
 25%|██████████████████████▋                                                                   | 72/285 [20:28<1:00:33, 17.06s/it]
 25%|███████████████████████▏                                                                    | 72/285 [19:56<58:58, 16.61s/it]
  3%|██▊                                                                           