In [34]:
import os
import re
import nibabel as nib
import torch
import torch.utils.data as td
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import math

# Possible improvements

* Go 3D (e.g. V-Net)
* Change model parameters (more down/up layers)
* Fix data loader
 - Detect empty slices
 - Detect slices without matching labels
 - Data augmentation (cropping, deform)
* Try other optimizers
* Tune hyperparameters
 - k-fold-Validation
* One Channel, Transfer Learning, Add another layer later
* Try other datasets (hippocampus, liver)
* Experiment with Loss Functions
 - https://lars76.github.io/neural-networks/object-detection/losses-for-segmentation/
 - https://github.com/JunMa11/SegLoss
* Accuracy metric (measure overlay predicted/true)
* Participate in Challenge

# Orga
* Comment code
* Try on Colab
* Upload to GitHub
* Executive Summary

In [2]:
# From d2l.ai
def try_gpu(i=0):  #@save
    """Return gpu(i) if exists, otherwise return cpu()."""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

In [3]:
class BRATS(td.Dataset):
    def __init__(self, data_dir, train):
        super(BRATS).__init__()
        self.train = train
        self.data_dir = data_dir
        self.data = []
        self.filenames = []
        match_brats_filename = re.compile(r"^BRATS_[0-9]{3}.nii.gz$")
        if self.train:
            self.labels = []
            image_dir = os.path.join(self.data_dir, 'imagesTr')
            label_dir = os.path.join(self.data_dir, 'labelsTr')
        else:
            image_dir = os.path.join(self.data_dir, 'imagesTs')

        for fn in os.listdir(image_dir):
            abs_fn = os.path.join(image_dir, fn)
            if os.path.isfile(abs_fn):
                #print(fn)
                if match_brats_filename.match(fn):
                    nifti_img = nib.load(abs_fn)
                    #nifti_data = nifti_img.get_fdata()
                    self.data.append(nifti_img)
                    self.filenames.append(abs_fn)
        
        self.slice_offset_start = 40
        self.slice_offset_end = 40
        self.slice_count = 155 - self.slice_offset_start - self.slice_offset_end
        
        self.transform_cache = {}
        
        self.num_files = len(self.data)
        print('Loaded {} files'.format(self.num_files))
        if self.train:
            for fn in os.listdir(image_dir):
                abs_fn = os.path.join(label_dir, fn)
                if os.path.isfile(abs_fn) and match_brats_filename.match(fn):
                    nifti_img = nib.load(abs_fn)
                    self.labels.append(nifti_img)
            print('Loaded {} labels'.format(len(self.labels)))
        
    def __len__(self):
        return self.num_files * self.slice_count
    
    def __getitem__(self, idx):
        sample_index = idx // self.slice_count
        slice_index = self.slice_offset_start + idx % self.slice_count
        #print('sample_index = {}, slice_index = {}'.format(sample_index, slice_index)
        
        nifti_data = np.asarray(self.data[sample_index].dataobj[:, :, slice_index, :])
        nifti_slice = torch.from_numpy(np.copy(nifti_data)).transpose(0,2)
        
        if idx in self.transform_cache:
            nifti_normalize = self.transform_cache[idx]
        else:
            nifti_norm = np.linalg.norm(nifti_slice)
            if (nifti_norm > 0.0):
                nifti_mean = torch.mean(nifti_slice, dim=(1,2)) +1e-9
                nifti_std = torch.std(nifti_slice, dim=(1,2)) +1e-9
                #print('sampleidx = {}, iidx = {}, fn = {}, mean = {}, std = {}, shape = {}'.format(idx, sample_index, self.filenames[sample_index], nifti_mean, nifti_std, nifti_slice.shape))
                nifti_normalize = torchvision.transforms.Normalize(nifti_mean, nifti_std)
            else:
                nifti_normalize = None
            self.transform_cache[idx] = nifti_normalize
        
        if nifti_normalize is not None:
            nifti_slice = nifti_normalize(nifti_slice)
        
        if self.train:
            label_data = np.asarray(self.labels[sample_index].dataobj[:, :, slice_index])
            label_slice = torch.from_numpy(np.copy(label_data)).transpose(0,1)

        return nifti_slice, label_slice

