In [1]:
%matplotlib widget
import numpy as np
import pandas as pd
from pathlib import Path
import os, sys, datetime, time, random, fnmatch, math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import skimage.metrics

import torch
from torchvision import transforms as tvtransforms
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.utils as vutils
import torch.utils.tensorboard as tensorboard
import torch.nn as nn

import datasets, transforms, GusarevModel, pytorch_msssim

flag_debug = True

# Input Directories
data_BSE = "D:/data/JSRT/BSE_JSRT"
data_normal = "D:/data/JSRT/JSRT"

# Image Size:
image_spatial_size = (440,440)
_batch_size = 4
test_length = 10

# Optimisation
lr_ini = 0.001
beta1 = 0.9
beta2 = 0.999

# Training
num_reals_per_epoch_paper = 4000 # in Gusarev et al. 2017
total_num_epochs_paper = 150
num_epochs_decay_lr_paper = 100
lr_decay_ratio = 0.25

# Weight Initialisation
def weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight.data, 0., 0.02)
        try:
            nn.init.constant_(m.bias.data, 0.)
        except:
            pass
    if isinstance(m, nn.BatchNorm2d):
        if m.affine:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.)

## Code for putting things on the GPU
ngpu = 2
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
device = torch.device("cuda" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print(device)

cpu


In [2]:
# Current date:
current_date=datetime.datetime.today().strftime('%Y-%m-%d')

# Data Loader
discriminator_keys_images = ["source", "boneless"]
target_key = "boneless"
ds = datasets.JSRT_CXR(data_normal, data_BSE,
                         transform=tvtransforms.Compose([
                             transforms.CLAHE(discriminator_keys_images),
                             transforms.ZScoreNormalisation(discriminator_keys_images),
                             transforms.RescalingNormalisation(discriminator_keys_images,(0,1)),
                             transforms.RandomHorizontalFlip(discriminator_keys_images, probability=0.5),
                             transforms.RandomVerticalFlip(discriminator_keys_images, probability=0.5),
                             transforms.IntensityJitter(discriminator_keys_images,source_image_key="source", rescale_factor_limits=(0.75,1.0), window_motion_limits=(-1,1)),
                             transforms.RandomIntensityComplement(discriminator_keys_images, probability=0.5),
                             transforms.RandomRotation(discriminator_keys_images),
                             transforms.Rescale(image_spatial_size, discriminator_keys_images, None),
                             transforms.ToTensor(discriminator_keys_images),
                             ])
                      )

# SPLIT DATA INTO TRAINING/VALIDATION SET
lengths=(len(ds)-test_length, test_length)
ds_training, ds_val = torch.utils.data.random_split(ds, lengths)

dl_training = DataLoader(ds_training, batch_size=_batch_size,
                         shuffle=True, num_workers=0)

fixed_val_sample = next(iter(ds_val))

In [3]:
## Implementation of network and losses
input_array_size = (_batch_size, 1, image_spatial_size[0], image_spatial_size[1])
net = GusarevModel.Autoencoder(input_array_size)
# Initialise weights
net.apply(weights_init)

# Multi-GPU
if (device.type == 'cuda') and (ngpu > 1):
    net = nn.DataParallel(net, list(range(ngpu)))
    

# Optimiser
optimizer = torch.optim.Adam(net.parameters(), lr=lr_ini, betas=(beta1, beta2))
# Learning Rate Scheduler
epoch_factor = num_reals_per_epoch_paper//len(ds_training) # need to have this factor as many epochs as that described in the paper
total_num_epochs = total_num_epochs_paper*epoch_factor
num_epochs_decay_lr = num_epochs_decay_lr_paper*epoch_factor
def lambda_rule(epoch, lr_ini=lr_ini, num_epochs_decay_lr=num_epochs_decay_lr, lr_decay_ratio=lr_decay_ratio):
    lr = lr_ini*((1-lr_decay_ratio)**(epoch//num_epochs_decay_lr))
    return lr

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)

# Gusarev Loss
def criterion_Gusarev(testImage, referenceImage, alpha=0.84):
    """
    Gusarev et al. 2017. Deep learning models for bone suppression in chest radiographs.  IEEE Conference on Computational Intelligence in Bioinformatics and Computational Biology.
    """
    mseloss = nn.MSELoss() # L2 used for easier optimisation

    msssim = pytorch_msssim.MSSSIM(window_size=11, size_average=True, channel=1, normalize="relu")
    msssim_loss = 1 - msssim(testImage, referenceImage)
    total_loss = (1-alpha)*mseloss(testImage, referenceImage) + alpha*msssim_loss
    return total_loss

In [5]:
# Training
img_list = []
loss_list = []
reals_shown = []
reals_shown_now = 0

# For each epoch
iters = 0
torch.autograd.set_detect_anomaly(True)
for epoch in range(total_num_epochs ):
    for i, data in enumerate(dl_training):
        iters +=1
        optimizer.zero_grad()
        noisy_data = data["source"].to(device)
        cleaned_data = net(noisy_data)
        loss = criterion_Gusarev(cleaned_data, data[target_key].to(device))
        loss.backward() # calculate gradients
        optimizer.step() # optimiser step along gradients
        
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_G: %.4f'
                  % (epoch, total_num_epochs, i, len(dl_training),
                     loss.item()))
        # Record generator output
        if (iters % 100 == 0) or ((epoch == total_num_epochs-1) and (i == len(dl_training)-1)):
            with torch.no_grad():
                val_cleaned = net(fixed_val_sample["source"].to(device)).detach().cpu()
            img_list_pretraining.append(vutils.make_grid(val_cleaned, padding=2, normalize=True))
        
        iters += 1
        reals_shown_now += _batch_size
        reals_shown.append(reals_shown_now)
        loss_list.append(loss.item())
        
        if flag_debug and iters>=2:
            break
    if flag_debug and iters>=2:
        break


[0/2550][0/58]	Loss_G: 0.7478


NameError: name 'batch_size' is not defined

In [None]:
print(ds_training)