## Data Preprocessing

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import nibabel as nib

In [2]:
datapath = 'data/'

### Dataset Creation

In [3]:
# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
class LVDataset(Dataset):
    # animal_nums is a list of numbers indicating which animals' images to include in the dataset
    
    # to try: tuple(zip(images, masks))
    def __init__(self, animal_nums, transform=None):
        image_folders = []
        types = ['Baseline', 'PostGel', 'PostMI']
        for num in animal_nums:
            for type in types:
                image_folders.append(datapath + 'PSEA' + str(num) + ' ' + type + '/')
        self.image_depth = 37
        self.images = []
        self.masks = []
        for folder in image_folders:
            files = np.array(os.listdir(folder))
            images = np.sort(files[[('Mask' not in name and name != '.DS_Store') for name in files]])
            images = [folder + image for image in images]
            self.images.extend(images)
            
            masks = np.sort(files[['Mask' in name for name in files]])
            masks = [folder + mask for mask in masks]
            self.masks.extend(masks)

        if len(self.images) != len(self.masks):
            print('Different number of images and masks')

        self.transform = transform
    def __len__(self):
        return len(self.images) * self.image_depth
    
    # input must be a list
    # returns tensors that represent images/masks
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image = nib.load(self.images[idx // self.image_depth]).get_fdata()
        image = image[:,:,idx % self.image_depth]
        
        mask = nib.load(self.masks[idx // self.image_depth]).get_fdata()
        mask = mask[:,:,idx % self.image_depth]
        sample = {'image': image, 'mask': mask}
        
        if self.transform:
            sample = self.transform(sample)

        return (sample['image'], sample['mask'])

### Transforms
- patching
- normalization
- slicing

In [4]:
# To use the full U-Net without losing data,
# we need the dimensions to be a multiple of 16
class CropTensor(object):
    def __init__(self, output_size):
        self.output_size = output_size
        
    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']
        orig_shape = list(image.shape)
        start_ind = orig_shape[0] // 2 - self.output_size // 2
        end_ind = orig_shape[0] // 2 + self.output_size // 2
        
        # Channels for cross entropy loss
        image = image[np.newaxis, start_ind:end_ind, start_ind:end_ind]
        mask = mask[start_ind:end_ind, start_ind:end_ind]
        return {'image': image, 'mask': mask}

### Train Test Split

In [5]:
print(os.listdir(datapath))

['PSEA12 Baseline', 'PSEA12 PostGel', 'PSEA12 PostMI', 'PSEA13 Baseline', 'PSEA13 PostGel', 'PSEA13 PostMI', 'PSEA14 Baseline', 'PSEA14 PostGel', 'PSEA18 Baseline', 'PSEA18 PostGel', 'PSEA18 PostMI', 'PSEA20 PostGel', 'PSEA20 PostMI', 'PSEA25 Baseline', 'PSEA25 PostGel', 'PSEA25 PostMI', 'PSEA27 Baseline', 'PSEA27 PostGel', 'PSEA27 PostMI']


In [6]:
all_animals = [12, 13, 18, 25, 27]

In [7]:
train_set = LVDataset(all_animals[:4], transform=CropTensor(288))
test_set = LVDataset(all_animals[5:], transform=CropTensor(288))

In [8]:
train_set = LVDataset([12], transform=CropTensor(288))
test_set = LVDataset([13], transform=CropTensor(288))

In [9]:
print('Size of training set: ', len(train_set))
print('Size of test set: ', len(test_set))
print('Image Shape: ', train_set[0][0].shape)
print('Mask Shape: ', train_set[0][1].shape)

Size of training set:  1110
Size of test set:  1110
Image Shape:  (1, 288, 288)
Mask Shape:  (288, 288)


## Model

In [10]:
# modelled after the original UNet
class UNet(nn.Module):
    # This conv/relu combination results in no change in dimension for full image restoration
    def conv_relu(self, in_channels, out_channels, kernel_size=3, padding=1, padding_mode='reflect'):
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                padding=padding,
                padding_mode=padding_mode
            ),
            nn.ReLU()
        )
    
    # This transpose doubles the dimensions
    def conv_transpose(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
        return nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding
        )
    
    def first_block(self, in_channels, out_channels):
        return nn.Sequential(
            self.conv_relu(in_channels, out_channels),
            self.conv_relu(out_channels, out_channels)
        )
    
    # Output: (x-4)/2
    def contract_block(self, in_channels, out_channels):
        # Testing: adding BatchNorm2d(out_channels) after ReLU layers
        return nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            self.conv_relu(in_channels, out_channels),
            self.conv_relu(out_channels, out_channels)
        )
    
    def bottleneck_block(self, in_channels, mid_channels, out_channels):
        return nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            self.conv_relu(in_channels, mid_channels),
            self.conv_relu(mid_channels, mid_channels),
            self.conv_transpose(mid_channels, out_channels)
        )
        
    # Output: (x-4)*2
    def expand_block(self, in_channels, mid_channels, out_channels):
        return nn.Sequential(
            self.conv_relu(in_channels, mid_channels),
            self.conv_relu(mid_channels, mid_channels),
            self.conv_transpose(mid_channels, out_channels)
        )

    def final_block(self, in_channels, mid_channels, out_channels):
        return nn.Sequential(
            self.conv_relu(in_channels, mid_channels),
            self.conv_relu(mid_channels, mid_channels),
            nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
        )
    
    def __init__(self):
        super().__init__()
        self.contraction = nn.ModuleList([
            # 288
            self.first_block(1, 64),
            # 288
            self.contract_block(64, 128),
            # 144
            self.contract_block(128, 256),
            # 72
            self.contract_block(256, 512),
            # 36
        ])
        
        self.bottleneck = self.bottleneck_block(512, 1024, 512)
        
        self.expansion = nn.ModuleList([
            # 36
            self.expand_block(1024, 512, 256),
            # 72
            self.expand_block(512, 256, 128),
            # 144
            self.expand_block(256, 128, 64),
            # 288
            self.final_block(128, 64, 2)
            # 288
        ])
        
        self.contraction_outputs = []

    def forward(self, image):
        for layer in self.contraction:
            image = layer(image)
            self.contraction_outputs.append(image)
        
        image = self.bottleneck(image)
        for i in range(4):
            image = torch.cat((self.contraction_outputs[3 - i], image), dim=1)
            image = self.expansion[i](image)
        self.contraction_outputs = []
        return image

