In [1]:
import glob
import datetime

from tqdm import tqdm
from tqdm import trange
import nibabel as nib
import numpy as np
from sklearn.preprocessing import MinMaxScaler
scalar = MinMaxScaler()
import torch
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

from self_attention_cv.Transformer3Dsegmentation.tranf3Dseg import Transformer3dSeg
from self_attention_cv.transunet import TransUnet
from self_attention_cv import TransformerEncoder
from self_attention_cv import ViT, ResNet50ViT

In [2]:
class DiceLoss(torch.nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):        
        inputs = torch.sigmoid(inputs)       
        
        # flatten label and prediction tensors
        # inputs = inputs.view(-1)
        # targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

In [3]:
class FocalLoss(torch.nn.modules.loss._WeightedLoss):
    def __init__(self, weight=None, gamma=2, reduction='sum'):
        super(FocalLoss, self).__init__(weight, reduction=reduction)
        self.gamma = gamma
        self.weight = weight # weight parameter will act as the alpha parameter to balance class weights

    def forward(self, input, target):
        ce_loss = torch.nn.functional.cross_entropy(input, target, reduction=self.reduction, weight=self.weight)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).sum()
        return focal_loss

In [4]:
def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    return np.eye(num_classes, dtype='uint8')[y.astype(np.int16)]

In [5]:
def generate_brats_batch(prefix, 
                         contrasts, 
                         batch_size=32, 
                         tumour='*', 
                         patient_ids='*',
                         augment_size=None,
                         infinite=True):
    """
    Generate arrays for each batch, for x (data) and y (labels), where the contrast is treated like a colour channel.
    
    Example:
    x_batch shape: (32, 240, 240, 155, 4)
    y_batch shape: (32, 240, 240, 155)
    
    augment_size must be less than or equal to the batch_size, if None will not augment.
    
    """
    file_pattern = '{prefix}/MICCAI_BraTS_2018_Data_Training/{tumour}/{patient_id}/{patient_id}_{contrast}.nii.gz'
    while True:
        n_classes = 4

        # get list of filenames for every contrast available
        keys = dict(prefix=prefix, tumour=tumour)
        filenames_by_contrast = {}
        for contrast in contrasts:
            filenames_by_contrast[contrast] = glob.glob(file_pattern.format(contrast=contrast, patient_id=patient_ids, **keys)) if patient_ids == '*' else []
            if patient_ids != '*':
                contrast_files = []
                for patient_id in patient_ids:
                    contrast_files.extend(glob.glob(file_pattern.format(contrast=contrast, patient_id=patient_id, **keys)))
                filenames_by_contrast[contrast] = contrast_files

        # get the shape of one 3D volume and initialize the batch lists
        arbitrary_contrast = contrasts[0]
        shape = nib.load(filenames_by_contrast[arbitrary_contrast][0]).get_fdata().shape

        # initialize empty array of batches
        x_batch = np.empty((batch_size, ) + shape + (len(contrasts), )) #, dtype=np.int32)
        y_batch = np.empty((batch_size, ) + shape + (n_classes,)) #, dtype=np.int32)
        num_images = len(filenames_by_contrast[arbitrary_contrast])
        np.random.shuffle(filenames_by_contrast[arbitrary_contrast])
        for bindex in trange(0, num_images, batch_size):
            filenames = filenames_by_contrast[arbitrary_contrast][bindex:bindex + batch_size]
            for findex, filename in enumerate(filenames):
                for cindex, contrast in enumerate(contrasts):

                    # load raw image batches and normalize the pixels
                    tmp_img = nib.load(filename.replace(arbitrary_contrast, contrast)).get_fdata()
                    tmp_img = scalar.fit_transform(tmp_img.reshape(-1, tmp_img.shape[-1])).reshape(tmp_img.shape)
                    x_batch[findex, ..., cindex] = tmp_img

                    # load mask batches and change to categorical
                    tmp_mask = nib.load(filename.replace(arbitrary_contrast, 'seg')).get_fdata()
                    tmp_mask[tmp_mask==4] = 3
                    tmp_mask = to_categorical(tmp_mask, num_classes = 4)
                    y_batch[findex] = tmp_mask

            if bindex + batch_size > num_images:
                x_batch, y_batch = x_batch[:num_images - bindex], y_batch[:num_images - bindex]
            if augment_size is not None:
                # x_aug, y_aug = augment(x_batch, y_batch, augment_size)
                x_aug = None
                y_aug = None
                yield np.append(x_batch, x_aug), np.append(y_batch, y_aug)
            else:
                yield x_batch, y_batch
        if not infinite:
            break

