In [1]:
## Training Script
# Run this to train a bone suppression network ResNet-BS, introduced by Rajaraman et al. 2021.
# 
import numpy as np
import pandas as pd
from pathlib import Path
import os, sys, datetime, time, random, fnmatch, math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import skimage.metrics
from sklearn.model_selection import KFold

import torch
from torchvision import transforms as tvtransforms
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler, ConcatDataset
import torchvision.utils as vutils
import torch.utils.tensorboard as tensorboard
import torch.nn as nn
import torch.optim as optim

import datasets, custom_transforms, RajaramanModel, pytorch_msssim

flag_debug = False
flag_load_previous_save = False

# Input Directories
#data_BSE = "D:/data/JSRT/augmented/train/target/"
#data_normal = "D:/data/JSRT/augmented/train/source/"
data_BSE = Path('G:/DanielLam/JSRT/HQ_JSRT_and_BSE-JSRT/177-20-20 split/trainValidate_LowRes/suppressed/')
data_normal = Path('G:/DanielLam/JSRT/HQ_JSRT_and_BSE-JSRT/177-20-20 split/trainValidate_LowRes/normal/')
#data_val_normal = 'G:/DanielLam/JSRT/HQ_JSRT_and_BSE-JSRT/177-20-20 split/validate/normal'
#data_val_BSE = 'G:/DanielLam/JSRT/HQ_JSRT_and_BSE-JSRT/177-20-20 split/validate/suppressed/'

# Save directories:
output_save_directory = Path("./runs/Rajaraman_ResNet/177-20-20/10KFold_4000perEpoch")
output_save_directory.mkdir(parents=True, exist_ok=True)
PATH_SAVE_NETWORK_INTERMEDIATE = os.path.join(output_save_directory, 'network_intermediate_{}.tar' )
PATH_SAVE_NETWORK_FINAL = os.path.join(output_save_directory, 'network_final_{}.pt')

# Image Size:
image_spatial_size = (480, 480)
_batch_size = 2
split_k_folds=10
sample_keys_images = ["source", "boneless"]

# Optimisation
lr_ini = 0.001
beta1 = 0.9
beta2 = 0.999

# Training
total_num_epochs_paper = 200
epoch_factor = 22.8764
total_num_epochs = int(epoch_factor*total_num_epochs_paper)

# Weight Initialisation
def weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight.data, 0., 0.02)
        #nn.init.kaiming_normal_(m.weight.data,0)
        try:
            nn.init.constant_(m.bias.data, 0.)
        except:
            pass
    if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
        if m.affine:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.)

