In [6]:
import os

import pandas as pd
import numpy as np

import torch
from torch.utils.data import DataLoader
import torchvision

from opacus import PrivacyEngine

import monai
from monai.data import ImageDataset
from monai.transforms import AddChannel, Compose, RandRotate90, Resize, ScaleIntensity, ToTensor

# ---- Data handeling ----

In [2]:
# Set to true if Data was not downloaded yet. Set to false afterwards!
# 4.5 GB / ~ 600 Pictures (3D) / 1 excel file for labels
DOWNLOAD = False

if DOWNLOAD:
    data_url = 'http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T1.tar'
    compressed_file = os.sep.join(['Data', 'IXI-T1.tar'])
    data_dir = os.sep.join(['Data', 'IXI-T1'])
    
    # Data download
    monai.apps.download_and_extract(data_url, compressed_file, './Data/IXI-T1')

    # Labels document download
    labels_url = 'http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI.xls'
    monai.apps.download_url(labels_url, './Data/IXI.xls')

In [3]:
images = [impath for impath in os.listdir('./Data/IXI-T1')]

demographic_info = pd.read_excel('./Data/IXI.xls')

# Getting lables TODO: Implement multiple possible labels, not just sex.
def make_labeled_data(df, images):
    data = []
    labels = []
    for i in images:
        ixi_id = int(i[3:6])
        row = df.loc[df['IXI_ID'] == ixi_id]
        if not row.empty:
            data.append(os.sep.join(['Data', 'IXI-T1', i]))
            labels.append(row.iat[0, 1] - 1) # Sex labels are 1/2 but need to be 0/1
         
    return data, labels


data, labels = make_labeled_data(demographic_info, images)

# Train - Test split
TEST_SIZE = 0.2 # How much percent of the data should be Test Data
SAMPLE_SIZE = 0.05 # How much of the whole data should be used (1.0 = 566 pictures, 0.1 = 56 pictures)
BATCH_SIZE = 2

def train_test_split(data, labels):
    size = int(len(data) * SAMPLE_SIZE)
    split = int(size * TEST_SIZE)
    
    test_data = data[:split]
    train_data = data[split:size]
    
    test_labels = labels[:split]
    train_labels = labels[split:size]
    
    return train_data, train_labels, test_data, test_labels
    
    
train_data, train_labels, test_data, test_labels = train_test_split(data, labels)

In [4]:
# Define transforms
train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), ToTensor()])
val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor()])

# Create a training data loader
train_ds = ImageDataset(image_files=train_data, labels=train_labels, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, pin_memory=torch.cuda.is_available())

# Create a validation data loader
val_ds = ImageDataset(image_files=test_data, labels=test_labels, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, num_workers=1, pin_memory=torch.cuda.is_available())

 # ---- Model Specification ----
 
 

In [5]:
# Create DenseNet, CrossEntropyLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.densenet(spatial_dims=3, in_channels=1, out_channels=2).to(device)
# TODO create own model !!!!

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)

# Adding differential Privacy
privacy_engine = PrivacyEngine(model, 
                              batch_size=BATCH_SIZE,
                              sample_size=len(train_data),
                              alphas=range(2, 16),
                              noise_multiplier=1.3,
                              max_grad_norm=1.0)
privacy_engine.attach(optimizer)

The sample rate will be defined from ``batch_size`` and ``sample_size``.The returned privacy budget will be incorrect.
Secure RNG turned off. This is perfectly fine for experimentation as it allows for much faster training performance, but remember to turn it on and retrain one last time before production with ``secure_rng`` turned on.


