In [None]:
import sys
# setting path
sys.path.append('../')

import os
import matplotlib.pyplot as plt
import numpy as np

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import torch
torch.cuda.empty_cache()
from models.DVBCDModel import DVBCDModel, peak_signal_noise_ratio
from torchmetrics.functional import structural_similarity_index_measure
from utils.datasets import CamelyonDataset, WSSBDatasetTest
from utils.utils_BCD import undo_normalization, od2rgb_np, C_to_OD_torch, normalize_to1

In [None]:
print(torch.__version__)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)

USE_GPU = False
if USE_GPU and torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
print('using device:', DEVICE)

SAVE_MODEL_PATH = "/work/work_fran/Deep_Var_BCD/weights/mobilenetv3s_1pe_224ps_0.3theta_0.05sigmaRui_60000nsamples/"
SAVE_FIG_PATH = "/work/work_fran/Deep_Var_BCD/results/img/"

In [None]:
model = DVBCDModel().to(DEVICE)
model.load(SAVE_MODEL_PATH + "best.pt")

# Camelyon

In [None]:
dataset = CamelyonDataset(data_path = "/data/BasesDeDatos/Camelyon/Camelyon17/training/patches_224/", centers = [0], patch_size = 224, n_samples = 100)
#dataset = WSSBDatasetTest("/data/BasesDeDatos/Alsubaie/Data/", organ_list=["Colon"])

In [None]:
idx = 3
img, mR = dataset[idx]
print(img.min(), img.max())
#img, od_img, mR, C_gt, M_gt = dataset[idx]

In [None]:
od_img = model._rgb2od(img).to(DEVICE)
out_Mnet_mean, out_Mnet_var, out_Cnet, Y_rec = model.forward(od_img.unsqueeze(0))

In [None]:
Y_rec_rgb = model._od2rgb(Y_rec)
print(Y_rec_rgb.min(), Y_rec_rgb.max())
#Y_rec_rgb = torch.clamp(Y_rec_rgb, 0, 255)
Y_rec_rgb = 255.0*(Y_rec_rgb - torch.min(Y_rec_rgb)) / (torch.max(Y_rec_rgb) - torch.min(Y_rec_rgb))
print(Y_rec_rgb.min(), Y_rec_rgb.max())

print(structural_similarity_index_measure(img.unsqueeze(0).to(DEVICE), Y_rec_rgb.to(DEVICE)))
print(peak_signal_noise_ratio(img.unsqueeze(0).to(DEVICE), Y_rec_rgb.to(DEVICE)))



In [None]:
img_np = img.detach().cpu().numpy().astype(np.uint).transpose(1,2,0)
Y_rec_rgb_np =  Y_rec_rgb.detach().cpu().numpy().squeeze()
Y_rec_rgb_np = np.clip(Y_rec_rgb_np, 0, 255).astype(np.uint8)
Y_rec_rgb_np = Y_rec_rgb_np.transpose(1,2,0)

plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.imshow(img_np)
plt.title("Original")
plt.subplot(1,2,2)
plt.imshow(Y_rec_rgb_np)
plt.title("Reconstruction")
#plt.savefig(SAVE_FIG_PATH + f"camelyon_{idx}idx.pdf", bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
C_OD = C_to_OD_torch(out_Cnet, out_Mnet_mean).cpu().detach().numpy().squeeze()
H_OD = C_OD[0, :, :, :]
H_RGB = od2rgb_np(undo_normalization(H_OD))
H_RGB = np.clip(H_RGB, 0, 255).astype(np.uint8)
H_RGB = H_RGB.transpose(1,2,0)

E_OD = C_OD[1, :, :, :]
E_RGB = od2rgb_np(undo_normalization(E_OD))
E_RGB = np.clip(E_RGB, 0, 255).astype(np.uint8)
E_RGB = E_RGB.transpose(1,2,0)

