In [4]:
import os
import nibabel as nib
import matplotlib.pyplot as plt
import json
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import math
import glob
import random

In [5]:
ls Task01_BrainTumour

 Volume in drive C is Windows
 Volume Serial Number is EA97-91A8

 Directory of C:\Users\maxmc\OneDrive\Desktop\Python\UNet\Task01_BrainTumour

01/26/2024  03:49 AM    <DIR>          .
02/04/2024  11:59 AM    <DIR>          ..
07/03/2018  10:30 AM               384 ._dataset.json
05/14/2018  11:56 PM               120 ._imagesTr
05/26/2018  06:47 AM               120 ._imagesTs
07/03/2018  10:02 AM               120 ._labelsTr
07/03/2018  10:30 AM            46,231 dataset.json
05/14/2018  11:56 PM    <DIR>          imagesTr
05/26/2018  06:47 AM    <DIR>          imagesTs
07/03/2018  10:02 AM    <DIR>          labelsTr
               5 File(s)         46,975 bytes
               5 Dir(s)  1,747,753,529,344 bytes free


In [6]:
task1 = 'Task01_BrainTumour'
train_dir = f'{task1}/imagesTr'
label_dir = f'{task1}/labelsTr'
test_dir = f'{task1}/imagesTs'

In [7]:
info = open(f'{task1}/dataset.json')
json.load(info)

