In [1]:
import logging
import os
import sys

import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader

import monai
from monai.data import ImageDataset
from monai.transforms import AddChannel, Compose, RandRotate90, Resize, ScaleIntensity, ToTensor

 ---- Data handeling ----

In [15]:
# Set to true if Data was not downloaded yet. Set to false afterwards!
# 4.5 GB / ~ 600 Pictures (3D) / 1 excel file for labels
DOWNLOAD = False

if DOWNLOAD:
    data_url = 'http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T1.tar'
    compressed_file = os.sep.join(['Data', 'IXI-T1.tar'])
    data_dir = os.sep.join(['Data', 'IXI-T1'])
    
    # Data download
    monai.apps.download_and_extract(data_url, compressed_file, './Data/IXI-T1')

    # Labels document download
    labels_url = 'http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI.xls'
    monai.apps.download_url(labels_url, './Data/IXI.xls')

In [3]:
images = [impath for impath in os.listdir('./Data/IXI-T1')]

demographic_info = pd.read_excel('./Data/IXI.xls')

# Getting lables TODO: Implement multiple possible labels, not just sex.
def make_labeled_data(df, images):
    data = []
    labels = []
    for i in images:
        ixi_id = int(i[3:6])
        row = df.loc[df['IXI_ID'] == ixi_id]
        if not row.empty:
            data.append(os.sep.join(['Data', 'IXI-T1', i]))
            labels.append(row.iat[0, 1] - 1) # Sex labels are 1/2 but need to be 0/1
         
    return data, labels


data, labels = make_labeled_data(demographic_info, images)

# Train - Test split
TEST_SIZE = 0.2 # How much percent of the data should be Test Data
SAMPLE_SIZE = 0.1 # How much of the whole data should be used (1.0 = 566 pictures, 0.1 = 56 pictures)
BATCH_SIZE = 2

def train_test_split(data, labels):
    size = int(len(data) * SAMPLE_SIZE)
    split = int(size * TEST_SIZE)
    
    test_data = data[:split]
    train_data = data[split:size]
    
    test_labels = labels[:split]
    train_labels = labels[split:size]
    
    return train_data, train_labels, test_data, test_labels
    
    
train_data, train_labels, test_data, test_labels = train_test_split(data, labels)

In [4]:
# Define transforms
train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), ToTensor()])
val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor()])

# Define image dataset, data loader
check_ds = ImageDataset(image_files=data, labels=labels, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=BATCH_SIZE, num_workers=2)
im, label = monai.utils.misc.first(check_loader)
print(type(im), im.shape, label)

# Create a training data loader
train_ds = ImageDataset(image_files=train_data, labels=train_labels, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

# Create a validation data loader
val_ds = ImageDataset(image_files=test_data, labels=test_labels, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, num_workers=2)

<class 'torch.Tensor'> torch.Size([2, 1, 96, 96, 96]) tensor([1, 0])


 ---- Model Specification ----
 
 

In [5]:
# Create DenseNet, CrossEntropyLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.DenseNet(spatial_dims=3, in_channels=1, out_channels=2).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)

 ---- Model Training ----

In [6]:
EPOCHS  = 5

val_interval = 2
best_metric = -1
best_metric_epoch = -1
metric_values = list()

# Iterate through Epochs
for epoch in range(EPOCHS):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{5}")
    model.train()
    epoch_loss = 0
    step = 0
    
    # Iterate over the batches
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        
    epoch_loss /= step
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    # Evaluate the current model in regular interval
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            num_correct = 0.0
            metric_count = 0
            for val_data in val_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                val_outputs = model(val_images)
                value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                metric_count += len(value)
                num_correct += value.sum().item()
            metric = num_correct / metric_count
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "./Data/best_metric_model_classification3d_array.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")

----------
epoch 1/5
1/22, train_loss: 0.7606
2/22, train_loss: 0.4488
3/22, train_loss: 0.4486
4/22, train_loss: 0.4422
5/22, train_loss: 1.0372
6/22, train_loss: 1.0491
7/22, train_loss: 0.7550
8/22, train_loss: 0.9360
9/22, train_loss: 0.7819
10/22, train_loss: 1.0338
11/22, train_loss: 0.4390
12/22, train_loss: 1.0271
13/22, train_loss: 1.0085
14/22, train_loss: 0.7754
15/22, train_loss: 0.7408
16/22, train_loss: 0.7523
17/22, train_loss: 1.0160
18/22, train_loss: 0.4548
19/22, train_loss: 0.7242
20/22, train_loss: 0.4618
21/22, train_loss: 0.9993
22/22, train_loss: 0.6784
23/22, train_loss: 0.9695
epoch 1 average loss: 0.7713
----------
epoch 2/5
1/22, train_loss: 0.7916
2/22, train_loss: 0.6175
3/22, train_loss: 0.4732
4/22, train_loss: 0.7396
5/22, train_loss: 0.4812
6/22, train_loss: 0.6898
7/22, train_loss: 0.9260
8/22, train_loss: 0.4619
9/22, train_loss: 0.9864
10/22, train_loss: 0.5518
11/22, train_loss: 0.6121
12/22, train_loss: 0.9588
13/22, train_loss: 0.5789
14/22, trai