In [None]:
import sys
# setting path
sys.path.append('../')

import os
import torch

import matplotlib.pyplot as plt
import numpy as np

from utils.callbacks import EarlyStopping, History

from utils.utils_data import get_train_dataloaders, get_wssb_dataloader
from utils.utils_BCD import od2rgb_np, undo_normalization
from models.DVBCDModel import DVBCDModel

print(torch.__version__)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)

USE_GPU = True
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2"
if USE_GPU and torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
print('using device:', DEVICE)

In [None]:
def psnr(A, B, max=None):
    """
    input: 
        A: tensor of shape (N, C, H, W)
        B: tensor of shape (N, C, H, W)
    return:
        psnr: tensor of shape (N, )
    """
    if max is None:
        max = torch.max(A)
    mse = torch.mean((A - B)**2, dim=(1,2,3))
    psnr = 10 * torch.log10(max**2 / mse)
    return psnr

In [None]:
import torch
from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure

bs = 5
A = 255.0*torch.rand(bs, 3, 256, 256)
print(A.max(), A.min())
B = A + 0.1 * torch.randn(bs, 3, 256, 256)
print(B.max(), B.min())

print(peak_signal_noise_ratio(A.view(bs, -1), B.view(bs,-1), dim=0, data_range=255.0))
print(psnr(A, B, max=255.0).mean())
print(structural_similarity_index_measure(A, B, data_range=255.0))


In [None]:
#data_path = '/data/BasesDeDatos/Camelyon/Camelyon17/training/Toy/'
camelyon_data_path = '/data/BasesDeDatos/Camelyon/Camelyon17/training/patches_224/'
wssb_data_path = '/data/BasesDeDatos/Alsubaie/Data/'

#dataloaders = get_dataloaders(args.data_path, args.patch_size, args.batch_size, args.num_workers, val_prop=0.3)
train_dataloader, val_dataloader = get_train_dataloaders(camelyon_data_path, 224, 16, 32, val_prop=0.1, n_samples=500, train_centers=[0])
val_dataloader = get_wssb_dataloader(wssb_data_path, 32)

print(len(train_dataloader), len(val_dataloader))

In [None]:
model = DVBCDModel(
                cnet_name="unet6", mnet_name="resnet18ft", 
                sigmaRui_sq=torch.tensor([1e-03, 1e-03]), theta_val=0.5, 
                lr=1e-4, lr_decay=0.1
                )
model.DP()
model.to(DEVICE)
#model.set_callbacks([EarlyStopping(model, "val_loss_mse", patience=2), History()])
#model.fit(1, train_dataloader, val_dataloader, pretraining=True)
model.fit(1, train_dataloader, val_dataloader, pretraining=False)
#model.save("./weights/DVBCDModel_prueba.pth")
#model.load("./weights/DVBCDModel_prueba.pth")
#model.to(DEVICE)

In [None]:
bs = 1
Y = torch.randn(bs, 3, 2000, 2000)
M = torch.randn(bs, 3, 2)
print(Y.shape, M.shape)

In [None]:
Y = Y.view(bs, 3, -1)
C = torch.linalg.lstsq(M, Y).solution
C.shape

In [None]:
img_prueba = train_dataloader.dataset[2][0].unsqueeze(0).to(DEVICE)
img_od = train_dataloader.dataset[2][1].unsqueeze(0).to(DEVICE)
mean, var, cnet, y_rec = model.forward(img_prueba)
cnet_h = cnet[0,0,:,:].cpu().detach().numpy()
cnet_e = cnet[0,1,:,:].cpu().detach().numpy()
plt.subplot(121)
plt.imshow(cnet_h)
plt.subplot(122)
plt.imshow(cnet_e)
plt.show()

In [None]:
b = y_rec[0,:,:,:].cpu().detach().numpy().transpose(1,2,0)
b = undo_normalization(b, np.log(256.0), 0)
b = od2rgb_np(b).astype(np.uint8)

In [None]:
c = img_od[0,:,:,:].cpu().detach().numpy().transpose(1,2,0)
c = undo_normalization(c, np.log(256.0), 0)
c = od2rgb_np(c).astype(np.uint8)

In [None]:
plt.subplot(121)
plt.imshow(c)
plt.subplot(122)
plt.imshow(b)
plt.show()