In [1]:
import glob
import datetime

from tqdm import tqdm
import nibabel as nib
import numpy as np
from sklearn.preprocessing import MinMaxScaler
scalar = MinMaxScaler()
import torch
from torch.utils.tensorboard import SummaryWriter
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

from self_attention_cv.Transformer3Dsegmentation.tranf3Dseg import Transformer3dSeg
from self_attention_cv.transunet import TransUnet
from self_attention_cv import TransformerEncoder
from self_attention_cv import ViT, ResNet50ViT

In [2]:
def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    return np.eye(num_classes, dtype='uint8')[y.astype(np.int16)]

In [3]:
def generate_brats_batch(prefix, 
                         contrasts, 
                         batch_size=32, 
                         tumour='*', 
                         patient_ids='*',
                         augment_size=None,
                         infinite=True):
    """
    Generate arrays for each batch, for x (data) and y (labels), where the contrast is treated like a colour channel.
    
    Example:
    x_batch shape: (32, 240, 240, 155, 4)
    y_batch shape: (32, 240, 240, 155)
    
    augment_size must be less than or equal to the batch_size, if None will not augment.
    
    """
    file_pattern = '{prefix}/MICCAI_BraTS_2018_Data_Training/{tumour}/{patient_id}/{patient_id}_{contrast}.nii.gz'
    while True:
        n_classes = 4

        # get list of filenames for every contrast available
        keys = dict(prefix=prefix, tumour=tumour)
        filenames_by_contrast = {}
        for contrast in contrasts:
            filenames_by_contrast[contrast] = glob.glob(file_pattern.format(contrast=contrast, patient_id=patient_ids, **keys)) if patient_ids == '*' else []
            if patient_ids != '*':
                contrast_files = []
                for patient_id in patient_ids:
                    contrast_files.extend(glob.glob(file_pattern.format(contrast=contrast, patient_id=patient_id, **keys)))
                filenames_by_contrast[contrast] = contrast_files

        # get the shape of one 3D volume and initialize the batch lists
        arbitrary_contrast = contrasts[0]
        shape = nib.load(filenames_by_contrast[arbitrary_contrast][0]).get_fdata().shape

        # initialize empty array of batches
        x_batch = np.empty((batch_size, ) + shape + (len(contrasts), )) #, dtype=np.int32)
        y_batch = np.empty((batch_size, ) + shape + (n_classes,)) #, dtype=np.int32)
        num_images = len(filenames_by_contrast[arbitrary_contrast])
        np.random.shuffle(filenames_by_contrast[arbitrary_contrast])
        for bindex in tqdm(range(0, num_images, batch_size), total=num_images):
            filenames = filenames_by_contrast[arbitrary_contrast][bindex:bindex + batch_size]
            for findex, filename in enumerate(filenames):
                for cindex, contrast in enumerate(contrasts):

                    # load raw image batches and normalize the pixels
                    tmp_img = nib.load(filename.replace(arbitrary_contrast, contrast)).get_fdata()
                    tmp_img = scalar.fit_transform(tmp_img.reshape(-1, tmp_img.shape[-1])).reshape(tmp_img.shape)
                    x_batch[findex, ..., cindex] = tmp_img

                    # load mask batches and change to categorical
                    tmp_mask = nib.load(filename.replace(arbitrary_contrast, 'seg')).get_fdata()
                    tmp_mask[tmp_mask==4] = 3
                    tmp_mask = tmp_mask
                    tmp_mask = to_categorical(tmp_mask, num_classes = 4)
                    y_batch[findex] = tmp_mask

            if bindex + batch_size > num_images:
                x_batch, y_batch = x_batch[:num_images - bindex], y_batch[:num_images - bindex]
            if augment_size is not None:
                # x_aug, y_aug = augment(x_batch, y_batch, augment_size)
                x_aug = None
                y_aug = None
                yield np.append(x_batch, x_aug), np.append(y_batch, y_aug)
            else:
                yield x_batch, y_batch
        if not infinite:
            break

Model Architecture Hyperparameters
---

In [19]:
prefix = '/home/atom/Documents/datasets/brats' # Adam's Station
output_dir = prefix + '/transformer_models/'
batch_size = 8
contrasts = ['t1ce', 'flair', 't2', 't1']

In [20]:
brats_classes = 4
brats_contrasts = 4
brats_x = 240
brats_y = 240
brats_z = 155

block_side = 24 # W in the paper
patch_side = 8 # w in the paper, so n = W/w = 3, N = 27
embedding_size = 1024 # D
transformer_blocks = 5 # K
msa_heads = 3
mlp_size = 1024

dropout = 0.2
max_epochs = 5
learning_rate = 0.001

In [21]:
device

device(type='cuda', index=0)

In [22]:
model = Transformer3dSeg(subvol_dim=block_side, 
                         patch_dim=patch_side, 
                         num_classes=brats_classes,
                         in_channels=brats_contrasts,
                         dim=embedding_size,
                         blocks=transformer_blocks, 
                         heads=msa_heads, 
                         dim_linear_block=mlp_size,
                         dropout=dropout) #, transformer=TransformerEncoder)
