<a href="https://colab.research.google.com/github/lglass/healthcare-hackathon-2020/blob/master/hh2020_basis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --quiet torchio
!pip install --quiet nibabel

In [2]:
import torch
import torchio
import random
import numpy as np
from pathlib import Path
from torchio import Subject, Image, INTENSITY
from torchio.datasets.ixi import sglob, get_subject_id
from torchio import (
    RandomMotion,
    RandomGhosting,
    RandomBlur,
    RandomNoise,
    Compose
)

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Link: https://arxiv.org/abs/2003.04696



In [3]:
# https://colab.research.google.com/drive/112NTL8uJXzcMw4PQbUvMQN-WHlVwQS3i#scrollTo=b0NdJiFW3Uy7
# Dataset
dataset_url = 'https://www.dropbox.com/s/ogxjwjxdv5mieah/ixi_tiny.zip?dl=0'
dataset_path = 'ixi_tiny.zip'
dataset_dir_name = 'ixi_tiny'
dataset_dir = Path(dataset_dir_name)
if not dataset_dir.is_dir():
    !curl --silent --output {dataset_path} --location {dataset_url} 
    !unzip -qq {dataset_path}

In [4]:


def simulate_artefacts(sample, artefacts=(0, 0, 0, 0)):
    transforms = []
    if artefacts[0] == 1:
        if random.random() < 0.5:
            degrees = (5, random.random()*10 + 5)
        else:
            degrees = (-(random.random()*10 + 5), -5)
        if random.random() < 0.5:
            translation = (5, random.random()*10 + 5)
        else:
            translation = (-(random.random()*10 + 5), -5)
        transforms.append(RandomMotion(degrees=degrees, translation=translation, num_transforms=random.randint(2, 15)))
    elif artefacts[1] == 1:
        transforms.append(RandomGhosting(num_ghosts=(2, 10), intensity=(0.5, 0.75)))
    elif artefacts[2] == 1:
        transforms.append(RandomBlur(std=(0.5, 2.)))
    elif artefacts[3] == 1:
        transforms.append(RandomNoise(std=(0.01, 0.05)))
    transforms = Compose(transforms)
    return transforms(sample)


class MRIQADataset(torchio.datasets.IXI):
    # overrride method to filter for Hammersmith Hospital data
    @staticmethod
    def _get_subjects_list(root, modalities):
        one_modality = modalities[0]
        paths = sglob(root / one_modality, '*.nii.gz')
        subjects = []
        for filepath in paths:
            subject_id = get_subject_id(filepath)
            images_dict = dict(subject_id=subject_id)
            images_dict[one_modality] = Image(filepath, INTENSITY)
            for modality in modalities[1:]:
                globbed = sglob(
                    root / modality, f'{subject_id}-{modality}.nii.gz')
                if globbed:
                    assert len(globbed) == 1
                    images_dict[modality] = Image(globbed[0], INTENSITY)
                else:
                    skip_subject = True
                    break
            else:
                skip_subject = False
            if '-HH-' not in images_dict['subject_id']:
                skip_subject = True
            if skip_subject:
                continue
            subjects.append(Subject(**images_dict))
        return subjects

    # 2D training samples
    def __getitem__(self, index: int):
        if not isinstance(index, int):
            raise ValueError(f'Index "{index}" must be int, not {type(index)}')
        subject = self.subjects[index]
        sample = self._get_sample_dict_from_subject(subject)

        # choose random MR contrast
        if random.random() < 0.5:
            sample = sample['T1'].data
        else:
            sample = sample['T2'].data

        # normalization
        sample -= torch.min(sample)
        sample /= torch.max(sample)

        # choose random slice of volume
        margin = int(sample.shape[-1] * 0.1)
        slice_number = random.randint(margin, sample.shape[-1] - margin)
        sample = sample[..., slice_number].unsqueeze(dim=0)

        # Apply random artefact (or not)
        artefact = torch.zeros(5)
        if random.random() > 0.2:
            artefact[np.random.randint(4)] = 1
            sample = simulate_artefacts(sample, artefact)
        else:
            artefact[-1] = 1
        _, label = artefact.max(0)
        # Apply random combination of artefacts
        #label = np.random.choice([0, 1], size=4)
        #sample = simulate_artefacts(sample, label)
        sample = sample.squeeze(dim=0)
        return sample, label

