## Practical work in AI

1. Image preprocessing

The images are not sorted in the right way: can't figure out which domain belongs to which images.
I wrote a script to perform sorting based on the domain names that are present in the UNPROCESSED Retouch images.
Now, I have a list of dicts for the source images and a list of dicts for the target images.

2. Image augmentation

    Main problem currently: SVD is slow (0.6 seconds/image just for the decomposition!)
    the entire dataset in batches of 16, dataloading alone takes more than 5 minutes!

    

3. Training

    How many output channels does my network? 2? How many classes do I actually have? i guess 3, since in the paper they use 3 biomarkers
       

In [155]:
import numpy as np

from pathlib import Path
import os
from typing import *

from tqdm.notebook import tqdm

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchvision import transforms

from sklearn.model_selection import train_test_split

import monai
from monai.networks.nets import BasicUNetPlusPlus
from monai.transforms import *
from monai.config.type_definitions import KeysCollection

import albumentations as A

import wandb

# Imports from local files
from transforms import *
from dataset import *
from utils import *


# Set random seed
np.random.seed(99)
torch.manual_seed(99)

wandb.login()

True

## Paths to the data

In [156]:
name_dir = Path(Path.cwd() / 'data/RETOUCH/TrainingSet-Release/') # directory where img folders are still sorted by domain (but unprocessed OCT images)
train_dir = Path(Path.cwd() / 'data/Retouch-Preprocessed/train') # already processed OCT images but unsorted by domain (sorting happens in dataset class)

## Image transformations

In [157]:
transforms = Compose([
    #GetMaskPositions(keys=['masks'], target_keys=["mask_positions"]), #We get the layer position, but on the original height
    #LayerPositionToProbabilityMap(["mask_positions"], target_size=(400,400), target_keys=["mask_probability_map"]),

    CustomImageLoader(keys=['img', 'label']), # if SVDNA should not be performed, uncomment this and comment the following two lines
    #SVDNA(keys=['img'], histogram_matching_degree=.5),
    #CustomImageLoader(keys=['label']),
    ConvertLabelMaskToChannel(keys=['label'], target_keys=["masks"]),
    ExpandChannelDim(keys=['img', 'label']),
    ToTensord(keys=['img', 'label', 'masks']),
    NormalizeToZeroOne(keys=['img', 'label', 'masks']),
    #Debugging(keys=['img', 'label', 'masks']),
    Transposed(keys=['img', 'label', 'masks'], indices=[0, 2, 1]),
    Resized(keys=["img", "label", 'masks'], mode=["area", "nearest-exact", "nearest-exact"], spatial_size=[1024, 400]),
    #RandZoomd(keys=["img", "label", 'masks'], mode=["area", "nearest-exact", "nearest-exact"], prob=0.3, min_zoom=0.5, max_zoom=1.5),
    #RandAxisFlipd(keys=["img", "label", 'masks'], prob=0.3),
    #RandHistogramShiftd(keys=["img"], prob=0.3),
    #RandAffined(keys=["img", "label", 'masks'], 
    #            prob=0.3, 
    #            shear_range=[(-0.7, 0.7), (0.0, 0.0)], 
    #            translate_range=[(-300, 100), (0, 0)], 
    #            rotate_range=[20, (0, 0)],
    #            mode=["bilinear", "nearest", "nearest"], 
    #            padding_mode="zeros"),      
    #Resized(keys=["img", "label", "masks"], mode=["area", "nearest-exact", "nearest-exact"], spatial_size=[400, 400]),
    #Lambdad(keys=['mask_positions'], func = lambda x: x * 400 / 800), #We scale down the positions to have more accurate positions
    #Lambdad(keys=['img'], func = lambda x: np.clip((x - x.mean()) / x.std(), -1, 1)),
    #Lambdad(keys=['img'], func = lambda x: 2*(x - x.min()) / (x.max() - x.min()) - 1 ),
    #ImageVisualizer(keys=['img', 'label', 'masks']),        
    
])


print_some_imgs = False