In [4]:
brats_train = BRATS('Task01_BrainTumour', train=True)
brats_test = BRATS('Task01_BrainTumour', train=False)

Loaded 484 files
Loaded 484 labels
Loaded 266 files


In [5]:
batch_size = 8
num_workers = 6

train_iter = td.DataLoader(brats_train, batch_size, shuffle=True, num_workers=num_workers)
test_iter = td.DataLoader(brats_test, batch_size, shuffle=False, num_workers=num_workers)

In [59]:
class Down(nn.Module):
    def __init__(self, size):
        super(Down, self).__init__()
        assert(size > 2)
        size2 = 1 << size
        size2m1 = 1 << (size - 1)

        self.down = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(size2m1, size2, kernel_size=3, padding=1),
            nn.BatchNorm2d(size2),
            nn.ReLU(),
            nn.Conv2d(size2, size2, kernel_size=3, padding=1),
            nn.BatchNorm2d(size2),
            nn.ReLU())

    def forward(self, X):
        #print('Xs = {}'.format(X.shape))
        return self.down(X)

class Up(nn.Module):
    def __init__(self, size, padding=0):
        super(Up, self).__init__()
        
        assert(size > 2)
        assert(padding >= 0)
        
        size2 = 1 << size # 2^10 = 1024
        size2m1 = 1 << (size - 1) # 2^9 = 512
        size2p1 = 1 << (size + 1) # 2^11 = 2048
        
        padding_top = padding // 2
        padding_bottom = padding - padding_top
        padding_left = padding // 2
        padding_right = padding - padding_left

        self.up = nn.ConvTranspose2d(size2p1, size2, kernel_size=2, stride=2)
        #self.up = nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2)
        self.pad = nn.ReflectionPad2d(padding=(padding_left, padding_right, padding_top, padding_bottom))
        self.conv = nn.Sequential(
            nn.Conv2d(size2p1, size2, kernel_size=3),
            nn.BatchNorm2d(size2),
            nn.ReLU(),
            nn.Conv2d(size2, size2, kernel_size=3),
            nn.BatchNorm2d(size2),
            nn.ReLU())

    def forward(self, X, Z):        
        X = self.up(X)
        
        #diff_x = Z.size()[2] - X.size()[2]
        #diff_y = Z.size()[3] - X.size()[3]
        #pad_x = diff_x >> 1
        #pad_y = diff_y >> 1
        #print('diff_x = {}, diff_y = {}'.format(diff_x, diff_y))
        #X = F.pad(X, pad=(pad_x, diff_x - pad_x, pad_y, diff_x - pad_y))

        X = torch.cat([self.pad(X), Z], dim=1)
        
        return self.conv(X)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        #self.num_inputs = num_inputs
        self.input_channels = 4
        
        self.down1 = nn.Sequential(
            nn.Conv2d(self.input_channels, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU())
        # 64  -> 128
        self.down2 = Down(7)
        # 128 -> 256
        self.down3 = Down(8)
        # 256 -> 512
        self.down4 = Down(9)
        # 512 -> 1024
        self.even5 = Down(10)
        # 1024 -> 512
        self.up4 = Up(9, 1)
        # 512  -> 256
        self.up3 = Up(8, 9)
        # 256  -> 128
        self.up2 = Up(7, 8)
        self.up1 = Up(6, 8)
        self.out = nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=1)
        )
        
        downup = nn.Sequential(
            self.down1,
            self.down2,
            self.down3,
            self.down4,
            self.even5,
            self.up4,
            self.up3,
            self.up2,
            self.up1
            )

    def forward(self, X):
        H1 = self.down1(X)
        H2 = self.down2(H1)
        H3 = self.down3(H2)
        H4 = self.down4(H3)
        H5 = self.even5(H4)
        H6 = self.up4(H5, H4)
        H7 = self.up3(H6, H3)
        H8 = self.up2(H7, H2)
        H9 = self.up1(H8, H1)
        out = F.pad(self.out(H9), pad=(4, 4, 4, 4))
        return out