IncompatibleModuleException: Model contains incompatible modules.
Some modules are not valid.: ['Main.features.norm0', 'Main.features.denseblock1.denselayer1.layers.norm1', 'Main.features.denseblock1.denselayer1.layers.norm2', 'Main.features.denseblock1.denselayer2.layers.norm1', 'Main.features.denseblock1.denselayer2.layers.norm2', 'Main.features.denseblock1.denselayer3.layers.norm1', 'Main.features.denseblock1.denselayer3.layers.norm2', 'Main.features.denseblock1.denselayer4.layers.norm1', 'Main.features.denseblock1.denselayer4.layers.norm2', 'Main.features.denseblock1.denselayer5.layers.norm1', 'Main.features.denseblock1.denselayer5.layers.norm2', 'Main.features.denseblock1.denselayer6.layers.norm1', 'Main.features.denseblock1.denselayer6.layers.norm2', 'Main.features.transition1.norm', 'Main.features.denseblock2.denselayer1.layers.norm1', 'Main.features.denseblock2.denselayer1.layers.norm2', 'Main.features.denseblock2.denselayer2.layers.norm1', 'Main.features.denseblock2.denselayer2.layers.norm2', 'Main.features.denseblock2.denselayer3.layers.norm1', 'Main.features.denseblock2.denselayer3.layers.norm2', 'Main.features.denseblock2.denselayer4.layers.norm1', 'Main.features.denseblock2.denselayer4.layers.norm2', 'Main.features.denseblock2.denselayer5.layers.norm1', 'Main.features.denseblock2.denselayer5.layers.norm2', 'Main.features.denseblock2.denselayer6.layers.norm1', 'Main.features.denseblock2.denselayer6.layers.norm2', 'Main.features.denseblock2.denselayer7.layers.norm1', 'Main.features.denseblock2.denselayer7.layers.norm2', 'Main.features.denseblock2.denselayer8.layers.norm1', 'Main.features.denseblock2.denselayer8.layers.norm2', 'Main.features.denseblock2.denselayer9.layers.norm1', 'Main.features.denseblock2.denselayer9.layers.norm2', 'Main.features.denseblock2.denselayer10.layers.norm1', 'Main.features.denseblock2.denselayer10.layers.norm2', 'Main.features.denseblock2.denselayer11.layers.norm1', 'Main.features.denseblock2.denselayer11.layers.norm2', 'Main.features.denseblock2.denselayer12.layers.norm1', 'Main.features.denseblock2.denselayer12.layers.norm2', 'Main.features.transition2.norm', 'Main.features.denseblock3.denselayer1.layers.norm1', 'Main.features.denseblock3.denselayer1.layers.norm2', 'Main.features.denseblock3.denselayer2.layers.norm1', 'Main.features.denseblock3.denselayer2.layers.norm2', 'Main.features.denseblock3.denselayer3.layers.norm1', 'Main.features.denseblock3.denselayer3.layers.norm2', 'Main.features.denseblock3.denselayer4.layers.norm1', 'Main.features.denseblock3.denselayer4.layers.norm2', 'Main.features.denseblock3.denselayer5.layers.norm1', 'Main.features.denseblock3.denselayer5.layers.norm2', 'Main.features.denseblock3.denselayer6.layers.norm1', 'Main.features.denseblock3.denselayer6.layers.norm2', 'Main.features.denseblock3.denselayer7.layers.norm1', 'Main.features.denseblock3.denselayer7.layers.norm2', 'Main.features.denseblock3.denselayer8.layers.norm1', 'Main.features.denseblock3.denselayer8.layers.norm2', 'Main.features.denseblock3.denselayer9.layers.norm1', 'Main.features.denseblock3.denselayer9.layers.norm2', 'Main.features.denseblock3.denselayer10.layers.norm1', 'Main.features.denseblock3.denselayer10.layers.norm2', 'Main.features.denseblock3.denselayer11.layers.norm1', 'Main.features.denseblock3.denselayer11.layers.norm2', 'Main.features.denseblock3.denselayer12.layers.norm1', 'Main.features.denseblock3.denselayer12.layers.norm2', 'Main.features.denseblock3.denselayer13.layers.norm1', 'Main.features.denseblock3.denselayer13.layers.norm2', 'Main.features.denseblock3.denselayer14.layers.norm1', 'Main.features.denseblock3.denselayer14.layers.norm2', 'Main.features.denseblock3.denselayer15.layers.norm1', 'Main.features.denseblock3.denselayer15.layers.norm2', 'Main.features.denseblock3.denselayer16.layers.norm1', 'Main.features.denseblock3.denselayer16.layers.norm2', 'Main.features.denseblock3.denselayer17.layers.norm1', 'Main.features.denseblock3.denselayer17.layers.norm2', 'Main.features.denseblock3.denselayer18.layers.norm1', 'Main.features.denseblock3.denselayer18.layers.norm2', 'Main.features.denseblock3.denselayer19.layers.norm1', 'Main.features.denseblock3.denselayer19.layers.norm2', 'Main.features.denseblock3.denselayer20.layers.norm1', 'Main.features.denseblock3.denselayer20.layers.norm2', 'Main.features.denseblock3.denselayer21.layers.norm1', 'Main.features.denseblock3.denselayer21.layers.norm2', 'Main.features.denseblock3.denselayer22.layers.norm1', 'Main.features.denseblock3.denselayer22.layers.norm2', 'Main.features.denseblock3.denselayer23.layers.norm1', 'Main.features.denseblock3.denselayer23.layers.norm2', 'Main.features.denseblock3.denselayer24.layers.norm1', 'Main.features.denseblock3.denselayer24.layers.norm2', 'Main.features.transition3.norm', 'Main.features.denseblock4.denselayer1.layers.norm1', 'Main.features.denseblock4.denselayer1.layers.norm2', 'Main.features.denseblock4.denselayer2.layers.norm1', 'Main.features.denseblock4.denselayer2.layers.norm2', 'Main.features.denseblock4.denselayer3.layers.norm1', 'Main.features.denseblock4.denselayer3.layers.norm2', 'Main.features.denseblock4.denselayer4.layers.norm1', 'Main.features.denseblock4.denselayer4.layers.norm2', 'Main.features.denseblock4.denselayer5.layers.norm1', 'Main.features.denseblock4.denselayer5.layers.norm2', 'Main.features.denseblock4.denselayer6.layers.norm1', 'Main.features.denseblock4.denselayer6.layers.norm2', 'Main.features.denseblock4.denselayer7.layers.norm1', 'Main.features.denseblock4.denselayer7.layers.norm2', 'Main.features.denseblock4.denselayer8.layers.norm1', 'Main.features.denseblock4.denselayer8.layers.norm2', 'Main.features.denseblock4.denselayer9.layers.norm1', 'Main.features.denseblock4.denselayer9.layers.norm2', 'Main.features.denseblock4.denselayer10.layers.norm1', 'Main.features.denseblock4.denselayer10.layers.norm2', 'Main.features.denseblock4.denselayer11.layers.norm1', 'Main.features.denseblock4.denselayer11.layers.norm2', 'Main.features.denseblock4.denselayer12.layers.norm1', 'Main.features.denseblock4.denselayer12.layers.norm2', 'Main.features.denseblock4.denselayer13.layers.norm1', 'Main.features.denseblock4.denselayer13.layers.norm2', 'Main.features.denseblock4.denselayer14.layers.norm1', 'Main.features.denseblock4.denselayer14.layers.norm2', 'Main.features.denseblock4.denselayer15.layers.norm1', 'Main.features.denseblock4.denselayer15.layers.norm2', 'Main.features.denseblock4.denselayer16.layers.norm1', 'Main.features.denseblock4.denselayer16.layers.norm2', 'Main.features.norm5']

# ---- Model Training ----

In [None]:
EPOCHS  = 5
DELTA = 1e-5

val_interval = 2
best_metric = -1
best_metric_epoch = -1
metric_values = list()

# Iterate through Epochs
for epoch in range(EPOCHS):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{5}")
    model.train()
    epoch_loss = 0
    step = 0
    
    # Iterate over the batches
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        
    epoch_loss /= step
    epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(DELTA) 
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}"
          f"(ε = {epsilon:.2f}, δ = {delta}) for α = {best_alpha}")

    # Evaluate the current model in regular interval
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            num_correct = 0.0
            metric_count = 0
            for val_data in val_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                val_outputs = model(val_images)
                value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                metric_count += len(value)
                num_correct += value.sum().item()
            metric = num_correct / metric_count
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "./Data/best_metric_model_classification3d_array.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")