{'name': 'BRATS',
 'description': 'Gliomas segmentation tumour and oedema in on brain images',
 'reference': 'https://www.med.upenn.edu/sbia/brats2017.html',
 'licence': 'CC-BY-SA 4.0',
 'release': '2.0 04/05/2018',
 'tensorImageSize': '4D',
 'modality': {'0': 'FLAIR', '1': 'T1w', '2': 't1gd', '3': 'T2w'},
 'labels': {'0': 'background',
  '1': 'edema',
  '2': 'non-enhancing tumor',
  '3': 'enhancing tumour'},
 'numTraining': 484,
 'numTest': 266,
 'training': [{'image': './imagesTr/BRATS_457.nii.gz',
   'label': './labelsTr/BRATS_457.nii.gz'},
  {'image': './imagesTr/BRATS_306.nii.gz',
   'label': './labelsTr/BRATS_306.nii.gz'},
  {'image': './imagesTr/BRATS_206.nii.gz',
   'label': './labelsTr/BRATS_206.nii.gz'},
  {'image': './imagesTr/BRATS_449.nii.gz',
   'label': './labelsTr/BRATS_449.nii.gz'},
  {'image': './imagesTr/BRATS_318.nii.gz',
   'label': './labelsTr/BRATS_318.nii.gz'},
  {'image': './imagesTr/BRATS_218.nii.gz',
   'label': './labelsTr/BRATS_218.nii.gz'},
  {'image': './

In [8]:
modality = {0: 'FLAIR', 1: 'T1w', 2: 't1gd', 3: 'T2w'}
labels = {0: 'background',
  1: 'edema',
  2: 'non-enhancing tumor',
  3: 'enhancing tumour'}

In [9]:
# test_img = nib.load(f'{train_dir}/BRATS_457.nii.gz')
# test_label = nib.load(f'{label_dir}/BRATS_457.nii.gz')
# test_img_data = test_img.get_fdata()
# test_label_data = test_label.get_fdata()
# test_img_data.shape, test_label_data.shape

In [10]:
def plot_slices(slices):
    """ Function to display row of image slices """
    fig, ax = plt.subplots(5, 3, figsize = (10,10))
    for i, slice in enumerate(slices):
        if i<3:
            ax[0][i].imshow(slice.T, origin="lower")
            ax[0][i].set_xticks([])
            ax[0][i].set_yticks([])
        elif i<6:
            ax[1][i-3].imshow(slice.T, origin="lower")
            ax[1][i-3].set_xticks([])
            ax[1][i-3].set_yticks([])
        elif i<9:
            ax[2][i-6].imshow(slice.T, origin="lower")
            ax[2][i-6].set_xticks([])
            ax[2][i-6].set_yticks([])
        elif i<12:
            ax[3][i-9].imshow(slice.T, origin="lower")
            ax[3][i-9].set_xticks([])
            ax[3][i-9].set_yticks([])
        elif i<15:
            ax[4][i-12].imshow(slice.T, origin="lower")
            ax[4][i-12].set_xticks([])
            ax[4][i-12].set_yticks([])
            
def show_slices(x,y,z,scan_data, label_data):
    slices = [label_data[x,:,:], label_data[:,y,:], label_data[:,:,z]]
    for i in range(4):
        slices = slices + [scan_data[x,:,:,i], scan_data[:,y,:,i], scan_data[:,:,z,i]]
    plot_slices(slices)

# show_slices(51,86,50, test_img_data, test_label_data)
# print('Row 1 is slice of label data')
# for key in modality.keys():
#     print(f'Row {key+2} is slice of test data on the {modality[key]} channel')

In [11]:
for name in glob.glob(f'{train_dir}/BRATS*')[100:105]:
    print(name)
len(glob.glob(f'{train_dir}/BRATS*')), len(glob.glob(f'{label_dir}/BRATS*')), len(glob.glob(f'{test_dir}/BRATS*'))

Task01_BrainTumour/imagesTr\BRATS_101.nii.gz
Task01_BrainTumour/imagesTr\BRATS_102.nii.gz
Task01_BrainTumour/imagesTr\BRATS_103.nii.gz
Task01_BrainTumour/imagesTr\BRATS_104.nii.gz
Task01_BrainTumour/imagesTr\BRATS_105.nii.gz


(484, 484, 266)

# Utils
## Data transform and weight map of classes

In [12]:
def trans(in_tensor, seed = None, hw=120, d=64):
    if seed:
        random.seed(seed)

    in_tensor = torch.permute(in_tensor,(3,2,0,1)) #HxWxDxC --> CxDxHxW
    in_shape = in_tensor.shape
    
    start_d = random.randint(0,in_shape[1] - d)
    start_h = random.randint(0,in_shape[2] - hw)
    start_w = random.randint(0,in_shape[3] - hw)
    
    
    in_tensor = in_tensor[:, start_d: start_d + d, start_h:start_h + hw, start_w:start_w + hw]
    return in_tensor

def weight_map(label_tensor):
    wmap = torch.zeros_like(label_tensor)
    freq = []
    
    for class_label in labels.keys():
        if class_label == 0:
            freq.append(0)
            #continue
        else:
            num_zeros = label_tensor.view(-1).eq(0).sum()
            class_frequency = label_tensor.eq(class_label).sum()/(len(label_tensor.view(-1)-num_zeros))
            
            freq.append(class_frequency)
            #wmap = torch.where(label_tensor.eq(class_label), class_frequency, wmap)
    
    return torch.Tensor(freq) #wmap

# Set up datasets and dataloaders

In [13]:
#validation labels have not been released, so use 20% of train data as test data
train_split = .8

class ImageDataset(Dataset):
    def __init__(self, img_dir, train_split = 0.8, train = True, label_dir = None, transform = None, target_transform = None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transform = transform
        self.target_transform = target_transform
        self.train_split = train_split
        self.train = train
        
    def __len__(self):
        n=self.train_split
        if not self.train:
            n = 1-self.train_split
            
        return int(len(glob.glob(f'{self.img_dir}/BRATS*'))*n)
    
    def __getitem__(self, idx):
        if not self.train:
            idx += int(len(glob.glob(f'{self.img_dir}/BRATS*'))*train_split)
            
        #print(idx)
        #print(f'idx = {idx}')
        img_path = f'{train_dir}/BRATS_00{idx+1}.nii.gz'
        if 9 <= idx < 99:
            img_path = f'{train_dir}/BRATS_0{idx+1}.nii.gz'
        elif idx >= 99:
            img_path = f'{train_dir}/BRATS_{idx+1}.nii.gz'
            
        img = nib.load(img_path).get_fdata()#[40:-40, 40:-40, :, :]

        
        if label_dir:
            label_path = f'{label_dir}/BRATS_00{idx+1}.nii.gz'
            if 9 <= idx < 99:
                label_path = f'{label_dir}/BRATS_0{idx+1}.nii.gz'
            elif idx >= 99:
                label_path = f'{label_dir}/BRATS_{idx+1}.nii.gz'
            
            label = nib.load(label_path).get_fdata()#[40:-40, 40:-40, :]
            if(self.transform):
                img = torch.from_numpy(img)
                label = torch.from_numpy(label)
                label = label.unsqueeze(dim=-1)
                both = torch.cat((img,label),dim=3)
                both = trans(both)#, idx)
                img = both[:-1,:,:,:]
                label = both[-1,:,:,:]

        return img, label

# Building UNet3D

In [14]:
class convBlock(nn.Module):
    def __init__(self, in_channels, h_channels, out_channels, kernel_size=3, stride=1, padding='same'):
        super().__init__()
        self.in_channels = in_channels
        self.h_channels = h_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        self.block1 = nn.Sequential(
            nn.Conv3d(in_channels = self.in_channels, out_channels=self.h_channels, kernel_size = self.kernel_size, stride = self.stride, padding = self.padding),
            nn.BatchNorm3d(num_features = self.h_channels),
            nn.ReLU()
        )
        
        self.block2 = nn.Sequential(
            nn.Conv3d(in_channels = self.h_channels, out_channels=self.out_channels, kernel_size = self.kernel_size, stride = self.stride, padding = self.padding),
            nn.BatchNorm3d(num_features = self.out_channels),
            nn.ReLU()
        )
        
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        return x
    
class upConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 2, stride =2, padding =0):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        self.upconv = nn.ConvTranspose3d(in_channels = self.in_channels, out_channels = self.out_channels, kernel_size=self.kernel_size,stride = self.stride, padding = self.padding)
        
    def forward(self, x):
        return self.upconv(x)

In [27]:
class UNet3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.maxpool = nn.MaxPool3d(kernel_size=2, stride = 2, padding = 0)
        
        self.conv1 = convBlock(in_channels = 4, h_channels = 32, out_channels = 64)
        self.conv2 = convBlock(in_channels = 64, h_channels = 64, out_channels = 128)
        self.conv3 = convBlock(in_channels = 128, h_channels = 128, out_channels = 256)
        self.conv4 = convBlock(in_channels = 256, h_channels = 256, out_channels = 512)
                               
        self.deconv1 = upConvBlock(in_channels = 512, out_channels = 512)
        self.deconv2 = upConvBlock(in_channels = 256, out_channels = 256)
        self.deconv3 = upConvBlock(in_channels = 128, out_channels = 128)
        
        self.conv5 = convBlock(in_channels = 768, h_channels = 256, out_channels = 256)
        self.conv6 = convBlock(in_channels = 384, h_channels = 128, out_channels = 128)
        self.conv7 = convBlock(in_channels = 192, h_channels = 64, out_channels = 64)
        
        self.last_conv = nn.Conv3d(in_channels = 64, out_channels = 4, kernel_size=3, stride = 1, padding = 1)
    
    
    def forward(self, x):

        x = self.conv1(x)
        skip1 = x
        print(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        skip2 = x
        print(x)
        x = self.maxpool(x)
        x = self.conv3(x)
        skip3 = x
        print(x)
        x = self.maxpool(x)
        x = self.conv4(x)
        x = self.deconv1(x)
        print(x)
        x = torch.cat((skip3,x), dim=1)
        x = self.conv5(x)
        x = self.deconv2(x)
        print(x)
        x = torch.cat((skip2, x), dim=1)
        x = self.conv6(x)
        x = self.deconv3(x)
        print(f'BEFORE LAST BLOCK: {x}') 
        x = torch.cat((skip1,x), dim=1)
        x = self.conv7(x)
        x = self.last_conv(x)
        print('FINAL OUTPUT')
        return x

In [16]:
torch.cuda.empty_cache()
!nvidia-smi

Sun Feb  4 11:59:55 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 537.34                 Driver Version: 537.34       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4070      WDDM  | 00000000:01:00.0  On |                  N/A |
|  0%   33C    P8               8W / 200W |    305MiB / 12282MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# Training

In [28]:
EPOCHS = 15#00
MOMENTUM = 0.99
LR = 3e-4
BATCH_SIZE = 1


train_data = ImageDataset(img_dir = train_dir, train_split = 0.8, train = True, label_dir = label_dir,
                      transform = trans)
test_data = ImageDataset(img_dir = train_dir, train_split = 0.8, train = False, label_dir = label_dir,
                      transform = trans)

train_dataloader = DataLoader(train_data, batch_size = BATCH_SIZE, shuffle = True)
test_dataloader = DataLoader(test_data, batch_size = BATCH_SIZE, shuffle = False)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
m = UNet3D().to(device, dtype = torch.double)
optimizer = torch.optim.SGD(params = m.parameters(), lr = LR, momentum = MOMENTUM)

In [31]:
# t = torch.randint(0,500, size= (1,4,64,120,120)).to(device, dtype = torch.double)
# m(t)

In [30]:
from tqdm.auto import tqdm
#eval_epoch = 25
for epoch in tqdm(range(EPOCHS)):
    train_loss = 0
    for _, (x,y) in enumerate(train_dataloader):
        
        m.train()
        
        x = x.to(device, dtype = torch.double)
        y = y.to(device, dtype = torch.double)
        
        w = weight_map(y).to(device, dtype = torch.double)
        #print(w)
        optimizer.zero_grad()
        loss_fn = nn.CrossEntropyLoss(weight = w)
        
        print(x.shape)
        print(x)
        preds = m(x)
        print(preds.shape)
        print(preds)
        loss = loss_fn(preds, y.type(torch.long))
        train_loss += loss.detach().item()

        loss.backward()
        optimizer.step()
        
        
        train_loss /= len(train_dataloader)
        
    test_loss = 0
    m.eval()
    with torch.inference_mode():
        for x,y in test_dataloader:
            x = x.to(device, dtype = torch.double)
            y = y.to(device, dtype = torch.double)
            
            w = weight_map(y).to(device, dtype = torch.double)
            loss_fn = nn.CrossEntropyLoss(weight = w)
            
            test_preds = m(x)
            print(test_preds.shape)
            print(test_preds)
            loss = loss_fn(test_preds, y.type(torch.long))
            print(loss)
            test_loss += loss.detach().item()
            print(test_loss)


        test_loss /= len(test_dataloader)
            
    print(f'Epoch: {epoch} | Train loss: {train_loss} | Test loss: {test_loss}')


  0%|          | 0/15 [00:00<?, ?it/s]

torch.Size([1, 4, 64, 120, 120])
tensor([[[[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
           [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
           [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
           ...,
           [121., 142., 138.,  ...,   0.,   0.,   0.],
           [101., 129., 115.,  ...,   0.,   0.,   0.],
           [112., 125., 108.,  ...,   0.,   0.,   0.]],

          [[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
           [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
           [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
           ...,
           [133., 148., 141.,  ...,   0.,   0.,   0.],
           [104., 121., 123.,  ...,   0.,   0.,   0.],
           [127., 133., 121.,  ...,   0.,   0.,   0.]],

          [[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
           [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
           [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
           ...,
           [146., 154., 145.,  ...,   0.,   0.,   0.],
           [109., 114., 128.,  ...,

OutOfMemoryError: CUDA out of memory. Tried to allocate 450.00 MiB. GPU 0 has a total capacty of 11.99 GiB of which 0 bytes is free. Of the allocated memory 25.01 GiB is allocated by PyTorch, and 89.08 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

# UNet 2D

In [2]:
class conv_block_2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=0):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        
        self.layer = nn.Sequential([nn.Conv2d(in_channels = self.in_channels, out_channels = self.out_channels, kernel_size = self.kernel_size, padding = self.padding),
                                           nn.BatchNorm2d(num_features = self.out_channels),
                                           nn.ReLU(),
                                           nn.Conv2d(in_channels = self.out_channels, out_channels = self.out_channels,kernel_size = self.kernel_size, padding = self.padding),
                                           nn.BatchNorm2d(num_features = self.out_channels),
                                           nn.ReLU()])
        
    def forward(self, x):
        return self.layer(x)

In [None]:
class Encoder_2D(nn.Module):
    def __init__(self, in_channels=3, padding = 0):
        super().__init__()
        self.in_channels = in_channels
        self.padding = padding
        
        self.conv_block_1 = conv_block_2D(in_channels = in_channels, out_channels = 64, kernel_size=3, padding = padding)
        
        self.conv_block_2 = conv_block_2D(in_channels = 64, out_channels = 128, kernel_size=3, padding = padding)
        
        self.conv_block_3 = conv_block_2D(in_channels = 128, out_channels = 256, kernel_size=3, padding = padding)
        
        self.conv_block_4 = conv_block_2D(in_channels = 256, out_channels = 512, kernel_size=3, padding = padding)
        
        self.bottleneck_conv = conv_block_2D(in_channels = 512, out_channels = 124, kernel_size=3, padding = padding)
        
        self.mp = nn.MaxPool2d(kernel_size = 2, stride = 2)
        
    def forward(self,x):
        x = self.conv_block_1(x)
        skip4 = x
        
        x = self.mp(x)
        x = self.conv_block_2(x)
        skip3 = x
        
        x = self.mp(x)
        x = self.conv_block_3(x)
        skip2 = x
        
        x = self.mp(x)
        x = self.conv_block_4(x)
        skip1 = x
        
        x = self.bottleneck_conv(x)
        
        return x, skip1, skip2, skip3, skip4
    

class Decoder_2D(nn.Module):
    def __init__(self):
        super().__init__()
        