if print_some_imgs:
    transforms_visualize = Compose([transforms, ImageVisualizer(keys=['img', 'label', 'masks'])])
    dataset = OCTDatasetPrep(train_dir, transform=transforms_visualize)

    for i in range(5):
 
        rand_num = np.random.randint(0, len(dataset))
        sample = dataset[rand_num]
        print("Sample: ", sample['img'].shape, sample['label'].shape, sample['masks'].shape)



## Architecture, dataset, loss function, optimizer

In [158]:
if not os.path.isdir('models'):
    os.mkdir('models')

device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print("Device: ", device)

# model params were not specified in SVDNA paper
model = monai.networks.nets.UNet(spatial_dims=2, 
                                 in_channels=1, 
                                 out_channels=3, 
                                 channels=(16, 32, 64, 128, 256),
                                 strides=(2, 2, 2, 2),
                                 kernel_size=3, 
                                 up_kernel_size=3, 
                                 num_res_units=2, 
                                 act='RELU', 
                                 norm='INSTANCE', 
                                 dropout=0.1, 
                                 bias=False, 
                                 adn_ordering='ADN')
model = model.to(device)

#BasicUNetPlusPlus(
#    spatial_dims=2,
#    in_channels=1,
#    out_channels=3,
#    features=(32, 64, 128, 256, 512, 1024)
#)


criterion = monai.losses.DiceLoss(sigmoid=True)
optimizer = optim.Adam(model.parameters(), lr=0.005)

dataset = OCTDatasetPrep(train_dir, transform=transforms)


Device:  mps


## First tests

Let's overfit on a single batch of 16 images.
Then on two batches of 16 images.
Adapt the learning rate: 1e-4, 1e-3, 5e-3 

In [126]:
trainset_len = 32
valset_len = 8
rest = len(dataset) - trainset_len - valset_len

train_dataset, val_dataset, rest = random_split(dataset, lengths=[trainset_len, valset_len, rest]) # if i keep running this, i'll always get a different batch of images

print("Training dataset length: ", len(train_dataset), 
      "Validation dataset length", len(val_dataset),
      "Rest: ", len(rest))

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0) # multiprocessing doesn't work??
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True, num_workers=0)

Training dataset length:  32 Validation dataset length 8 Rest:  744


In [127]:
model = load_model(model, 'models/model_two_batches.pth')
model = train(model, train_loader, val_loader, criterion, optimizer, device, epochs=200, save_path='models/model_two_batches.pth')

Training:   0%|          | 0/100 [00:00<?, ?epoch/s]

Epoch 1/100
Batch 1/2, Loss: 0.9922
Batch 2/2, Loss: 0.9846
Validation loss: 0.9956
Model saved.
Epoch 2/100
Batch 1/2, Loss: 0.9864
Batch 2/2, Loss: 0.9869
Validation loss: 0.9938
Model saved.
Epoch 3/100
Batch 1/2, Loss: 0.9860
Batch 2/2, Loss: 0.9823
Validation loss: 0.9937
Model saved.
Epoch 4/100
Batch 1/2, Loss: 0.9814
Batch 2/2, Loss: 0.9834
Validation loss: 0.9928
Model saved.
Epoch 5/100
Batch 1/2, Loss: 0.9760
Batch 2/2, Loss: 0.9843
Validation loss: 0.9922
Model saved.
Epoch 6/100
Batch 1/2, Loss: 0.9792
Batch 2/2, Loss: 0.9774
Validation loss: 0.9918
Model saved.
Epoch 7/100
Batch 1/2, Loss: 0.9763
Batch 2/2, Loss: 0.9765
Validation loss: 0.9913
Model saved.
Epoch 8/100
Batch 1/2, Loss: 0.9734
Batch 2/2, Loss: 0.9768
Validation loss: 0.9907
Model saved.
Epoch 9/100
Batch 1/2, Loss: 0.9681
Batch 2/2, Loss: 0.9788
Validation loss: 0.9902
Model saved.
Epoch 10/100
Batch 1/2, Loss: 0.9815
Batch 2/2, Loss: 0.9623
Validation loss: 0.9897
Model saved.
Epoch 11/100
Batch 1/2, Loss: