# Synthetic NF1 MRI Images - Train the FCC-GAN

## 1. Project datasheet

| Category | Value
| --: | :-- |
| Project | Synthetic NF1 MRI Images ` (syn26010238) ` |
| Team | DCD ` (3430042) ` |
| Competition | Hack4Rare 2021 |
| Description | script to train FCC-GAN
| Additional requirements | you need to have the dataset ` syn20608511 ` |

## 2. Imports

Project level imports include check of folders. It produces an output.

In [None]:
from os import listdir, mkdir
from os.path import isdir, isfile, join
from pickle import dump, load
from random import shuffle

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pydicom
import torch

from config import Folder
import constants
import enum_handler
from models import DCDDiscriminator, DCDGenerator
from training_tools import weight_init, yield_x_batches

## 3. Constants

These constants control the training workflow.

| Variable name | type | Effect | Comments |
| --: | :-: | :-- | :-- |
| **ALLOW_CUDA** | bool | Whether or not to train on CUDA | This switch has no effect if cuda is not available |
| **BATCH_SIZE** | int | Size of batches | In GPU mode size of GPU memory is a limitation. |
| **DEPTH_BASE** | int | Base value for channel multiplication | Used to construct convolutional and transposed convolutional layers |
| **EPOCH_COUNT** | int | Number of epoch to train | |
| **FLAT_SIZE** | int | Size of the flat vector | It is used at the end of the discriminator model |
| **GAUSSIAN_INIT** | bool | Whether or not to apply gaussian init on layers | It has no effect is models are laoded |
| **IMG_CHANNELS** | int | Count of color channels for images | It is applied both for input and output images |
| **IMG_HEIGHT** | int | Height of generated images | |
| **IMG_WIDTH** | int | Width of generated images | |
| **KERNEL_SIZE** | int | Size of the kernels to use | Convolutional and transposed convulational layers use the same kernel size |
| **LEARNING_RATE** | float | Learning rate for optimizers | Both models are using the same learning rate |
| **LOAD_MODEL** | str | Model name to load | If an empty string is given, a training from scratch is started |
| **RANDOM_VECTOR_DIM0** | int | Dim 0 of the random vector | |
| **VECTOR_DIMS** | tuple(int, int) | Further dimensons of the random vector | |

In [None]:
ALLOW_CUDA = True
BATCH_SIZE = 4
DEPTH_BASE = 16
EPOCH_COUNT = 3000
FLAT_SIZE = 992
GAUSSIAN_INIT = True
IMG_CHANNELS = 1
IMG_HEIGHT = 530
IMG_WIDTH = 162
KERNEL_SIZE = 4
LEARNING_RATE = 1e-4
LOAD_MODEL = ''
RANDOM_VECTOR_DIM0 = 128
VECTOR_DIMS = (50, 4)

## 4. Dataset

### 4.1. Check or prepare dataset

#### 4.1.1. Functions the prepare or load dataset if needed

In [None]:
def copy_selected(patient_id : str, start : int =13, stop : int=15):
    """
    Copy and normalize selected slices from dicom files
    ===================================================
    
    Parameters
    ----------
    patient_id : str
        ID of the patient to copy images.
    start : int, optional (13 if omitted)
        First index of the dicomlist to copy.
    stop : int, optional (14 if omitted)
        Index in the dicomlist to stop copy.
    
    Notes
    -----
        Just like in case of any other Python slicing, stop index is not included.
    """
    
    dicom_files = get_dicomlist(patient_id)
    count = len(dicom_files)
    if count > 0:
        for index, file_name in enumerate(dicom_files[start:stop]):
            img = pydicom.dcmread(file_name).pixel_array
            img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
            with open('{}/{}_{}_orig.pkl'
                      .format(Folder.IMG_ORIGINAL.value, patient_id,
                              start + index), 'wb') as outstream:
                dump(img, outstream)
            img = cv2.convertScaleAbs(img)
            img = np.asarray(img, dtype=np.float32) / 255.0
            with open('{}/{}_{}_normalized.pkl'
                      .format(Folder.IMG_NORMALIZED.value, patient_id,
                              start + index), 'wb') as outstream:
                dump(img, outstream)


