In [None]:
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']="2"

In [None]:
import io
import time
import torch
import numpy as np
from tqdm.notebook import tqdm
from contextlib import redirect_stdout

os.chdir("../")

from models.BCDnet import BCDnet
from models.cnet import Cnet
from datasets import WSSBDatasetTest
from utils import od2rgb, rgb2od, random_ruifrok_matrix_variation, direct_deconvolution, peak_signal_noise_ratio, structural_similarity, askforPyTorchWeightsviaGUI

In [None]:
NUM_ITERATIONS = 4000

SAVE_WEIGHTS = False
RUN_FROM_WEIGHTS = False
START_FROM_IMAGE_ITSELF = False

APPROACH_USED = 'bcdnet_e1' # bcdnet_e1, bcdnet_e2, cnet_e2
BATCH_SIZE = 1  # Should always be 1
SIGMA_RUI_SQ = 0.05 # Prior hematoxylin/eosin variance of M
LEARNING_RATE = 1e-4
THETA_VAL = 0.5 # Ponderation for each kind of loss in the total loss of BCDNET_E2
THETA_VAL_COLORITER = 0.99 # Ponderation of loss_kl during pre-training in BCDNET_E3
COLORITER = 1000 # Number of iterations for pre-training in BCDNET_E3

ORGAN_LIST = ['Colon', 'Breast', 'Lung']
RUNNING_ON_DELFOS = True

In [None]:
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device == 'cuda':    # Try to improve speed as images are always the same size
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
print('Using device:', device)

In [None]:
alsubaie_dataset_path = '/home/modej/Alsubaie_500x500' if RUNNING_ON_DELFOS else '../Alsubaie_500x500'

metrics_dict = { 
    'epoch' : 0, 'loss': 0.0, 
    'mse_rec' : 0.0, 'psnr_rec': 0.0, 'ssim_rec': 0.0,
    'mse_gt_h': 0.0, 'mse_gt_e': 0.0, 'mse_gt': 0.0,
    'psnr_gt_h': 0.0, 'psnr_gt_e': 0.0, 'psnr_gt': 0.0,
    'ssim_gt_h': 0.0, 'ssim_gt_e': 0.0, 'ssim_gt': 0.0, 'time': 0.0    
}


if APPROACH_USED in ['bcdnet_e2', 'bcdnet_e3', 'bcdnet_e4L1', 'bcdnet_e4L2']:
    metrics_dict['loss_rec'] = 0.0
    metrics_dict['loss_kl'] = 0.0
    if APPROACH_USED == 'bcdnet_e4L1':
        metrics_dict['loss_l1'] = 0.0
    elif APPROACH_USED == 'bcdnet_e4L2':
        metrics_dict['loss_l2'] = 0.0

In [None]:
intermediate_folder = f'{APPROACH_USED}_lr{LEARNING_RATE}'
if APPROACH_USED == 'bcdnet_e2':
    intermediate_folder += f'_theta{THETA_VAL}'
elif APPROACH_USED == 'bcdnet_e3':
    intermediate_folder += f'_theta{THETA_VAL}_thetacolor_{THETA_VAL_COLORITER}_coloriters{COLORITER}'
intermediate_folder += '_fromimage' if START_FROM_IMAGE_ITSELF else '_fromnoise'
if RUN_FROM_WEIGHTS:
    intermediate_folder += '_fromweights'

folder_route = f'../../results/{intermediate_folder}/batch_training/'
if not os.path.exists(folder_route):
    os.makedirs(folder_route)

if not os.path.exists(folder_route + '/metrics'):
    os.makedirs(folder_route + '/metrics')

if RUN_FROM_WEIGHTS and APPROACH_USED != 'cnet_e2':
    pretrained_weights_filepath = askforPyTorchWeightsviaGUI()

In [None]:
ruifrok_matrix = torch.tensor([
                    [0.6442, 0.0928],
                    [0.7166, 0.9541],
                    [0.2668, 0.2831]
                    ]).type(torch.float32)
ruifrok_matrix = ruifrok_matrix.repeat(BATCH_SIZE, 1, 1).to(device)  # (batch_size, 3, 2)