## Code for putting things on the GPU
ngpu = 1 #torch.cuda.device_count()
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
device = torch.device("cuda" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print(device)
if (torch.cuda.is_available()):
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
    if ngpu ==1:
        device=torch.device('cuda:1')

cuda
GeForce RTX 2080 Ti


In [2]:
# Current date:
current_date=datetime.datetime.today().strftime('%Y-%m-%d')

# Data Loader
original_key = "source"
target_key = "boneless"
discriminator_keys_images = [original_key, target_key]

#ds_training = datasets.JSRT_CXR(data_normal, data_BSE,
#                         transform=tvtransforms.Compose([
#                             # custom_transforms.HistogramEqualisation(discriminator_keys_images),-- check if training data is already equalised
#                             custom_transforms.Resize(discriminator_keys_images, image_spatial_size),
#                             custom_transforms.ToTensor(discriminator_keys_images),
#                             ])
#                      )
#ds_val = datasets.JSRT_CXR(data_val_normal, data_val_BSE,
#                         transform=tvtransforms.Compose([
#                             #custom_transforms.HistogramEqualisation(discriminator_keys_images),
#                             custom_transforms.Resize(discriminator_keys_images, image_spatial_size),
#                             custom_transforms.ToTensor(discriminator_keys_images),
#                             ])
#                      )



In [3]:
# Rajaraman Loss
def criterion_Rajaraman(testImage, referenceImage, alpha=0.84):
    """
    Generates a loss over a mini-batch of images.
    Outputs scalars.
    Rajaraman, S., Zamzmi, G., Folio, L., Alderson, P., & Antani, S. (2021). Chest X-Ray Bone Suppression for Improving Classification of Tuberculosis-Consistent Findings. In Diagnostics (Vol. 11, Issue 5). https://doi.org/10.3390/diagnostics11050840
    """
    mae = nn.L1Loss() # L2 used for easier optimisation c.f. L1
    mae_loss = mae(testImage, referenceImage)
    msssim = pytorch_msssim.MSSSIM(window_size=11, size_average=True, channel=1, normalize='relu')
    msssim_loss = 1 - msssim(testImage, referenceImage)
    total_loss = (1-alpha)*mae_loss + alpha*msssim_loss
    return total_loss, mae_loss, msssim_loss

In [None]:
# Training

# Data augmentation transformations
data_transforms = {
    'train_transforms': tvtransforms.Compose([
    custom_transforms.RandomAutocontrast(sample_keys_images, cutoff_limits=(0.1,0.1)),
    custom_transforms.CenterCrop(sample_keys_images,256),
    custom_transforms.RandomHorizontalFlip(sample_keys_images, 0.5),
    custom_transforms.RandomAffine(sample_keys_images, degrees=10,translate=(0.1,0.1),scale=(0.9,1.1)),
    custom_transforms.ToTensor(sample_keys_images),
    custom_transforms.ImageComplement(sample_keys_images),
    ])
      ,
    'test_transforms': tvtransforms.Compose([
      custom_transforms.ToTensor(sample_keys_images),
    ]),
    }

# K-FOLD VALIDATION
splits=KFold(n_splits=split_k_folds,shuffle=True,random_state=42)
foldperf={}

print(target_key)
dataset = datasets.JSRT_CXR(data_normal, data_BSE, transform=None)
for fold, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(dataset)))):
    print('Fold {}'.format(fold + 1))
    history = {'loss':[], 'ssim_acc':[]}
    
    # Subset sample from dataset
    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(val_idx)
    dl_training = DataLoader(datasets.JSRT_CXR(data_normal, data_BSE, transform=data_transforms['train_transforms']),
                             batch_size=_batch_size, sampler=train_sampler, num_workers=0)
    dl_validation = DataLoader(datasets.JSRT_CXR(data_normal, data_BSE, transform=data_transforms['train_transforms']),
                             batch_size=_batch_size, sampler=test_sampler, num_workers=0)
    
    # Print example of transformed image
    for count, sample in enumerate(dl_training):
        image = sample["source"][0,:]
        image = torch.squeeze(image)
        plt.imshow(image)
        if count == 0:
            break

    ## Implementation of network and losses
    input_array_size = (_batch_size, 1, image_spatial_size[0], image_spatial_size[1])
    net = RajaramanModel.ResNet_BS(input_array_size)
    # Initialise weights
    net.apply(weights_init)

    # Multi-GPU
    if (device.type == 'cuda') and (ngpu > 1):
        print("Neural Net on GPU")
        net = nn.DataParallel(net, list(range(ngpu)))
    net = net.to(device)

    # Optimiser
    optimizer = optim.Adam(net.parameters(), lr=lr_ini, betas=(beta1, beta2))

    # Learning Rate Scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                     factor=0.5, patience=int(10*epoch_factor), 
                                                     threshold=0.0001, min_lr=0.00001, verbose=True)
    
    # For each epoch
    epochs_list = []
    img_list = []
    training_loss_list = []
    reals_shown = []
    #validation_loss_per_epoch_list = []
    #training_loss_per_epoch_list = []
    loss_per_epoch={"training":[], "validation":[]}
    ssim_average={"training":[], "validation":[]}


    # optionally resume from a checkpoint
    if flag_load_previous_save:
        if os.path.isfile(PATH_SAVE_NETWORK_INTERMEDIATE):
            print("=> loading checkpoint '{}'".format(PATH_SAVE_NETWORK_INTERMEDIATE))
            checkpoint = torch.load(PATH_SAVE_NETWORK_INTERMEDIATE)
            start_epoch = checkpoint['epoch_next']
            reals_shown_now = checkpoint['reals_shown']
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            net.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print("=> loaded checkpoint '{}' (epoch {}, reals shown {})".format(PATH_SAVE_NETWORK_INTERMEDIATE, 
                                                                                start_epoch, reals_shown_now))
            print(scheduler)
        else:
            print("=> NO CHECKPOINT FOUND AT '{}'" .format(PATH_SAVE_NETWORK_INTERMEDIATE))
            raise RuntimeError("No checkpoint found at specified path.")
    else:
        print("FLAG: NO CHECKPOINT LOADED.")
        reals_shown_now = 0
        start_epoch=0

    # Loop variables
    flag_break = False # when debugging, this will automatically go to True
    iters = 0
    net.train()
    for param in net.parameters():
        param.requires_grad = True
    for epoch in range(start_epoch, total_num_epochs):
        print(optimizer.param_groups[0]['lr'])
        sum_loss_in_epoch = 0
        for i, data in enumerate(dl_training):
            # Training
            net.zero_grad()
            noisy_data = data[original_key].to(device)
            cleaned_data = net(noisy_data)
            loss, maeloss, msssim_loss = criterion_Rajaraman(cleaned_data, data[target_key].to(device))
            loss.backward() # calculate gradients
            optimizer.step() # optimiser step along gradients

            # Output training stats
            if epoch % 1000 == 0:
                print('[%d/%d][%d/%d]\tTotal Loss: %.4f\tMAELoss: %.4f\tMSSSIM Loss: %.4f'
                      % (epoch, total_num_epochs, i, len(dl_training),
                         loss.item(), maeloss.item(), msssim_loss.item()))
            # Record generator output
            #if reals_shown_now%(100*_batch_size)==0:
            #    with torch.no_grad():
            #        val_cleaned = net(fixed_val_sample["source"].to(device)).detach().cpu()
            #    print("Printing to img_list")
            #    img_list.append(vutils.make_grid(val_cleaned, padding=2, normalize=True))
            #iters +=1
            #if flag_debug and iters>=10:
            #    flag_break = True
            #    break

            # Running counter of reals shown
            reals_shown_now += _batch_size
            reals_shown.append(reals_shown_now)

            # Training loss list for loss per minibatch
            training_loss_list.append(loss.item()) # training loss
            sum_loss_in_epoch += loss.item()*len(cleaned_data)
        # Turn the sum loss in epoch into a loss-per-epoch
        loss_per_epoch["training"].append(sum_loss_in_epoch/len(train_idx))


        with torch.no_grad():
            # Training Accuracy per EPOCH
            ssim_training_list = []
            for train_count, data in enumerate(dl_training):
                noisy_training_data = data[original_key].to(device)
                true_training_data = data[target_key]
                cleaned_training_data = net(noisy_training_data)

                for ii, image in enumerate(cleaned_training_data):
                    clean_training_numpy = image.cpu().detach().numpy()
                    true_training_numpy = true_training_data[ii].numpy()
                    clean_training_numpy = np.moveaxis(clean_training_numpy, 0, -1)
                    true_training_numpy = np.moveaxis(true_training_numpy, 0, -1)
                    ssim_training = skimage.metrics.structural_similarity(clean_training_numpy, true_training_numpy, multichannel=True)
                    ssim_training_list.append(ssim_training) # SSIM per image
            ssim_average["training"].append(np.mean(ssim_training_list))

            # Validation Loss and Accuracy per EPOCH
            sum_loss_in_epoch =0
            ssim_val_list = []
            for val_count, sample in enumerate(dl_validation):
                noisy_val_data = sample[original_key].to(device)
                cleaned_val_data = net(noisy_val_data)

                # Loss
                true_val_data = sample[target_key]
                val_loss, maeloss, msssim_loss = criterion_Rajaraman(cleaned_val_data, true_val_data.to(device))
                sum_loss_in_epoch += val_loss.item()*len(cleaned_val_data)

                # Accuracy
                for ii, image in enumerate(cleaned_val_data):
                    clean_val_numpy = image.cpu().detach().numpy()
                    true_val_numpy = true_val_data[ii].numpy()
                    clean_val_numpy = np.moveaxis(clean_val_numpy, 0, -1)
                    true_val_numpy = np.moveaxis(true_val_numpy, 0, -1)
                    ssim_val = skimage.metrics.structural_similarity(clean_val_numpy, true_val_numpy, multichannel=True)
                    ssim_val_list.append(ssim_val) # SSIM per image
            # After considering all validation images
            loss_per_epoch["validation"].append(sum_loss_in_epoch/len(val_idx))
            ssim_average["validation"].append(np.mean(ssim_val_list))
        epochs_list.append(epoch)
        # LR Scheduler after epoch
        scheduler.step(val_loss)

        # Save the network in indications
        #if epoch % 5 == 0:
        #    if not flag_debug:
        #        #torch.save(net.state_dict(), PATH_SAVE_NETWORK_INTERMEDIATE)
        #        torch.save({
        #        'epochs_completed': epoch+1,
        #        'epoch_next': epoch+1,
        #        'model_state_dict': net.state_dict(),
        #        'optimizer_state_dict': optimizer.state_dict(),
        #        'loss': loss,
        #        'scheduler_state_dict': scheduler.state_dict(),
        #        'reals_shown': reals_shown_now
        #        }, PATH_SAVE_NETWORK_INTERMEDIATE)
        #        print("Saved Intermediate: "+ str(PATH_SAVE_NETWORK_INTERMEDIATE))
        #if flag_break:
        #    break
        
    #After all epochs, save results:
    history["loss"] = loss_per_epoch
    history["ssim_acc"] = ssim_average
    
    foldperf['fold{}'.format(fold+1)] = history
    
    # Final Save for each fold
    if not flag_debug:
        torch.save({
                'epochs_completed': epoch+1,
                'epoch_next': epoch+1,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                'scheduler_state_dict': scheduler.state_dict(),
                'reals_shown': reals_shown_now
                }, PATH_SAVE_NETWORK_INTERMEDIATE.format(fold+1))
        print("Saved Intermediate: "+ str(PATH_SAVE_NETWORK_INTERMEDIATE.format(fold+1)))
    