def get_dicomlist(patient_id : str) -> list:
    """
    Get list of existing DICOM files of a patient
    =============================================
    
    Parameters
    ----------
    patient_id : str
        ID of the patient to get list of files.
        
    Returns
    -------
    list
        List of files.
    """

    result = []
    patient_root = join(Folder.CASES.value, 'WBMRI{}/DICOM/'.format(patient_id))
    if isdir(patient_root):
        directories = [join(patient_root, d) for d in listdir(patient_root)
                         if isdir(join(patient_root, d))]
        for directory in directories:
            for file_name in listdir(directory):
                full_path = join(directory, file_name)
                if isfile(full_path) and full_path.endswith('.dcm'):
                    result.append(full_path)
    return result


def get_patient_ids() -> list:
    """
    Get list of patient IDs
    =======================
    
    Returns
    -------
    list
        List of patient IDs.
    """
    
    result = []
    patients_files = [f for f in listdir(Folder.REPORT.value)
                      if isfile(join(Folder.REPORT.value, f)) and f.endswith('.xls')]
    for patient_file in patients_files:
        result.append(patient_file.lstrip('wbmri_').rstrip('.xls'))
    return result

def load_image_np(file_name : str) -> np.ndarray:
    """
    Load image as numpy.ndarray
    
    Parameters
    ----------
    file_name : str
        File path to load.
    
    Returns
    -------
    np.ndarray
        The loaded image as numpy.ndarray.
    """
    
    with open(file_name, 'rb') as instream:
        return np.expand_dims(np.expand_dims(load(instream), axis=0), axis=0)

#### 4.1.2. Check or prepare in action

In [None]:
if not isdir(Folder.IMG_NORMALIZED.value):
    if not isdir(Folder.IMG_ROOT.value):
        mkdir(Folder.IMG_ROOT.value)
    if not isdir(Folder.IMG_ORIGINAL.value):
        mkdir(Folder.IMG_ORIGINAL.value)
    if not isdir(Folder.IMG_NORMALIZED.value):
        mkdir(Folder.IMG_NORMALIZED.value)
    for patient_id in get_patient_ids():
        copy_selected(patient_id)
    print('Image data was just created.')
else:
    if len([f for f in listdir(Folder.IMG_NORMALIZED.value)]) == 0:
        if not isdir(Folder.IMG_ORIGINAL.value):
            mkdir(Folder.IMG_ORIGINAL.value)
        for patient_id in get_patient_ids():
            copy_selected(patient_id)
        print('Image data was just created.')
    else:
        print('Image data already exists.')

### 4.2. Load the dataset

In fact image paths are loaded only but this is a good way to spare memory.

In [None]:
image_list = [join(Folder.IMG_NORMALIZED.value, f) for f in listdir(Folder.IMG_NORMALIZED.value)
                                                   if isfile(join(Folder.IMG_NORMALIZED.value, f))]

## 5. Device

In [None]:
device = 'cuda' if ALLOW_CUDA and torch.cuda.is_available() else 'cpu'
print('Training on {}.'.format(device))

## 6. Model

### 6.1. Instantiate model

In [None]:
generator = DCDGenerator(IMG_CHANNELS, DEPTH_BASE, KERNEL_SIZE, (RANDOM_VECTOR_DIM0,*VECTOR_DIMS))
discriminator = DCDDiscriminator(IMG_CHANNELS, DEPTH_BASE, KERNEL_SIZE, FLAT_SIZE)

### 6.2. Load pretrained if needed

In [None]:
if LOAD_MODEL != '':
    if isfile(join(Folder.GENERATOR.value, LOAD_MODEL)) and isfile(join(Folder.DISCRIMINATOR.value, LOAD_MODEL)):
        generator.load_state_dict(torch.load(join(Folder.GENERATOR.value, LOAD_MODEL)))        
        discriminator.load_state_dict(torch.load(join(Folder.DISCRIMINATOR.value, LOAD_MODEL)))
    else:
        print('Trained models "{}" are not available here.'.format(LOAD_MODEL))
else:
    print('No trained models are loaded.')

### 6.3. Apply initialization

In [None]:
if LOAD_MODEL == '' and GAUSSIAN_INIT:
    generator.apply(weight_init)
    discriminator.apply(weight_init)
    print('Gaussian init applied.')