In [None]:
for organ in tqdm(ORGAN_LIST, desc="Organs", unit="organ"):
    # One dataset for each organ would help to keep track of the results
    dataset = WSSBDatasetTest(alsubaie_dataset_path, organ_list=[organ], load_at_init=False)

    # Train the model and evaluate for each image
    for index, (image, M_gt) in tqdm(enumerate(dataset), desc="Images", unit="image", leave=False):
        image = image.to(device)
        img_od = rgb2od(image)

        if 'bcdnet' in APPROACH_USED:
            model = BCDnet(cnet_name='unet_64_6', mnet_name='mobilenetv3s_50').to(device)
        elif 'cnet' in APPROACH_USED:
            model = Cnet().to(device)
        else:
            raise Exception('Approach not found.')
        
        if RUN_FROM_WEIGHTS:
            model.load_state_dict(torch.load(pretrained_weights_filepath, map_location=device))

        optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
        with redirect_stdout(io.StringIO()):
            model.train()

        if START_FROM_IMAGE_ITSELF:
            input = img_od.unsqueeze(0).to(device)
        else:
            input = torch.rand(image.shape).unsqueeze(0).to(device)  # Unsqueezed to add the batch dimension, it needs to be ([1, 3, x, y])

        original_tensor = image.unsqueeze(0).to(device)
        original_tensor_od = rgb2od(original_tensor).to(device)
        _, _, H, W = original_tensor.shape
        M_gt = M_gt.to(device)

        # Create the metrics file and fill the header
        metrics_filepath = folder_route + f"/metrics/{organ}_{index}.csv"
        with(open(metrics_filepath, 'w')) as file:
            header = ','.join(metrics_dict.keys()) + '\n'
            file.write(header)

        # Generate the ground truth images
        C_gt = direct_deconvolution(rgb2od(image), M_gt).unsqueeze(0) # (1, 2, H, W)
        M_gt = M_gt.unsqueeze(0) # (1, 3, 2)

        gt_od = torch.einsum('bcs,bshw->bschw', M_gt, C_gt)
        H_gt_od = gt_od[:,0,:,:]
        H_gt = torch.clamp(od2rgb(H_gt_od), 0.0, 255.0).to(device) # (batch_size, 3, H, W)
        E_gt_od = gt_od[:,1,:,:]
        E_gt = torch.clamp(od2rgb(E_gt_od), 0.0, 255.0).to(device) # (batch_size, 3, H, W)
        H_gt_od = H_gt_od.to(device)
        E_gt_od = E_gt_od.to(device)          

        for iteration in tqdm(range(1, NUM_ITERATIONS+1), desc="Iterations", unit="iteration", leave=False):

            start_time = time.time()
            optimizer.zero_grad()

            if 'bcdnet' in APPROACH_USED:
                # Using BCDnet we obtain both the concentration matrix and the colors matrix as well as the colors' variation
                M_matrix, M_variation, C_matrix = model(input)

            elif APPROACH_USED == 'cnet_e2':
                # Using Cnet we just obtain the concentration matrix
                C_matrix = model(input)

                # Generate the colors matrix as a sample of a gaussian distribution given the Ruifrok matrix
                h_matrix, e_matrix = random_ruifrok_matrix_variation(args.sigma_rui_sq)
                M_matrix = np.concatenate((h_matrix,e_matrix),axis=1)                   # ([1, 3, 2])
                M_matrix = torch.from_numpy(M_matrix).float().unsqueeze(0).to(device)   # ([1, 2, x, y])

            # Generate the 3 channels image and get it back to RGB
            reconstructed_od = torch.einsum('bcs,bshw->bchw', M_matrix, C_matrix)   # ([1, 3, x, y])
            reconstructed = torch.clamp(od2rgb(reconstructed_od), 0, 255.0)

            if APPROACH_USED in ['bcdnet_e1', 'cnet_e2']:
                loss = torch.nn.functional.mse_loss(reconstructed_od, original_tensor_od)

            elif APPROACH_USED == 'bcdnet_e2':
                M_variation = M_variation.repeat(1, 3, 1)   # (batch_size, 3, 2)
                # Calculate the Kullback-Leiber divergence via its closed form
                loss_kl = (0.5 / SIGMA_RUI_SQ) * torch.nn.functional.mse_loss(M_matrix, ruifrok_matrix, reduction='none') + 1.5 * (M_variation / SIGMA_RUI_SQ - torch.log(M_variation / SIGMA_RUI_SQ) - 1) # (batch_size, 3, 2)
                loss_kl = torch.sum(loss_kl) / BATCH_SIZE # (1)
                # Re-parametrization trick to sample from the gaussian distribution
                M_sample = M_matrix + torch.sqrt(M_variation) * torch.randn_like(M_matrix) # (batch_size, 3, 2)

                Y_rec = torch.einsum('bcs,bshw->bchw', M_sample, C_matrix) # (batch_size, 3, H, W)
                loss_rec = torch.sum(torch.nn.functional.mse_loss(Y_rec, original_tensor_od)) / BATCH_SIZE # (1)

                loss = (1.0 - THETA_VAL)*loss_rec + THETA_VAL*loss_kl

                metrics_dict['loss_rec'] = (1.0 - THETA_VAL)*loss_rec.item()
                metrics_dict['loss_kl'] = THETA_VAL*loss_kl.item()

            elif APPROACH_USED == 'bcdnet_e3':
                M_variation = M_variation.repeat(1, 3, 1)   # (batch_size, 3, 2)
                # Calculate the Kullback-Leiber divergence via its closed form
                loss_kl = (0.5 / SIGMA_RUI_SQ) * torch.nn.functional.mse_loss(M_matrix, ruifrok_matrix, reduction='none') + 1.5 * (M_variation / SIGMA_RUI_SQ - torch.log(M_variation / SIGMA_RUI_SQ) - 1) # (batch_size, 3, 2)
                loss_kl = torch.sum(loss_kl) / BATCH_SIZE # (1)
                # Re-parametrization trick to sample from the gaussian distribution
                M_sample = M_matrix + torch.sqrt(M_variation) * torch.randn_like(M_matrix) # (batch_size, 3, 2)

                Y_rec = torch.einsum('bcs,bshw->bchw', M_sample, C_matrix) # (batch_size, 3, H, W)
                loss_rec = torch.sum(torch.nn.functional.mse_loss(Y_rec, original_tensor_od)) / BATCH_SIZE # (1)

                if iteration < COLORITER:
                    loss = (1.0 - THETA_VAL_COLORITER)*loss_rec + THETA_VAL_COLORITER*loss_kl
                    metrics_dict['loss_rec'] = (1.0 - THETA_VAL_COLORITER)*loss_rec.item()
                    metrics_dict['loss_kl'] = THETA_VAL_COLORITER*loss_kl.item()

                else:
                    loss = (1.0 - THETA_VAL)*loss_rec + THETA_VAL*loss_kl
                    metrics_dict['loss_rec'] = (1.0 - THETA_VAL)*loss_rec.item()
                    metrics_dict['loss_kl'] = THETA_VAL*loss_kl.item()

            elif 'bcdnet_e4' in APPROACH_USED:
                if APPROACH_USED == 'bcdnet_e4L1':
                    l_norm = torch.sum(torch.abs(C_matrix), dim=1).sum().item() / (H*W)
                    metrics_dict['loss_l1'] = l_norm
                elif APPROACH_USED == 'bcdnet_e4L2':
                    l_norm = torch.sqrt(torch.sum(C_matrix ** 2, dim=1)).sum().item() / (H*W)
                    metrics_dict['loss_l2'] = l_norm

                M_variation = M_variation.repeat(1, 3, 1)   # (batch_size, 3, 2)
                # Calculate the Kullback-Leiber divergence via its closed form
                loss_kl = (0.5 / SIGMA_RUI_SQ) * torch.nn.functional.mse_loss(M_matrix, ruifrok_matrix, reduction='none') + 1.5 * (M_variation / SIGMA_RUI_SQ - torch.log(M_variation / SIGMA_RUI_SQ) - 1) # (batch_size, 3, 2)
                loss_kl = torch.sum(loss_kl) / BATCH_SIZE # (1)
                # Re-parametrization trick to sample from the gaussian distribution
                M_sample = M_matrix + torch.sqrt(M_variation) * torch.randn_like(M_matrix) # (batch_size, 3, 2)

                Y_rec = torch.einsum('bcs,bshw->bchw', M_sample, C_matrix) # (batch_size, 3, H, W)
                loss_rec = torch.sum(torch.nn.functional.mse_loss(Y_rec, original_tensor_od)) / BATCH_SIZE # (1)

                loss = loss_rec + loss_kl + l_norm

                metrics_dict['loss_rec'] = loss_rec.item()
                metrics_dict['loss_kl'] = loss_kl.item()

            loss.backward()
            optimizer.step()

            # Calculate general metrics and reconstruction metrics
            metrics_dict['time'] = ((time.time() - start_time) * 1000.0)  # Milliseconds
            metrics_dict['epoch'] = iteration
            metrics_dict['loss'] = loss.item()
            metrics_dict['mse_rec'] = torch.sum(torch.pow(reconstructed_od - rgb2od(original_tensor), 2)).item() / (3.0*H*W)
            metrics_dict['psnr_rec'] = torch.sum(peak_signal_noise_ratio(reconstructed, original_tensor)).item()
            metrics_dict['ssim_rec'] = torch.sum(structural_similarity(reconstructed, original_tensor)).item()

            # Generate the images from the model
            C_mean = C_matrix.detach().cpu()

            rec_od = torch.einsum('bcs,bshw->bschw', M_matrix, C_matrix)
            H_rec_od = rec_od[:,0,:,:]
            H_rec = torch.clamp(od2rgb(H_rec_od), 0.0, 255.0)
            E_rec_od = rec_od[:,1,:,:]
            E_rec = torch.clamp(od2rgb(E_rec_od), 0.0, 255.0)

            # Calculate the metrics comparing with the ground truth
            metrics_dict['mse_gt_h'] = torch.sum(torch.pow(H_gt_od - H_rec_od, 2)).item() / (3.0*H*W)
            metrics_dict['mse_gt_e'] = torch.sum(torch.pow(E_gt_od - E_rec_od, 2)).item() / (3.0*H*W)

            metrics_dict['psnr_gt_h'] = torch.sum(peak_signal_noise_ratio(H_gt, H_rec)).item()
            metrics_dict['psnr_gt_e'] = torch.sum(peak_signal_noise_ratio(E_gt, E_rec)).item()

            metrics_dict['ssim_gt_h'] = torch.sum(structural_similarity(H_gt, H_rec)).item()
            metrics_dict['ssim_gt_e'] = torch.sum(structural_similarity(E_gt, E_rec)).item()

            metrics_dict['mse_gt'] = (metrics_dict['mse_gt_h'] + metrics_dict['mse_gt_e'])/2.0
            metrics_dict['psnr_gt'] = (metrics_dict['psnr_gt_h'] + metrics_dict['psnr_gt_e'])/2.0
            metrics_dict['ssim_gt'] = (metrics_dict['ssim_gt_h'] + metrics_dict['ssim_gt_e'])/2.0

            with open(metrics_filepath, mode='a') as file:
                data_row = ','.join(str(val) for val in metrics_dict.values()) + '\n'
                file.write(data_row)

        # Save weights at the end in case we want to train further from this point.
        if SAVE_WEIGHTS:
            save_weights_folderpath = f'{folder_route}/weights'
            if not os.path.exists(save_weights_folderpath):
                os.makedirs(save_weights_folderpath)
            save_weights_filepath = save_weights_folderpath + f"/{organ}_{index}_iteration_{NUM_ITERATIONS}.pt"
            if os.path.exists(save_weights_filepath):
                new_number_iterations = int(save_weights_filepath.split('_')[-1].split('.')[0]) + NUM_ITERATIONS
                save_weights_filepath = folder_route + f"/weights/{organ}_{index}_iteration_{new_number_iterations}.pt"
            torch.save(model.state_dict(), save_weights_filepath)

    # Trying to reduce wasted memory
    torch.cuda.empty_cache()