In [5]:
# ----- NETWORK -----

In [6]:
import torch
import torch.nn as nn


class ClassicCNN(nn.Module):
    def __init__(self, input_channels=1, num_classes=4, batch_size=8):  # original batch_size=8
        super(ClassicCNN, self).__init__()
        self.batch_size = batch_size
        self.num_classes = num_classes

        self.block1 = self.classic_cnn_block(input_channels, 32, 3, stride=2)
        self.block2 = self.classic_cnn_block(32, 64, 3, stride=2)
        self.maxpool1 = nn.MaxPool2d(2)
        self.block3 = self.classic_cnn_block(64, 128, 3)
        self.block4 = self.classic_cnn_block(128, 256, 3, stride=2)
        self.maxpool2 = nn.MaxPool2d(2)
        self.block5 = self.classic_cnn_block(256, 128, 3)
        self.block6 = self.classic_cnn_block(128, 64, 3, stride=2)
        self.block7 = nn.Conv2d(64, self.num_classes, 1)


    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.maxpool1(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.maxpool2(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = torch.mean(x, dim=(2, 3))
        return x

    def classic_cnn_block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                      padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )


# add you own networks :)

In [7]:
# ----- TRAINING -----

In [8]:
# GPU choice
# import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import random
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from tqdm import tqdm




# set random seeds for reproducibility
random.seed(21062020)
np.random.seed(21062020)
torch.manual_seed(21062020)
torch.cuda.manual_seed(21062020)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def train():
    num_epochs = 20  #original 500
    batch_size = 8 #original 8
    loss_csv = open('losses.csv', 'w')
    loss_csv.write('epoch,training,validation\n')

    # create dataset (automatically downloads IXI at first run)
    dataset = MRIQADataset(
        '/Z/AI/',    # path to save data to
        modalities=('T1', 'T2'),
        download=True,
    )

    # split data into training and validation sets
    train_set, validation_set = torch.utils.data.random_split(dataset, (100, 85))
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=True)

    net = ClassicCNN(num_classes=5)
    net = net.cuda()
    optimizer = optim.Adam(net.parameters())
    ce = CrossEntropyLoss().cuda()

    num_mini_batches = len(train_loader)
    best_val_loss = 999
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        net.train()

        # train loop
        pbar = tqdm(train_loader, total=len(train_loader))
        for sample, label in pbar:
            sample = sample.cuda()
            label = label.cuda()

            prediction = net(sample)
            loss = ce(prediction, label)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
        print('[{}] train-loss: {}'.format(epoch, epoch_loss / num_mini_batches))
        loss_csv.write(str(epoch) + ',' + str(epoch_loss / num_mini_batches))
        loss_csv.flush()

        # validation loop
        net.eval()
        mean_validation_loss = 0
        num_validation_mini_batches = len(validation_loader)
        with torch.no_grad():
            pbar = tqdm(validation_loader, total=len(validation_loader))
            for sample, label in pbar:
                sample = sample.cuda()
                label = label.cuda()

                prediction = net(sample)
                validation_loss = ce(prediction, label)

                mean_validation_loss += validation_loss.item()
            print('[{}] validation-loss: {}'.format(epoch, mean_validation_loss / num_validation_mini_batches))
            loss_csv.write(',' + str(mean_validation_loss / num_validation_mini_batches) + '\n')
            loss_csv.flush()
        # save best model
        if mean_validation_loss <= best_val_loss:
            torch.save({'epoch': epoch,
                        'model_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': validation_loss.item()}, 'checkpoint_best')
            best_val_loss = mean_validation_loss
    print('DONE.')

    
if __name__ == '__main__':
    train()

 15%|█▌        | 2/13 [00:08<00:44,  4.03s/it]

KeyboardInterrupt: ignored