model = model.to(device)

In [23]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Training and Validation
---

In [24]:
import random, os

brats_dir = '/MICCAI_BraTS_2018_Data_Training/'

data_list_LGG = os.listdir(os.path.join(prefix+brats_dir,'LGG'))
data_list_HGG = os.listdir(os.path.join(prefix+brats_dir,'HGG'))
dataset_file_list = data_list_HGG + data_list_LGG

# shuffle and split the dataset file list
random.seed(42)
file_list_shuffled = dataset_file_list.copy()
random.shuffle(file_list_shuffled)
test_ratio = 0.2

train_file, test_file = file_list_shuffled[0:int(len(file_list_shuffled)*(1-test_ratio))], file_list_shuffled[int(len(file_list_shuffled)*(1-test_ratio)):]

while '.DS_Store' in train_file:
    train_file.remove('.DS_Store')
while '.DS_Store' in test_file:
    test_file.remove('.DS_Store')

In [31]:
def train():
    model.train()
    running_loss = 0
    count = 0
    for img, mask in generate_brats_batch(prefix, contrasts, batch_size=batch_size, patient_ids=train_file):
        # img (8, 240, 240, 155, 4) -> (8, 4, 240, 240, 155)
        img, mask = np.rollaxis(img, -1, 1), np.rollaxis(mask, -1, 1)
        img_gpu, mask_gpu = torch.FloatTensor(img).to(device), torch.FloatTensor(mask).to(device)
        for i in range(0, brats_x, block_side):
            for j in range(0, brats_y, block_side):
                for k in range(6, brats_z - block_side, block_side):
                    img_block = img_gpu[..., i:i+block_side, j:j+block_side, k:k+block_side]
                    mask_block = mask_gpu[..., i:i+block_side, j:j+block_side, k:k+block_side]
                    optimizer.zero_grad()
                    output = model(img_block)
                    # TODO: get max (worst case) segment class within each patch, not just centre
                    current_loss = loss(output, mask_block[..., patch_side // 2::patch_side, patch_side // 2::patch_side, patch_side // 2::patch_side].argmax(axis=1))
                    current_loss.backward()
                    optimizer.step()
                    running_loss += current_loss.item()
                    count += 1
        print(f'Training batch loss: {running_loss / count}')
        writer.add_scalar('Training batch loss', running_loss / count)
    return running_loss / count

In [32]:
def validate():      
    model.eval()
    with torch.no_grad():
        running_loss = 0
        count = 0
        for img, mask in generate_brats_batch(prefix, contrasts, batch_size=batch_size, patient_ids=test_file):
            # img (8, 240, 240, 155, 4) -> (8, 4, 240, 240, 155)
            img, mask = np.rollaxis(img, -1, 1), np.rollaxis(mask, -1, 1)
            img_gpu, mask_gpu = torch.FloatTensor(img).to(device), torch.FloatTensor(mask).to(device)
            for i in range(0, brats_x, block_side):
                for j in range(0, brats_y, block_side):
                    for k in range(6, brats_z - block_side, block_side):
                        img_block = img_gpu[..., i:i+block_side, j:j+block_side, k:k+block_side]
                        mask_block = mask_gpu[..., i:i+block_side, j:j+block_side, k:k+block_side]
                        output = model(img_block)
                        current_loss = loss(output, mask_block[..., patch_side // 2::patch_side, patch_side // 2::patch_side, patch_side // 2::patch_side].argmax(axis=1))
                        running_loss += current_loss.item()
    return running_loss / count

In [33]:
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
best_val_loss = np.inf

for epoch in range(max_epochs):
    train_loss = train()
    val_loss = validate()
    
    print('Epoch {epoch}: LOSS train {train_loss}; validation {val_loss}')

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                       {'Training' : train_loss, 'Validation' : val_loss}, epoch)
    writer.flush()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        model_path = output_dir + f'model_{timestamp}_{epoch_number}'
        torch.save(model.state_dict(), model_path)

  0%|▍                                                                                          | 1/227 [00:27<1:44:40, 27.79s/it]

Training batch loss: 0.09549589879705687


  1%|▊                                                                                          | 2/227 [00:51<1:36:12, 25.66s/it]

Training batch loss: 0.08971404869859612


  1%|█▏                                                                                         | 3/227 [01:16<1:33:26, 25.03s/it]

Training batch loss: 0.08543367897220706


  2%|█▌                                                                                         | 4/227 [01:40<1:31:42, 24.68s/it]

Training batch loss: 0.08326041705112497


  2%|██                                                                                         | 5/227 [02:04<1:30:40, 24.51s/it]

Training batch loss: 0.07809638480151701


  3%|██▍                                                                                        | 6/227 [02:28<1:29:48, 24.38s/it]

Training batch loss: 0.07424449357590031


  3%|██▍                                                                                        | 6/227 [02:50<1:44:22, 28.34s/it]


KeyboardInterrupt: 