boneless
Fold 1
FLAG: NO CHECKPOINT LOADED.
0.001
[0/4575][0/89]	Total Loss: 0.7398	MAELoss: 0.3184	MSSSIM Loss: 0.8201
[0/4575][1/89]	Total Loss: 0.6066	MAELoss: 0.2979	MSSSIM Loss: 0.6654
[0/4575][2/89]	Total Loss: 0.5838	MAELoss: 0.2955	MSSSIM Loss: 0.6387
[0/4575][3/89]	Total Loss: 0.4591	MAELoss: 0.2504	MSSSIM Loss: 0.4989
[0/4575][4/89]	Total Loss: 0.2803	MAELoss: 0.2146	MSSSIM Loss: 0.2928
[0/4575][5/89]	Total Loss: 0.3430	MAELoss: 0.3673	MSSSIM Loss: 0.3383
[0/4575][6/89]	Total Loss: 0.2927	MAELoss: 0.3878	MSSSIM Loss: 0.2746
[0/4575][7/89]	Total Loss: 0.2081	MAELoss: 0.3374	MSSSIM Loss: 0.1834
[0/4575][8/89]	Total Loss: 0.2266	MAELoss: 0.1683	MSSSIM Loss: 0.2376
[0/4575][9/89]	Total Loss: 0.2486	MAELoss: 0.1938	MSSSIM Loss: 0.2590
[0/4575][10/89]	Total Loss: 0.2459	MAELoss: 0.1688	MSSSIM Loss: 0.2605
[0/4575][11/89]	Total Loss: 0.1661	MAELoss: 0.1495	MSSSIM Loss: 0.1692
[0/4575][12/89]	Total Loss: 0.1821	MAELoss: 0.0972	MSSSIM Loss: 0.1983
[0/4575][13/89]	Total Loss: 0.1458	MA