else:
    print('Gaussian init not applied.')

### 6.4. Move model to device

In [None]:
generator.to(device)
discriminator.to(device)
print('Device {} is set.'.format(device))

### 6.5. Define loss function and add optimizers to models

In [None]:
criterion = torch.nn.BCELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE)

## 7. Training

### 7.1. Check and create folders

In [None]:
if not isdir(Folder.MODEL_ROOT.value):
    mkdir(Folder.MODEL_ROOT.value)
if not isdir(Folder.GENERATOR.value):
    mkdir(Folder.GENERATOR.value)
if not isdir(Folder.DISCRIMINATOR.value):
    mkdir(Folder.DISCRIMINATOR.value)
if not isdir(Folder.SAMPLES.value):
    mkdir(Folder.SAMPLES.value)

### 7.3. Actual training

In [None]:
best_epoch_loss_g, best_epoch_loss_d = np.inf, np.inf
for epoch in range(1, EPOCH_COUNT + 1):
    shuffle(image_list)
    epoch_loss_g, epoch_loss_d = 0.0, 0.0
    g_losses, d_losses = [], []
    b_count = 1
    good, count = 0, 0
    for image_files in yield_x_batches(image_list, BATCH_SIZE):
        # Discriminator training
        dim0 = len(image_files)
        images = np.concatenate([load_image_np(i) for i in image_files])
        targets = np.expand_dims(np.array([1.0 for i in range(dim0)],
                                        dtype=np.float32), axis=1)
        images = torch.from_numpy(images).to(device)
        targets = torch.from_numpy(targets).to(device)
        noise = torch.randn(dim0, RANDOM_VECTOR_DIM0, device=device)
        fake_images = generator(noise)
        fake_targets = np.expand_dims(np.array([0.0 for i in range(dim0)],
                                            dtype=np.float32), axis=1)
        fake_targets = torch.from_numpy(fake_targets).to(device)
        optimizer_d.zero_grad()
        d_reals = discriminator(images)
        real_loss = criterion(d_reals, targets)
        real_loss.backward()
        d_fakes = discriminator(fake_images)
        fake_loss = criterion(d_fakes, fake_targets)
        fake_loss.backward()
        optimizer_d.step()
        for value in (torch.round(d_reals) == targets).detach().cpu():
            if value:
                good +=1
            count +=1
        for value in (torch.round(d_fakes) == fake_targets).detach().cpu():
            if value:
                good +=1
            count +=1
        loss_item_d = real_loss.item() + fake_loss.item()
        for i in range(dim0):
            d_losses.append(loss_item_d)
        # Generator training
        optimizer_g.zero_grad()
        fake_images = generator(noise)
        g_targets = np.expand_dims(np.array([1.0 for i in range(dim0)],
                                            dtype=np.float32), axis=1)
        g_targets = torch.from_numpy(g_targets).to(device)
        g_preds = discriminator(fake_images)
        generator_loss = criterion(g_preds, g_targets)
        generator_loss.backward()
        optimizer_g.step()
        loss_item_g = generator_loss.item()
        for i in range(dim0):
            g_losses.append(loss_item_g)
        print('\r{:04d}/{:02d} - generator {:.8f} - discriminator {:.8f}'
            .format(epoch, b_count, loss_item_g, loss_item_d),
            end='')
        b_count += 1
    epoch_loss_d = sum(d_losses) / len(d_losses)
    epoch_loss_g = sum(g_losses) / len(g_losses)
    print_str = 'Epoch {:04d} - generator {:.8f} - discriminator {:.8f}'.format(epoch,
                                                                                epoch_loss_g,
                                                                                epoch_loss_d)
    print_str += ' - discriminator accuracy: {:.4f}'.format(good / count)
    print('\r{}'.format(print_str))
    torch.save(discriminator.state_dict(),
          join(Folder.DISCRIMINATOR.value, 'discriminator_last.state_dict'))
    torch.save(generator.state_dict(),
          join(Folder.GENERATOR.value, 'generator_last.state_dict'))
    plt.imsave(join(Folder.SAMPLES.value, 'e{:05d}.png'.format(epoch)),
        fake_images.detach().cpu()[0, 0], cmap='bone')