In [None]:
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']="2"

In [None]:
import time
import torch
import numpy as np
import matplotlib.pyplot as plt

os.chdir("../") # So it finds the imports in upper directory, although it should.

from models import Cnet
from datasets import WSSBDatasetTest
from utils import od2rgb, rgb2od, random_ruifrok_matrix_variation, direct_deconvolution, peak_signal_noise_ratio, structural_similarity

In [None]:
SAVE_IMAGES = True
RUNNING_ON_DELFOS = False

torch.manual_seed(0)
plt.rcParams['font.size'] = 14
plt.rcParams['toolbar'] = 'None'

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [None]:
alsubaie_dataset_path = '/data/datasets/Alsubaie/Data/' if RUNNING_ON_DELFOS else '../Alsubaie/Data/'

dataset = WSSBDatasetTest(alsubaie_dataset_path, organ_list=['Colon'], load_at_init=False)
original_image, M_gt = dataset[0]
print('Image shape:', original_image.shape)

In [None]:
# Generate all images derivated from the ground truth.

img_np = original_image.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')
img_od = rgb2od(original_image)

C_gt = direct_deconvolution(img_od, M_gt).unsqueeze(0) # (1, 2, H, W)
M_gt = M_gt.unsqueeze(0) # (1, 3, 2)

C_H_gt_np = C_gt[:, 0, :, :].squeeze().numpy() # (H, W)
C_E_gt_np = C_gt[:, 1, :, :].squeeze().numpy() # (H, W)

H_gt_od = torch.einsum('bcs,bshw->bschw', M_gt, C_gt)[:,0,:,:] # (batch_size, H, W)
H_gt = torch.clamp(od2rgb(H_gt_od), 0.0, 255.0) # (batch_size, 3, H, W)
E_gt_od = torch.einsum('bcs,bshw->bschw', M_gt, C_gt)[:,1,:,:] # (batch_size, H, W)
E_gt = torch.clamp(od2rgb(E_gt_od), 0.0, 255.0) # (batch_size, 3, H, W)

H_gt_np = H_gt.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')
E_gt_np = E_gt.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')

# Move to GPU, we gonna need it later during traning.
img_od = img_od.to(device)
H_gt = H_gt.to(device)
H_gt_od = H_gt_od.to(device)
E_gt = E_gt.to(device)
E_gt_od = E_gt.to(device)

In [None]:
NUM_ITERATIONS = 5000

metrics_dict = { 
    'epoch' : 0, 'loss': 0.0, 
    'mse_rec' : 0.0, 'psnr_rec': 0.0, 'ssim_rec': 0.0,
    'mse_gt_h': 0.0, 'mse_gt_e': 0.0, 'mse_gt': 0.0,
    'psnr_gt_h': 0.0, 'psnr_gt_e': 0.0, 'psnr_gt': 0.0,
    'ssim_gt_h': 0.0, 'ssim_gt_e': 0.0, 'ssim_gt': 0.0, 'time': 0.0    
}

In [None]:
model = Cnet().to(device)
input_noise = torch.rand(original_image.shape).unsqueeze(0).to(device)  # Unsqueezed to add the batch dimension, it needs to be ([1, 3, x, y])
original_tensor = original_image.unsqueeze(0).to(device) 
_, _, H, W = original_tensor.shape               

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

if RUNNING_ON_DELFOS:
    if not os.path.exists('/home/modej/Deep_Var_BCD/results/dip/cnet_e1/single_image/'):
        os.makedirs('/home/modej/Deep_Var_BCD/results/dip/cnet_e1/single_image/')
    metrics_filepath = '/home/modej/Deep_Var_BCD/results/dip/cnet_e1/single_image/metricas.csv'
else:
    if not os.path.exists('/home/modejota/Deep_Var_BCD/results/dip/cnet_e1/single_image/'):
        os.makedirs('/home/modejota/Deep_Var_BCD/results/dip/cnet_e1/single_image/')
    metrics_filepath = '/home/modejota/Deep_Var_BCD/results/dip/cnet_e1/single_image/metricas.csv'
if os.path.exists(metrics_filepath):
    os.remove(metrics_filepath)

with(open(metrics_filepath, 'w')) as file:
    header = ','.join(metrics_dict.keys()) + '\n'
    file.write(header)