Model Architecture Hyperparameters
---

In [6]:
prefix = '/home/atom/Documents/datasets/brats' # Adam's Station
output_dir = prefix + '/transformer_models/'
batch_size = 16
contrasts = ['t1ce', 'flair', 't2', 't1']

In [7]:
brats_classes = 4
brats_contrasts = 4
brats_x = 240
brats_y = 240
brats_z = 155

block_side = 24 # W in the paper
patch_side = 8 # w in the paper, so n = W/w = 3, N = 27
embedding_size = 1024 # D
transformer_blocks = 5 # K
msa_heads = 4
mlp_size = 1024

dropout = 0.15
max_epochs = 50
learning_rate = 0.003

In [8]:
device

device(type='cuda', index=0)

In [9]:
model = Transformer3dSeg(subvol_dim=block_side, 
                         patch_dim=patch_side, 
                         num_classes=brats_classes,
                         in_channels=brats_contrasts,
                         dim=embedding_size,
                         blocks=transformer_blocks, 
                         heads=msa_heads, 
                         dim_linear_block=mlp_size,
                         dropout=dropout) #, transformer=TransformerEncoder)
model = model.to(device)

In [10]:
loss = torch.nn.CrossEntropyLoss()
focal_loss = FocalLoss(gamma=2) #, alpha=0.25, size_average=False)
dice_loss = DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Training and Validation
---

In [11]:
import random, os

brats_dir = '/MICCAI_BraTS_2018_Data_Training/'

data_list_LGG = os.listdir(os.path.join(prefix+brats_dir,'LGG'))
data_list_HGG = os.listdir(os.path.join(prefix+brats_dir,'HGG'))
dataset_file_list = data_list_HGG + data_list_LGG

# shuffle and split the dataset file list
random.seed(42)
file_list_shuffled = dataset_file_list.copy()
random.shuffle(file_list_shuffled)
test_ratio = 0.2

train_file, test_file = file_list_shuffled[0:int(len(file_list_shuffled)*(1-test_ratio))], file_list_shuffled[int(len(file_list_shuffled)*(1-test_ratio)):]

while '.DS_Store' in train_file:
    train_file.remove('.DS_Store')
while '.DS_Store' in test_file:
    test_file.remove('.DS_Store')