0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.001
0.00

[1000/4575][42/89]	Total Loss: 0.0098	MAELoss: 0.0091	MSSSIM Loss: 0.0099
[1000/4575][43/89]	Total Loss: 0.0071	MAELoss: 0.0085	MSSSIM Loss: 0.0068
[1000/4575][44/89]	Total Loss: 0.0059	MAELoss: 0.0062	MSSSIM Loss: 0.0059
[1000/4575][45/89]	Total Loss: 0.0097	MAELoss: 0.0126	MSSSIM Loss: 0.0091
[1000/4575][46/89]	Total Loss: 0.0152	MAELoss: 0.0200	MSSSIM Loss: 0.0143
[1000/4575][47/89]	Total Loss: 0.0073	MAELoss: 0.0062	MSSSIM Loss: 0.0076
[1000/4575][48/89]	Total Loss: 0.0077	MAELoss: 0.0086	MSSSIM Loss: 0.0075
[1000/4575][49/89]	Total Loss: 0.0103	MAELoss: 0.0093	MSSSIM Loss: 0.0105
[1000/4575][50/89]	Total Loss: 0.0056	MAELoss: 0.0053	MSSSIM Loss: 0.0056
[1000/4575][51/89]	Total Loss: 0.0064	MAELoss: 0.0070	MSSSIM Loss: 0.0063
[1000/4575][52/89]	Total Loss: 0.0067	MAELoss: 0.0062	MSSSIM Loss: 0.0068
[1000/4575][53/89]	Total Loss: 0.0112	MAELoss: 0.0105	MSSSIM Loss: 0.0113
[1000/4575][54/89]	Total Loss: 0.0132	MAELoss: 0.0129	MSSSIM Loss: 0.0133
[1000/4575][55/89]	Total Loss: 0.0078	