## Training

### Loop

In [11]:
def train_model(model, optimizer, criterion, epochs):
    writer = SummaryWriter()

    for epoch in range(epochs):
        print('EPOCH', epoch)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            total_loss = 0
            for batch_id, (image, mask) in enumerate(dataloaders[phase]):
                image = image.to(device, dtype=torch.float)
                mask = mask.to(device, dtype=torch.long)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    output_mask = model(image)
                    loss = criterion(output_mask, mask)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                total_loss += loss.item()
                
                del image
                del mask
                del output_mask
                del loss
                torch.cuda.empty_cache()
            
            print(phase, ' loss: ', total_loss)
            writer.add_scalar('loss/' + phase, total_loss, epoch)

## Training

In [12]:
NUM_EPOCHS = 100
LEARNING_RATE = 0.0001
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 64

In [13]:
train_loader = DataLoader(
    train_set,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
)

test_loader = DataLoader(
    test_set,
    batch_size=TEST_BATCH_SIZE,
    shuffle=True,
)

dataloaders = {'train': train_loader, 'val': test_loader}

In [14]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
model = UNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

cuda:0


In [15]:
train_model(model, optimizer, criterion, NUM_EPOCHS)

EPOCH 0


RuntimeError: CUDA out of memory. Tried to allocate 1.27 GiB (GPU 0; 5.94 GiB total capacity; 3.99 GiB already allocated; 614.06 MiB free; 4.01 GiB reserved in total by PyTorch)