def init_weights(m):
    if type(m) in [nn.Conv2d, nn.ConvTranspose2d]:
        nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0.01)

In [55]:
# https://github.com/pytorch/pytorch/issues/1249

def dice_loss(pred, target):
    smooth = 1.

    iflat = pred.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

In [8]:
bce_criterion = nn.BCEWithLogitsLoss(reduction='mean', pos_weight=torch.Tensor([5.0]))

In [9]:
def criterion(pred, target):
    return bce_criterion(pred, target) + dice_loss(pred, target)

In [11]:
num_epochs = 50
lr = 0.0001

In [12]:
net = Net()
#print(net)
net.apply(init_weights)

Net(
  (down1): Sequential(
    (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (down2): Down(
    (down): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
    )
  )
  (down3): Down(
    (down): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)


In [13]:
def check():
    X = torch.randn(size=(2, 4, 240, 240), dtype=torch.float32)
    y = net(X)
    
check()

In [14]:
#X = torch.randn(size=(8, 4, 240, 240), dtype=torch.float32)
#for layer in net.downup:
#    X = layer(X)
#    print(layer.__class__.__name__,'output shape: \t',X.shape)

In [60]:
#net = torch.load('checkpoint_1599.pt', map_location=torch.device('cpu'))
net = Net()
net.load_state_dict(torch.load('checkpoint_sd.pt'))

<All keys matched successfully>

In [19]:
run_iter = 0

In [51]:
run_iter += 1
batch_size=4
writer = SummaryWriter('runs/brats3_{}_{}'.format(lr, run_iter))
print('Writing graph')
X = torch.randn(size=(2, 4, 240, 240), dtype=torch.float32)
writer.add_graph(net, X)
print('done')

device = try_gpu()
print('Using device {}'.format(device))

net.to(device)

global_iter = 1

optimizer = torch.optim.SGD(net.parameters(), lr=lr)
#optimizer = torch.optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
#scheduler = ReduceLROnPlateau(optimizer, 'min')

datasize = len(brats_train)
data_per_epoch = int(math.floor(datasize/num_epochs))
print('total count = {}, num_epochs = {}, per epoch = {}'.format(datasize, num_epochs, data_per_epoch))

brats_train_perepoch = td.random_split(brats_train, torch.full(size=[num_epochs], fill_value=data_per_epoch, dtype=torch.int))

for epoch in range(num_epochs-1): 
    train_loss_epoch = 0.0
    
    train_iter = td.DataLoader(brats_train_perepoch[epoch], batch_size, shuffle=True, num_workers=num_workers)
    test_iter = td.DataLoader(brats_train_perepoch[epoch+1], batch_size, shuffle=True, num_workers=num_workers)
    #test_iter = td.DataLoader(brats_test, batch_size, shuffle=False, num_workers=num_workers)
    
    batch_count = len(train_iter)
    
    print('Total batches: {}'.format(batch_count))
    
    for i, (X, y) in enumerate(train_iter): 
        global_iter += 1      
        net.train()
        
        X = X.float().to(device)
        y = y.float().to(device)
        y_hat = net(X).squeeze(1)
        
        #l = criterion(y_hat, y)
        bl = bce_criterion(y_hat, y)
        dl = dice_loss(y_hat, y)
        l = bl + dl
        train_loss_epoch += float(l)
        
        optimizer.zero_grad()
        l.backward()
        nn.utils.clip_grad_value_(net.parameters(), 0.1)
        optimizer.step()
        
        if i > 0 and (i % 50 == 0):
            print('saving checkpoint...')
            torch.save(net, 'checkpoint_{}.pt'.format(i))
            print('done')
            for tag, value in net.named_parameters():
                tag = tag.replace('.', '/')
                writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_iter)
                writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_iter)
        
        writer.add_scalar('batch_loss', float(l), global_iter)
        writer.add_scalar('loss/total_loss', float(l), global_iter)
        writer.add_scalar('loss/bce_loss', float(bl), global_iter)
        writer.add_scalar('loss/dice_loss', float(dl), global_iter)
        writer.add_images('masks/0_base', X[:,0:3,:,:], global_iter)
        y_us = y.unsqueeze(1)
        y_hat_us = y_hat.unsqueeze(1)
        y_hat_us_sig = torch.sigmoid(y_hat_us) > 0.5
        writer.add_images('masks/1_true', y_us, global_iter)
        writer.add_images('masks/2_predicted', y_hat_us_sig, global_iter)
        writer.add_images('extra/raw', y_hat_us, global_iter)
        overlaid = torch.cat([y_hat_us_sig.float(), y_us, torch.zeros_like(y_us)], dim=1)
        writer.add_images('extra/overlaid', overlaid, global_iter)
        
        writer.flush()
    
        print('batch {:4}/{} batchloss {}'.format(i, batch_count, float(l)))

    train_loss_epoch /= batch_count
    
    writer.add_scalar('loss/train', train_loss_epoch, global_iter)
  
    net.eval()
    test_loss_epoch = 0.0
    with torch.no_grad():
        for X_test, y_test in test_iter:
            X_test = X_test.float().to(device)
            y_test = y_test.float().to(device)
            y_test_hat = net(X_test).squeeze(1)
            b_l = criterion(y_test_hat, y_test) 
            test_loss_epoch += float(b_l)
        test_loss_epoch /= len(test_iter)  
        
    writer.add_scalar('loss/test', test_loss_epoch, global_iter)
    
    print('epoch {}/{}, train loss {}, test loss {}'.format(epoch+1, num_epochs, train_loss_epoch, test_loss_epoch)) 