for epoch in range(NUM_ITERATIONS+1):
    start_time = time.time()
    optimizer.zero_grad()

    C_matrix = model(input_noise)

    # Generate the colors matrix as a sample of a gaussian distribution given the Ruifrok matrix
    h_matrix, e_matrix = random_ruifrok_matrix_variation(0.05)
    M_matrix = np.concatenate((h_matrix,e_matrix),axis=1)                   # ([1, 3, 2])
    M_matrix = torch.from_numpy(M_matrix).float().unsqueeze(0).to(device)   # ([1, 2, x, y])

    # Generate the 3 channels image and get it back to RGB
    reconstructed_od = torch.einsum('bcs,bshw->bchw', M_matrix, C_matrix)   # ([1, 3, x, y])
    reconstructed = torch.clamp(od2rgb(reconstructed_od), 0, 255.0)
    loss = torch.nn.functional.mse_loss(reconstructed, original_tensor)
    loss.backward()
    optimizer.step()

    execution_time = (time.time() - start_time) * 1000.0    # Milliseconds

    metrics_dict['epoch'] = epoch+1
    metrics_dict['loss'] = loss.item()
    metrics_dict['time'] = execution_time
    metrics_dict['mse_rec'] = torch.sum(torch.pow(reconstructed_od - rgb2od(original_tensor), 2)).item() / (3.0*H*W)    
    metrics_dict['psnr_rec'] = torch.sum(peak_signal_noise_ratio(reconstructed, original_tensor)).item()
    metrics_dict['ssim_rec'] = torch.sum(structural_similarity(reconstructed, original_tensor)).item()

    # Generate the images from the model
    C_mean = C_matrix.detach().cpu()
    img_rec_np = reconstructed.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')
    
    C_H_rec_np = C_mean[:, 0, :, :].squeeze().numpy() 
    C_E_rec_np = C_mean[:, 1, :, :].squeeze().numpy()

    H_rec_od = torch.einsum('bcs,bshw->bschw', M_matrix, C_matrix)[:,0,:,:] 
    H_rec = torch.clamp(od2rgb(H_rec_od), 0.0, 255.0) 
    E_rec_od = torch.einsum('bcs,bshw->bschw', M_matrix, C_matrix)[:,1,:,:]
    E_rec = torch.clamp(od2rgb(E_rec_od), 0.0, 255.0)

    H_rec_np = H_rec.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')
    E_rec_np = E_rec.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')

    metrics_dict['mse_gt_h'] = torch.sum(torch.pow(H_gt_od - H_rec_od, 2)).item() / (3.0*H*W)
    metrics_dict['mse_gt_e'] = torch.sum(torch.pow(E_gt_od - E_rec_od, 2)).item() / (3.0*H*W)

    metrics_dict['psnr_gt_h'] = torch.sum(peak_signal_noise_ratio(H_gt, H_rec)).item()
    metrics_dict['psnr_gt_e'] = torch.sum(peak_signal_noise_ratio(E_gt, E_rec)).item()

    metrics_dict['ssim_gt_h'] = torch.sum(structural_similarity(H_gt, H_rec)).item()
    metrics_dict['ssim_gt_e'] = torch.sum(structural_similarity(E_gt, E_rec)).item() 

    metrics_dict['mse_gt'] = (metrics_dict['mse_gt_h'] + metrics_dict['mse_gt_e'])/2.0
    metrics_dict['psnr_gt'] = (metrics_dict['psnr_gt_h'] + metrics_dict['psnr_gt_e'])/2.0
    metrics_dict['ssim_gt'] = (metrics_dict['ssim_gt_h'] + metrics_dict['ssim_gt_e'])/2.0


    with open(metrics_filepath, mode='a') as file:
        data_row = ','.join(str(val) for val in metrics_dict.values()) + '\n'
        file.write(data_row)

    if SAVE_IMAGES and epoch % 200 == 0:
        # Plot the ground truth images
        fig, ax = plt.subplots(2, 5, figsize=(20, 9))

        ax[0,0].imshow(img_np)
        ax[0,0].set_title('Original Image')
        ax[0,0].axis('off')

        ax[0,1].imshow(C_H_gt_np, cmap='gray')
        ax[0,1].set_title('Original Hematoxylin\nConcentration')
        ax[0,1].axis('off')

        ax[0,2].imshow(C_E_gt_np, cmap='gray')
        ax[0,2].set_title('Original Eosin\nConcentration')
        ax[0,2].axis('off')

        ax[0,3].imshow(H_gt_np)
        ax[0,3].set_title('Original Hematoxylin')
        ax[0,3].axis('off')

        ax[0,4].imshow(E_gt_np)
        ax[0,4].set_title('Original Eosin')
        ax[0,4].axis('off')

        # Plot the generated images via the model
        ax[1,0].imshow(img_rec_np)
        ax[1,0].set_title('Reconstructed Image')
        ax[1,0].axis('off')

        ax[1,1].imshow(C_H_rec_np, cmap='gray')
        ax[1,1].set_title('Reconstructed Hematoxylin\nConcentration')
        ax[1,1].axis('off')

        ax[1,2].imshow(C_E_rec_np, cmap='gray')
        ax[1,2].set_title('Reconstructed Eosin\nConcentration')
        ax[1,2].axis('off')

        ax[1,3].imshow(H_rec_np)
        ax[1,3].set_title('Reconstructed Hematoxylin')
        ax[1,3].axis('off')

        ax[1,4].imshow(E_rec_np)
        ax[1,4].set_title('Reconstructed Eosin')
        ax[1,4].axis('off')
        
        if RUNNING_ON_DELFOS:
            plt.savefig(f'/home/modej/Deep_Var_BCD/results/dip/cnet_e1/single_image/{epoch}.png', transparent=True)
        else:
            plt.savefig(f'/home/modejota/Deep_Var_BCD/results/dip/cnet_e1/single_image/{epoch}.png', transparent=True)
        plt.close()