6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6.25e-05
6

[2000/4575][49/89]	Total Loss: 0.0059	MAELoss: 0.0081	MSSSIM Loss: 0.0055
[2000/4575][50/89]	Total Loss: 0.0076	MAELoss: 0.0078	MSSSIM Loss: 0.0076
[2000/4575][51/89]	Total Loss: 0.0055	MAELoss: 0.0053	MSSSIM Loss: 0.0055
[2000/4575][52/89]	Total Loss: 0.0102	MAELoss: 0.0093	MSSSIM Loss: 0.0104
[2000/4575][53/89]	Total Loss: 0.0100	MAELoss: 0.0104	MSSSIM Loss: 0.0100
[2000/4575][54/89]	Total Loss: 0.0076	MAELoss: 0.0091	MSSSIM Loss: 0.0073
[2000/4575][55/89]	Total Loss: 0.0070	MAELoss: 0.0083	MSSSIM Loss: 0.0067
[2000/4575][56/89]	Total Loss: 0.0080	MAELoss: 0.0100	MSSSIM Loss: 0.0076
[2000/4575][57/89]	Total Loss: 0.0135	MAELoss: 0.0252	MSSSIM Loss: 0.0113
[2000/4575][58/89]	Total Loss: 0.0058	MAELoss: 0.0064	MSSSIM Loss: 0.0057
[2000/4575][59/89]	Total Loss: 0.0088	MAELoss: 0.0113	MSSSIM Loss: 0.0084
[2000/4575][60/89]	Total Loss: 0.0072	MAELoss: 0.0084	MSSSIM Loss: 0.0070
[2000/4575][61/89]	Total Loss: 0.0077	MAELoss: 0.0090	MSSSIM Loss: 0.0075
[2000/4575][62/89]	Total Loss: 0.0061	