Writing graph
diff_x = 1, diff_y = 1
diff_x = 9, diff_y = 9
diff_x = 8, diff_y = 8
diff_x = 8, diff_y = 8
diff_x = 1, diff_y = 1
diff_x = 9, diff_y = 9
diff_x = 8, diff_y = 8
diff_x = 8, diff_y = 8
diff_x = 1, diff_y = 1
diff_x = 9, diff_y = 9
diff_x = 8, diff_y = 8
diff_x = 8, diff_y = 8
done
Using device cpu
total count = 36300, num_epochs = 50, per epoch = 726
Total batches: 182


KeyboardInterrupt: 

In [46]:
#torch.save(net, 'checkpoint_rg2.pt')
torch.save(net.state_dict(), 'checkpoint_sd.pt')

In [61]:
def export_onnx(filename):
    dummy_input = torch.randn(1, 4, 240, 240)
    torch.onnx.export(net, dummy_input, filename, verbose=True)
    
export_onnx("net5.onnx")

graph(%input.1 : Float(1, 4, 240, 240),
      %down1.0.weight : Float(64, 4, 3, 3),
      %down1.0.bias : Float(64),
      %down1.1.weight : Float(64),
      %down1.1.bias : Float(64),
      %down1.1.running_mean : Float(64),
      %down1.1.running_var : Float(64),
      %down1.3.weight : Float(64, 64, 3, 3),
      %down1.3.bias : Float(64),
      %down1.4.weight : Float(64),
      %down1.4.bias : Float(64),
      %down1.4.running_mean : Float(64),
      %down1.4.running_var : Float(64),
      %down2.down.1.weight : Float(128, 64, 3, 3),
      %down2.down.1.bias : Float(128),
      %down2.down.2.weight : Float(128),
      %down2.down.2.bias : Float(128),
      %down2.down.2.running_mean : Float(128),
      %down2.down.2.running_var : Float(128),
      %down2.down.4.weight : Float(128, 128, 3, 3),
      %down2.down.4.bias : Float(128),
      %down2.down.5.weight : Float(128),
      %down2.down.5.bias : Float(128),
      %down2.down.5.running_mean : Float(128),
      %down2.down.5.runnin