# Alzheimer MRI images data notebook 
by Martin Closter Jespersen from Deloitte Consulting

## About the data 
The data was generated from a real patient cohort of MRI images both with and without alzheimer. The real images are in the dimension of 128x128 and consists of the following distribution (i.e. real distribution):
- 2560 Non alzheimer
- 1792 Very mild alzheimer 
- 717  Mild alzheimer
- 52   Moderate alzheimer

You are provided an evenly distributed synthetic dataset of ~3000 128x128 synthetic MRI images of each class. 

The data was generated using a simple (non state-of-the-art) <b>Conditional Generative Adversial Network (cGAN)</b>. cGANs are generally data hungry and considering this small dataset with great class imbalance, the data quality can be <u><b>limited</b></u>. Though image applications of machine learning has developed far, creating diverse synthetic images is still the main bottleneck. More sophisticated methods have improved this substantially but was not used here due to time and dataset size.

Reading material:
* Analysis of using GANs to replace real biomedical images in classification https://arxiv.org/pdf/1904.08688.pdf
* Synthetic COVID X ray images https://arxiv.org/pdf/2009.12478.pdf
* Synthetizing chest X ray images for model development https://www.researchgate.net/publication/328945795_Synthesizing_Chest_X-Ray_Pathology_for_Training_Deep_Convolutional_Neural_Networks


Investigate how the data can be useful!  