1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-0

[3000/4575][82/89]	Total Loss: 0.0069	MAELoss: 0.0070	MSSSIM Loss: 0.0069
[3000/4575][83/89]	Total Loss: 0.0042	MAELoss: 0.0042	MSSSIM Loss: 0.0041
[3000/4575][84/89]	Total Loss: 0.0075	MAELoss: 0.0090	MSSSIM Loss: 0.0073
[3000/4575][85/89]	Total Loss: 0.0090	MAELoss: 0.0079	MSSSIM Loss: 0.0091
[3000/4575][86/89]	Total Loss: 0.0118	MAELoss: 0.0119	MSSSIM Loss: 0.0118
[3000/4575][87/89]	Total Loss: 0.0054	MAELoss: 0.0072	MSSSIM Loss: 0.0051
[3000/4575][88/89]	Total Loss: 0.0051	MAELoss: 0.0059	MSSSIM Loss: 0.0050
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e-05
1e

In [None]:
print("Training complete")
# Find the best fold


# AVERAGE PERFORMANCE
testl_f,tl_f,testa_f,ta_f=[],[],[],[]
k=10
for f in range(1,k+1):
    tl_f.append(np.mean(foldperf['fold{}'.format(f)]['loss']['training']))
    testl_f.append(np.mean(foldperf['fold{}'.format(f)]['loss']['validation']))

    ta_f.append(np.mean(foldperf['fold{}'.format(f)]['ssim_acc']['training']))
    testa_f.append(np.mean(foldperf['fold{}'.format(f)]['ssim_acc']['validation']))

print('Performance of {} fold cross validation'.format(k))
print("Average Training Loss: {:.3f} \t Average Test Loss: {:.3f} \t Average Training Acc: {:.2f} \t Average Test Acc: {:.2f}"
      .format(np.mean(tl_f),np.mean(testl_f),np.mean(ta_f),np.mean(testa_f)))
print("Best fold: {}".format(np.argmax(testa_f)+1))

# Averaging accuracy and loss
diz_ep = {'train_loss_ep':[],'test_loss_ep':[],'train_acc_ep':[],'test_acc_ep':[]}
for i in range(total_num_epochs):
    diz_ep['train_loss_ep'].append(np.mean([foldperf['fold{}'.format(f+1)]['loss']['training'][i] for f in range(k)]))
    diz_ep['test_loss_ep'].append(np.mean([foldperf['fold{}'.format(f+1)]['loss']['validation'][i] for f in range(k)]))
    diz_ep['train_acc_ep'].append(np.mean([foldperf['fold{}'.format(f+1)]['ssim_acc']['training'][i] for f in range(k)]))
    diz_ep['test_acc_ep'].append(np.mean([foldperf['fold{}'.format(f+1)]['ssim_acc']['validation'][i] for f in range(k)]))