In [None]:
# plt.figure(figsize=(10,10))
# plt.subplot(2,2,1)
# plt.imshow(img_np)
# plt.title("Original")
# plt.subplot(2,2,2)
# plt.imshow(Y_rec_np)
# plt.title("Reconstruction")
# plt.subplot(2,2,3)
# plt.imshow(H_RGB)
# plt.title("Hematoxylin")
# plt.subplot(2,2,4)
# plt.imshow(E_RGB)
# plt.title("Eosin")
# #plt.savefig(SAVE_FIG_PATH + f"camelyon_{idx}idx.pdf", bbox_inches='tight', pad_inches=0)
# plt.show()

# WSSB

In [None]:
organ = "Lung"
dataset = WSSBDatasetTest("/data/BasesDeDatos/Alsubaie/Data/", organ_list=[organ])
idx = 0
img, mR, M_gt = dataset[idx]
od_img = model._rgb2od(img).to(DEVICE)
out_Mnet_mean, out_Mnet_var, out_Cnet, Y_rec = model.forward(od_img.unsqueeze(0))

In [None]:
import time

# get the start time
st = time.time()
out_Mnet_mean, out_Mnet_var, out_Cnet, Y_rec = model.forward(od_img.unsqueeze(0))
et = time.time()

elapsed_time = et - st
print('Execution time:', elapsed_time, 'seconds')

In [None]:
dataset.image_files

In [None]:
img_np = img.numpy().astype(np.uint).transpose(1,2,0)

Y_rec_np = model._od2rgb(Y_rec).detach().cpu().numpy().squeeze()
print(Y_rec_np.min(), Y_rec_np.max())
Y_rec_np = np.clip(Y_rec_np, 0.0, 255.0).astype(np.uint8)
Y_rec_np = Y_rec_np.transpose(1,2,0)
plt.imshow(Y_rec_np)

In [None]:
# C_GT_OD = C_to_OD_torch(C_gt, M_gt).cpu().detach().numpy().squeeze()
# H_OD_GT =  C_GT_OD[0, :, :]
# H_RGB_GT = od2rgb_np(H_OD_GT).transpose(1,2,0)
# H_RGB_GT = np.clip(H_RGB_GT, 0, 255).astype(np.uint8)
# E_OD_GT =  C_GT_OD[1, :, :]
# E_RGB_GT = od2rgb_np(E_OD_GT).transpose(1,2,0)
# E_RGB_GT = np.clip(E_RGB_GT, 0, 255).astype(np.uint8)

#C_OD = C_to_OD_torch(out_Cnet, out_Mnet_mean).cpu().detach().numpy().squeeze()

C_OD = torch.einsum('bcs, bshw -> bschw', out_Mnet_mean, out_Cnet).to("cpu").squeeze()
H_OD = C_OD[0, :, :, :]
H_RGB = model._od2rgb(H_OD)
H_RGB = torch.clamp(H_RGB, 0.0, 255.0)
H_RGB = H_RGB.detach().numpy()
H_RGB = H_RGB.transpose(1,2,0).astype(np.uint8)

E_OD = C_OD[1, :, :, :]
E_RGB = model._od2rgb(E_OD)
E_RGB = torch.clamp(E_RGB, 0.0, 255.0)
E_RGB = E_RGB.detach().numpy()
E_RGB = E_RGB.transpose(1,2,0).astype(np.uint8)

In [None]:
H_RGB.shape

In [None]:
plt.imshow(H_RGB)
plt.imsave("H_RGB.png", H_RGB)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,10))
plt.imshow(E_RGB)
plt.imsave("E_RGB.png", E_RGB)
plt.show()

In [None]:
plt.figure(figsize=(10,15))
plt.subplot(3,2,1)
plt.imshow(img_np)
plt.xlabel("Original")
plt.subplot(3,2,2)
plt.imshow(Y_rec_np)
plt.xlabel("Reconstruction")
plt.subplot(3,2,3)
plt.imshow(H_RGB_GT)
plt.xlabel("H GT")
plt.subplot(3,2,4)
plt.imshow(H_RGB)
plt.xlabel("H")
plt.subplot(3,2,5)
plt.imshow(E_RGB_GT)
plt.xlabel("E GT")
plt.subplot(3,2,6)
plt.imshow(E_RGB)
plt.xlabel("E")
#plt.savefig(SAVE_FIG_PATH + f"wssb_{organ}_{idx}idx.pdf", bbox_inches='tight', pad_inches=0)

plt.show()