In [None]:
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']="2"

In [None]:
import time
import torch
import numpy as np

os.chdir("..")

from models import Cnet
from datasets import WSSBDatasetTest
from utils import od2rgb, rgb2od, random_ruifrok_matrix_variation, direct_deconvolution, peak_signal_noise_ratio, structural_similarity

In [None]:
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [None]:
alsubaie_dataset_path = '/data/datasets/Alsubaie/Data' # Running on Delfos, not viable in local
organ_list = ['Lung', 'Breast']  # ['Lung', 'Breast', 'Colon']
# 4 lung images; 6 breast images; 14 colon images

metrics_dict = { 
    'epoch' : 0, 'loss': 0.0, 'time': 0.0,
    'mse_rec' : 0.0, 'psnr_rec': 0.0, 'ssim_rec': 0.0,
    'mse_gt_h': 0.0, 'mse_gt_e': 0.0, 'mse_gt': 0.0,
    'psnr_gt_h': 0.0, 'psnr_gt_e': 0.0, 'psnr_gt': 0.0,
    'ssim_gt_h': 0.0, 'ssim_gt_e': 0.0, 'ssim_gt': 0.0     
}

NUM_ITERATIONS = 4000

# Create the file for the metrics and fill the header
metrics_folder = f"/home/modej/Deep_Var_BCD/results/dip/cnet_e1/"
if not os.path.exists(metrics_folder):
    os.makedirs(metrics_folder)

In [None]:
for organ in organ_list:
    # One dataset for each organ would help to keep track of the results
    dataset = WSSBDatasetTest(alsubaie_dataset_path, organ_list=[organ], load_at_init=False)

    # Train the model and evaluate for each image
    for index, (image, M_gt) in enumerate(dataset):

        print(f"Organ: {organ} \t Image: {index}")
        
        # Input noise and initializations
        model = Cnet().to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
        input_noise = torch.rand(image.shape).unsqueeze(0).to(device)
        image_as_tensor = image.unsqueeze(0).to(device)
        _, _, H, W = image_as_tensor.shape
        M_gt = M_gt.unsqueeze(0)

        metrics_filepath = metrics_folder + f"{organ}_{index}.csv"
        with(open(metrics_filepath, 'w')) as file:
            header = ','.join(metrics_dict.keys()) + '\n'
            file.write(header)

        for epoch in range(NUM_ITERATIONS+1):

            start_time = time.time()
            optimizer.zero_grad()

            C_matrix = model(input_noise)

            # Generate the colors matrix as a sample of a gaussian distribution given the Ruifrok matrix
            h_matrix, e_matrix = random_ruifrok_matrix_variation(0.05)
            M_matrix = np.concatenate((h_matrix,e_matrix),axis=1)                   # ([1, 3, 2])
            M_matrix = torch.from_numpy(M_matrix).float().unsqueeze(0).to(device)   # ([1, 2, x, y])

            # Generate the 3 channels image and get it back to RGB
            reconstructed_od = torch.einsum('bcs,bshw->bchw', M_matrix, C_matrix)   # ([1, 3, x, y])
            reconstructed = torch.clamp(od2rgb(reconstructed_od), 0, 255.0)
            loss = torch.nn.functional.mse_loss(reconstructed, image_as_tensor)
            loss.backward()
            optimizer.step()

            execution_time = (time.time() - start_time) * 1000.0    # Milliseconds

            metrics_dict['epoch'] = epoch+1
            metrics_dict['loss'] = loss.item()
            metrics_dict['time'] = execution_time
            metrics_dict['mse_rec'] = torch.sum(torch.pow(reconstructed_od - rgb2od(image_as_tensor), 2)).item() / (3.0*H*W)    
            metrics_dict['psnr_rec'] = torch.sum(peak_signal_noise_ratio(reconstructed, image_as_tensor)).item()
            metrics_dict['ssim_rec'] = torch.sum(structural_similarity(reconstructed, image_as_tensor)).item()

            # Generate the ground truth images
            C_gt = direct_deconvolution(rgb2od(image), M_gt).unsqueeze(0) # (1, 2, H, W)

            H_gt_od = torch.einsum('bcs,bshw->bschw', M_gt, C_gt)[:,0,:,:] # (batch_size, H, W)
            H_gt = torch.clamp(od2rgb(H_gt_od), 0.0, 255.0) # (batch_size, 3, H, W)
            E_gt_od = torch.einsum('bcs,bshw->bschw', M_gt, C_gt)[:,1,:,:] # (batch_size, H, W)
            E_gt = torch.clamp(od2rgb(E_gt_od), 0.0, 255.0) # (batch_size, 3, H, W)

            H_gt = H_gt.to(device)
            H_gt_od = H_gt_od.to(device)
            E_gt = E_gt.to(device)
            E_gt_od = E_gt.to(device)

            # Generate the images from the model
            C_mean = C_matrix.detach().cpu()
        
            H_rec_od = torch.einsum('bc qbshw->bschw', M_matrix, C_matrix)[:,0,:,:] 
            H_rec = torch.clamp(od2rgb(H_rec_od), 0.0, 255.0) 
            E_rec_od = torch.einsum('bcs,bshw->bschw', M_matrix, C_matrix)[:,1,:,:]
            E_rec = torch.clamp(od2rgb(E_rec_od), 0.0, 255.0)

            metrics_dict['mse_gt_h'] = torch.sum(torch.pow(H_gt_od - H_rec_od, 2)).item() / (3.0*H*W)
            metrics_dict['mse_gt_e'] = torch.sum(torch.pow(E_gt_od - E_rec_od, 2)).item() / (3.0*H*W)

            metrics_dict['psnr_gt_h'] = torch.sum(peak_signal_noise_ratio(H_gt, H_rec)).item()
            metrics_dict['psnr_gt_e'] = torch.sum(peak_signal_noise_ratio(E_gt, E_rec)).item()

            metrics_dict['ssim_gt_h'] = torch.sum(structural_similarity(H_gt, H_rec)).item()
            metrics_dict['ssim_gt_e'] = torch.sum(structural_similarity(E_gt, E_rec)).item() 

            metrics_dict['mse_gt'] = (metrics_dict['mse_gt_h'] + metrics_dict['mse_gt_e'])/2.0
            metrics_dict['psnr_gt'] = (metrics_dict['psnr_gt_h'] + metrics_dict['psnr_gt_e'])/2.0
            metrics_dict['ssim_gt'] = (metrics_dict['ssim_gt_h'] + metrics_dict['ssim_gt_e'])/2.0

            with open(metrics_filepath, mode='a') as file:
                data_row = ','.join(str(val) for val in metrics_dict.values()) + '\n'
                file.write(data_row)
                
    # Trying to reduce wasted memory
    torch.cuda.empty_cache()