In [12]:
def train():
    model.train()
    running_loss = 0
    count = 0
    for img, mask in generate_brats_batch(prefix, contrasts, batch_size=batch_size, patient_ids=train_file, infinite=False):
        # img (8, 240, 240, 155, 4) -> (8, 4, 240, 240, 155)
        img, mask = np.rollaxis(img, -1, 1), np.rollaxis(mask, -1, 1)
        img_gpu, mask_gpu = torch.FloatTensor(img).to(device), torch.FloatTensor(mask).to(device)
        for i in range(0, brats_x, block_side):
            for j in range(0, brats_y, block_side):
                for k in range(6, brats_z - block_side, block_side):
                    img_block = img_gpu[..., i:i+block_side, j:j+block_side, k:k+block_side]
                    mask_block = mask_gpu[..., i:i+block_side, j:j+block_side, k:k+block_side]
                    optimizer.zero_grad()
                    output = model(img_block)
                    
                    # slice near the (slightly off-centre) centre of the patch to choose the class
                    # patch_mask = mask_block[..., patch_side // 2::patch_side, patch_side // 2::patch_side, patch_side // 2::patch_side]
                    
                    # combined segmentation results (total counts for all classes across patch)
                    patch_mask = mask_block.reshape(-1, brats_classes, block_side // patch_side, patch_side, block_side // patch_side, patch_side, block_side // patch_side, patch_side).sum(axis=-1).sum(axis=-2).sum(axis=-3) / 2 ** (3 * block_side // patch_side)
                    # patch_mask = torch.round(patch_mask).to(torch.int64)
                    
                    # current_loss = loss(output, patch_mask.argmax(axis=1))
                    current_loss = focal_loss(output, patch_mask) + dice_loss(output, patch_mask)
                    
                    current_loss.backward()
                    optimizer.step()
                    running_loss += current_loss.item()
                    count += 1
        print(f'Training batch loss: {running_loss / count}')
        writer.add_scalar('Training batch loss', running_loss / count)
    return running_loss / count

In [13]:
def validate():      
    model.eval()
    with torch.no_grad():
        running_loss = 0
        count = 0
        for img, mask in generate_brats_batch(prefix, contrasts, batch_size=batch_size, patient_ids=test_file, infinite=False):
            # img (8, 240, 240, 155, 4) -> (8, 4, 240, 240, 155)
            img, mask = np.rollaxis(img, -1, 1), np.rollaxis(mask, -1, 1)
            img_gpu, mask_gpu = torch.FloatTensor(img).to(device), torch.FloatTensor(mask).to(device)
            for i in range(0, brats_x, block_side):
                for j in range(0, brats_y, block_side):
                    for k in range(6, brats_z - block_side, block_side):
                        img_block = img_gpu[..., i:i+block_side, j:j+block_side, k:k+block_side]
                        mask_block = mask_gpu[..., i:i+block_side, j:j+block_side, k:k+block_side]
                        output = model(img_block)

                        # slice near the (slightly off-centre) centre of the patch to choose the class
                        # patch_mask = mask_block[..., patch_side // 2::patch_side, patch_side // 2::patch_side, patch_side // 2::patch_side]

                        # combined segmentation results (total counts for all classes across patch)
                        patch_mask = mask_block.reshape(-1, brats_classes, block_side // patch_side, patch_side, block_side // patch_side, patch_side, block_side // patch_side, patch_side).sum(axis=-1).sum(axis=-2).sum(axis=-3) / 2 ** (3 * block_side // patch_side)
                        # patch_mask = torch.round(patch_mask).to(torch.int64)
                        
                        # current_loss = loss(output, patch_mask.argmax(axis=1))
                        current_loss = focal_loss(output, patch_mask) + dice_loss(output, patch_mask)

                        running_loss += current_loss.item()
                        count += 1
    return running_loss / count

In [14]:
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
best_val_loss = np.inf

for epoch in range(max_epochs):
    train_loss = train()
    val_loss = validate()
    
    print(f'Epoch {epoch}: LOSS train {train_loss}; validation {val_loss}')

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                       {'Training' : train_loss, 'Validation' : val_loss}, epoch)
    writer.flush()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        model_path = output_dir + f'model_{timestamp}_{epoch}'
        torch.save(model.state_dict(), model_path)

  7%|██████▎                                                                                       | 1/15 [00:40<09:27, 40.56s/it]

Training batch loss: 62.681525887950166


 13%|████████████▌                                                                                 | 2/15 [01:20<08:41, 40.12s/it]

Training batch loss: 52.45108463408619


 20%|██████████████████▊                                                                           | 3/15 [01:59<07:58, 39.88s/it]

Training batch loss: 45.82528668188102


 27%|█████████████████████████                                                                     | 4/15 [02:39<07:18, 39.88s/it]

Training batch loss: 41.4151082617318


 33%|███████████████████████████████▎                                                              | 5/15 [03:19<06:38, 39.90s/it]

Training batch loss: 38.72588729434041


 40%|█████████████████████████████████████▌                                                        | 6/15 [03:59<05:59, 39.93s/it]

Training batch loss: 37.40398311541222


 47%|███████████████████████████████████████████▊                                                  | 7/15 [04:39<05:19, 39.99s/it]

Training batch loss: 36.239167474737236


 53%|██████████████████████████████████████████████████▏                                           | 8/15 [05:19<04:40, 40.01s/it]

Training batch loss: 35.816772397969714


 60%|████████████████████████████████████████████████████████▍                                     | 9/15 [06:00<04:00, 40.05s/it]

Training batch loss: 35.221254653520276


 67%|██████████████████████████████████████████████████████████████                               | 10/15 [06:40<03:20, 40.08s/it]

Training batch loss: 35.39871095958856


 73%|████████████████████████████████████████████████████████████████████▏                        | 11/15 [07:20<02:40, 40.01s/it]

Training batch loss: 34.82974914445379


 80%|██████████████████████████████████████████████████████████████████████████▍                  | 12/15 [07:59<01:59, 39.90s/it]

Training batch loss: 34.46639089821611


 87%|████████████████████████████████████████████████████████████████████████████████▌            | 13/15 [08:39<01:19, 39.96s/it]

Training batch loss: 34.686951077197


 93%|██████████████████████████████████████████████████████████████████████████████████████▊      | 14/15 [09:19<00:39, 39.94s/it]

Training batch loss: 34.37614883039152


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [09:34<00:00, 38.29s/it]


Training batch loss: 32.60390039670493


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [01:43<00:00, 25.91s/it]


Epoch 0: LOSS train 32.60390039670493; validation 38.440587958355124


  7%|██████▎                                                                                       | 1/15 [00:40<09:25, 40.39s/it]

Training batch loss: 30.032193303943302


 13%|████████████▌                                                                                 | 2/15 [01:20<08:42, 40.16s/it]

Training batch loss: 31.77421239723762


 20%|██████████████████▊                                                                           | 3/15 [02:00<08:02, 40.19s/it]

Training batch loss: 31.57218762432846


 27%|█████████████████████████                                                                     | 4/15 [02:40<07:22, 40.22s/it]

Training batch loss: 30.917413500943997


 33%|███████████████████████████████▎                                                              | 5/15 [03:21<06:42, 40.21s/it]

Training batch loss: 30.78353777543828


 40%|█████████████████████████████████████▌                                                        | 6/15 [04:01<06:01, 40.14s/it]

Training batch loss: 31.692633887784645


 47%|███████████████████████████████████████████▊                                                  | 7/15 [04:41<05:20, 40.07s/it]

Training batch loss: 31.76211533081959


 53%|██████████████████████████████████████████████████▏                                           | 8/15 [05:21<04:41, 40.19s/it]

Training batch loss: 31.615609678601395


 60%|████████████████████████████████████████████████████████▍                                     | 9/15 [06:01<04:01, 40.19s/it]

Training batch loss: 32.2220095830564


 67%|██████████████████████████████████████████████████████████████                               | 10/15 [06:41<03:20, 40.11s/it]

Training batch loss: 32.139558829386544


 73%|████████████████████████████████████████████████████████████████████▏                        | 11/15 [07:21<02:40, 40.15s/it]

Training batch loss: 31.567465552182878


 80%|██████████████████████████████████████████████████████████████████████████▍                  | 12/15 [08:01<02:00, 40.12s/it]

Training batch loss: 31.464198989812658


 87%|████████████████████████████████████████████████████████████████████████████████▌            | 13/15 [08:41<01:20, 40.06s/it]

Training batch loss: 31.262181725951063


 87%|████████████████████████████████████████████████████████████████████████████████▌            | 13/15 [09:06<01:24, 42.01s/it]


KeyboardInterrupt: 