# Plot losses
plt.figure(figsize=(10,8))
plt.semilogy(diz_ep['train_loss_ep'], label='Train')
plt.semilogy(diz_ep['test_loss_ep'], label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
#plt.grid()
plt.legend()
if not flag_debug:
    plt.savefig(os.path.join(output_save_directory, current_date + "_loss"+".png"))
plt.show()

# Plot accuracies
plt.figure(figsize=(10,8))
plt.semilogy(diz_ep['train_acc_ep'], label='Train')
plt.semilogy(diz_ep['test_acc_ep'], label='Validation')
plt.xlabel('Epoch')
plt.ylabel('SSIM')
#plt.grid()
plt.legend()
if not flag_debug:
    plt.savefig(os.path.join(output_save_directory, current_date + "_accuracy"+".png"))
plt.show()

In [None]:
%matplotlib inline
import math
current_date=datetime.datetime.today().strftime('%Y-%m-%d')

# Accuracy
plt.figure(figsize=(10,5))
plt.title("SSIM")
plt.plot(epochs_list, ssim_average["training"], label='training')
plt.plot(epochs_list, ssim_average["validation"] , label='validation')
plt.xlabel("Epochs")
plt.ylabel("SSIM")
plt.legend()
if not flag_debug:
    plt.savefig(os.path.join(output_save_directory, current_date + "_accuracy"+".png"))
    
# All losses
plt.figure(figsize=(10,5))
plt.title("Losses")
plt.plot(epochs_list, loss_per_epoch["training"], label='training')
plt.plot(epochs_list, loss_per_epoch["validation"] , label='validation')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
if not flag_debug:
    plt.savefig(os.path.join(output_save_directory, current_date + "_loss"+".png"))

# Loss for training
plt.figure(figsize=(10,5))
plt.title("Loss during training")
plt.plot(reals_shown, [math.log10(y) for y in training_loss_list])
plt.xlabel("reals_shown")
plt.ylabel("Training Log10(Loss)")
if not flag_debug:
    plt.savefig(os.path.join(output_save_directory, current_date + "_training_loss"+".png"))


    
# Final Model:
with torch.no_grad():
    input_image = fixed_val_sample['source']
    input_images = vutils.make_grid(input_image[0:1,:,:,:], padding=2, normalize=True)
    target_images = vutils.make_grid(fixed_val_sample['boneless'][0:1,:,:,:], padding=2, normalize=True)
    net = net.cpu()
    output_image = net(input_image[0:1,:,:,:]).detach().cpu()
    output_images = vutils.make_grid(output_image, padding=2, normalize=True)
print(str(torch.max(output_images)) + "," + str(torch.min(output_images)))
plt.figure(1)
fig, ax = plt.subplots(1,3, figsize=(15,15))
ax[0].imshow(np.transpose(input_images, (1,2,0)), vmin=0, vmax=1)
ax[0].set_title("Source")
ax[0].axis("off")
ax[1].imshow(np.transpose(output_images, (1,2,0)), vmin=0, vmax=1)
ax[1].set_title("Suppressed")
ax[1].axis("off")
ax[2].imshow(np.transpose(target_images, (1,2,0)), vmin=0, vmax=1)
ax[2].set_title("Ideal Bone-suppressed")
ax[2].axis("off")
plt.show
if not flag_debug:
    plt.savefig(os.path.join(output_save_directory, current_date + "_validation_ComparisonImages"+".png"))

# ANIMATED VALIDATION IMAGE
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)
plt.axis("off")
ims = []
training_ims_shown = []
for i, im in enumerate(img_list):
    if i % 50 == 0:  # controls how many images are printed into the animation
        training_ims_shown = i*(100*_batch_size)
        frame = ax.imshow(np.transpose(im,(1,2,0)))
        t = ax.annotate("Reals shown: {}".format(training_ims_shown), (0.5,1.02), xycoords="axes fraction")
        ims.append([frame, t])
ani = animation.ArtistAnimation(fig, ims, interval=200, repeat_delay=1000, blit=True)
if not flag_debug:
    ani.save(os.path.join(output_save_directory, current_date+"_animation.mp4"), dpi=300)

In [None]:
clean_matrix = torch.from_numpy(np.random.rand(15,1,256,256)).float()
true_matrix = torch.from_numpy(np.random.rand(15,1,256,256)).float()
print(criterion_Rajaraman(clean_matrix, true_matrix, alpha=0.84))
print(len(true_matrix))
print(len(ds_training))

In [None]:
ssim_average["validation"]