## Possible challenges
- Can you train a model on synthetic data to predict alzheimers on real data (binary = No alzheimer or not)
    - Easiest task of prediction
    - Play with best distribution / synthetic data size needed (most likely a small subsample is sufficient). 
    - Removing redundancy (too similar images) might be needed  (https://github.com/JohannesBuchner/imagehash)
    - Does the performance on real data increase if training only using synthetic data a subset of the alzheimer classes (i.e. leave out less frequent ones)?
- How well does the synthetic data behave on each class on the real data?
    - More complicated task of prediction
- How does a model trained on the real data behave on the synthetic data?
    - This can be useful if one wants to scale a model to new country and want to evaluate if it would succeed
- Do the model trained on real and synthetic data behave similarly? 
    - I.e. do they predict same targets the same classes, or do they use the same part of the images to classify (Grad-CAM or similar method)
- Can you improve the model trained on real data by augmenting it with synthetic data in the training?
- Explore and compare the real and synthetic data
    - Average pixel of each separate class or other create ideas?
- Can you build a model which can predict real from synthetic images, and if so can you understand why it can differentiate?

### import libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import numpy as np
from PIL import Image
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

In [4]:
from utils import load_images, set_global_seed
from dataset import MRIDataset, Unsqueeze, Repeat
from train import train, validate_epoch, metrics_callback
from convnet import ConvDropoutNet, ConvBatchNormNet, make_resnet18, make_pretrained_resnet18

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
set_global_seed(42)

### loading in the images

In [7]:
# load images
synth_images, synth_labels = load_images('data/SyntheticDataset', verbose=True)
real_images, real_labels = load_images('data/RealDataset', verbose=True)
real_final_images, real_final_labels = load_images('data/RealTestDataset', verbose=True)

100%|██████████| 12000/12000 [00:03<00:00, 3811.84it/s]
100%|██████████| 5121/5121 [00:01<00:00, 3836.60it/s]
100%|██████████| 1279/1279 [00:00<00:00, 3946.00it/s]


In [8]:
# print shapes
print(f'synth_images shape: {synth_images.shape}')
print(f'synth_labels shape: {synth_labels.shape}')
print()
print(f'real_images shape: {real_images.shape}')
print(f'real_labels shape: {real_labels.shape}')
print()
print(f'real_final_images shape: {real_final_images.shape}')
print(f'real_final_labels shape: {real_final_labels.shape}')

synth_images shape: (12000, 128, 128)
synth_labels shape: (12000,)

real_images shape: (5121, 128, 128)
real_labels shape: (5121,)

real_final_images shape: (1279, 128, 128)
real_final_labels shape: (1279,)


In [9]:
# print labels distribution
print(f'synth_labels: {Counter(synth_labels)}')
print(f'real_labels: {Counter(real_labels)}')
print(f'real_final_labels: {Counter(real_final_labels)}')

synth_labels: Counter({0: 3047, 1: 3030, 3: 3023, 2: 2900})
real_labels: Counter({0: 2560, 1: 1792, 2: 717, 3: 52})
real_final_labels: Counter({0: 640, 1: 448, 2: 179, 3: 12})


In [10]:
# binary classification
synth_labels = np.clip(synth_labels, 0, 1)
real_labels = np.clip(real_labels, 0, 1)
real_final_labels = np.clip(real_final_labels, 0, 1)

In [11]:
# print labels distribution after binary classification
print(f'synth_labels: {Counter(synth_labels)}')
print(f'real_labels: {Counter(real_labels)}')
print(f'real_final_labels: {Counter(real_final_labels)}')

synth_labels: Counter({1: 8953, 0: 3047})
real_labels: Counter({1: 2561, 0: 2560})
real_final_labels: Counter({0: 640, 1: 639})


### dataset and dataloader

In [12]:
# train/test split
synth_train_img, synth_test_img, synth_train_label, synth_test_label = train_test_split(
    synth_images, synth_labels, test_size=0.20, random_state=42,
)
real_train_img, real_test_img, real_train_label, real_test_label = train_test_split(
    real_images, real_labels, test_size=0.20, random_state=42,
)

In [13]:
# image transformation
transform = transforms.Compose([
    Unsqueeze(axis=-1),
    Repeat(n_channel=3, axis=-1),
    transforms.ToTensor(),
])

In [14]:
# make datasets

# synthetic
synth_train_dataset = MRIDataset(
    images=synth_train_img,
    labels=synth_train_label,
    transform=transform,
)
synth_test_dataset = MRIDataset(
    images=synth_test_img,
    labels=synth_test_label,
    transform=transform,
)

# real
real_train_dataset = MRIDataset(
    images=real_train_img,
    labels=real_train_label,
    transform=transform,
)
real_test_dataset = MRIDataset(
    images=real_test_img,
    labels=real_test_label,
    transform=transform,
)

# real_final
real_final_dataset = MRIDataset(
    images=real_final_images,
    labels=real_labels,
    transform=transform,
)

In [15]:
# make dataloaders

# synthetic
synth_train_loader = DataLoader(
    synth_train_dataset,
    batch_size=64,
    shuffle=True,
)
synth_test_loader = DataLoader(
    synth_test_dataset,
    batch_size=64,
    shuffle=False,
)

# real
real_train_loader = DataLoader(
    real_train_dataset,
    batch_size=64,
    shuffle=True,
)
real_test_loader = DataLoader(
    real_test_dataset,
    batch_size=64,
    shuffle=False,
)

# real_final
real_final_loader = DataLoader(
    real_final_dataset,
    batch_size=1,
    shuffle=False,
)

### synthetic vs synthetic

In [16]:
model = ConvDropoutNet(in_channels=3, n_classes=2).to(device)

In [17]:
print(f'num of params: {sum(p.numel() for p in model.parameters())}')

num of params: 172434


In [18]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [19]:
train(
    model=model,
    trainloader=synth_train_loader,
    valloader=synth_test_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    n_epochs=10,
    verbose=False,
)

In [20]:
# save model
torch.save(model.state_dict(), 'models/synthetic_model.pth')

### real vs real

In [21]:
model = make_resnet18(num_classes=2).to(device)

In [22]:
print(f'num of params: {sum(p.numel() for p in model.parameters())}')

num of params: 11177538


In [23]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [24]:
train(
    model=model,
    trainloader=real_train_loader,
    valloader=real_test_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    n_epochs=10,
    verbose=False,
)

In [25]:
# save model
torch.save(model.state_dict(), 'models/real_model.pth')

### synthetic vs real

In [26]:
model = make_resnet18(num_classes=2).to(device)

In [27]:
print(f'num of params: {sum(p.numel() for p in model.parameters())}')

num of params: 11177538


In [28]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [29]:
train(
    model=model,
    trainloader=synth_train_loader,
    valloader=real_train_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    n_epochs=5,
    verbose=True,
)

epoch [1/5]

train accuracy: 0.9540

train metrics:
              precision    recall  f1-score   support

           0       0.92      0.89      0.91      2434
           1       0.96      0.97      0.97      7166

    accuracy                           0.95      9600
   macro avg       0.94      0.93      0.94      9600
weighted avg       0.95      0.95      0.95      9600


val accuracy: 0.6626

val metrics:
              precision    recall  f1-score   support

           0       0.66      0.69      0.67      2061
           1       0.67      0.64      0.65      2035

    accuracy                           0.66      4096
   macro avg       0.66      0.66      0.66      4096
weighted avg       0.66      0.66      0.66      4096



epoch [2/5]

train accuracy: 0.9871

train metrics:
              precision    recall  f1-score   support

           0       0.97      0.97      0.97      2434
           1       0.99      0.99      0.99      7166

    accuracy                           0

In [30]:
metrics = validate_epoch(
    model=model,
    dataloader=real_final_loader,
    criterion=criterion,
    device=device,
)
accuracy, report = metrics_callback(metrics=metrics)

In [31]:
# final metrics
print(f'accuracy: {accuracy:.4f}\n')
print(f'metrics:\n{report}\n')

accuracy: 0.4910

metrics:
              precision    recall  f1-score   support

           0       0.53      0.18      0.27       663
           1       0.48      0.82      0.61       616

    accuracy                           0.49      1279
   macro avg       0.50      0.50      0.44      1279
weighted avg       0.51      0.49      0.43      1279




### adversarial validation

In [32]:
# concal all data (synthetic: 0, real: 1)
all_images = np.concatenate([synth_images, real_images])
all_labels = np.concatenate([np.zeros(len(synth_images)), np.ones(len(real_images))]).astype(np.int64)

In [33]:
adv_val_dataset = MRIDataset(
    images=all_images,
    labels=all_labels,
    transform=transform,
)

In [34]:
train_len = int(len(adv_val_dataset) * 0.8)
val_len = len(adv_val_dataset) - train_len

In [35]:
adv_val_train_dataset, adv_val_test_dataset = random_split(
    dataset=adv_val_dataset,
    lengths=[train_len, val_len],
)

In [36]:
adv_val_train_loader = DataLoader(
    adv_val_train_dataset,
    batch_size=64,
    shuffle=True,
)
adv_val_test_loader = DataLoader(
    adv_val_test_dataset,
    batch_size=64,
    shuffle=False,
)

In [37]:
model = ConvDropoutNet(in_channels=3, n_classes=2).to(device)

In [38]:
print(f'num of params: {sum(p.numel() for p in model.parameters())}')

num of params: 172434


In [39]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [40]:
train(
    model=model,
    trainloader=adv_val_train_loader,
    valloader=adv_val_test_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    n_epochs=2,
    verbose=False,
)

### filter bad images

In [41]:
synth_dataset = MRIDataset(
    images=synth_images,
    labels=synth_labels,
    transform=transform,
)
real_dataset = MRIDataset(
    images=real_images,
    labels=real_labels,
    transform=transform,
)

In [42]:
# get image scores
scores = []
for img, _ in synth_dataset:
    with torch.no_grad():
        img = img.unsqueeze(0).to(device)
        score = torch.softmax(model(img), dim=1)[0][1].item()
    scores.append(score)

In [43]:
print(f'most synthetic image probability: {np.min(scores)}')
print(f'least synthetic image probability: {np.max(scores)}')

most synthetic image probability: 1.4090519029341664e-10
least synthetic image probability: 0.5821159482002258


In [44]:
# filter bad images
indices = np.argsort(scores)[-6000:]

good_synth_images = synth_images[indices]
good_synth_labels = synth_labels[indices]

In [45]:
print(f'good_synth_labels: {Counter(good_synth_labels)}')

good_synth_labels: Counter({1: 4031, 0: 1969})


In [46]:
good_synth_dataset = MRIDataset(
    images=good_synth_images,
    labels=good_synth_labels,
    transform=transform,
)

In [47]:
good_synth_loader = DataLoader(
    good_synth_dataset,
    batch_size=64,
    shuffle=True,
)

### synthetic vs real v2.0

In [48]:
model = make_resnet18(num_classes=2).to(device)

In [49]:
print(f'num of params: {sum(p.numel() for p in model.parameters())}')

num of params: 11177538


In [50]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [51]:
train(
    model=model,
    trainloader=good_synth_loader,
    valloader=real_train_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    n_epochs=5,
    verbose=True,
)

epoch [1/5]

train accuracy: 0.9410

train metrics:
              precision    recall  f1-score   support

           0       0.93      0.89      0.91      1969
           1       0.95      0.97      0.96      4031

    accuracy                           0.94      6000
   macro avg       0.94      0.93      0.93      6000
weighted avg       0.94      0.94      0.94      6000


val accuracy: 0.6375

val metrics:
              precision    recall  f1-score   support

           0       0.63      0.70      0.66      2061
           1       0.65      0.58      0.61      2035

    accuracy                           0.64      4096
   macro avg       0.64      0.64      0.64      4096
weighted avg       0.64      0.64      0.64      4096



epoch [2/5]

train accuracy: 0.9743

train metrics:
              precision    recall  f1-score   support

           0       0.96      0.96      0.96      1969
           1       0.98      0.98      0.98      4031

    accuracy                           0

In [52]:
metrics = validate_epoch(
    model=model,
    dataloader=real_final_loader,
    criterion=criterion,
    device=device,
)
accuracy, report = metrics_callback(metrics=metrics)

In [53]:
# final metrics
print(f'accuracy: {accuracy:.4f}\n')
print(f'metrics:\n{report}\n')

accuracy: 0.4988

metrics:
              precision    recall  f1-score   support

           0       0.52      0.56      0.54       663
           1       0.48      0.44      0.46       616

    accuracy                           0.50      1279
   macro avg       0.50      0.50      0.50      1279
weighted avg       0.50      0.50      0.50      1279


