In [None]:
import os
import torch

import matplotlib.pyplot as plt
import numpy as np

from utils.callbacks import EarlyStopping, History

from utils.utils_data import get_train_dataloaders, get_test_dataloaders
from utils.utils_BCD import od2rgb_np, undo_normalization
from models.DVBCDModel import DVBCDModel

print(torch.__version__)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)

USE_GPU = True
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
if USE_GPU and torch.cuda.is_available():
    DEVICE = torch.device('cuda:3')
else:
    DEVICE = torch.device('cpu')
print('using device:', DEVICE)

In [None]:
#data_path = '/data/BasesDeDatos/Camelyon/Camelyon17/training/Toy/'
camelyon_data_path = '/data/BasesDeDatos/Camelyon/Camelyon17/training/patches_224/'
wssb_data_path = '/data/BasesDeDatos/Alsubaie/Data/'

#dataloaders = get_dataloaders(args.data_path, args.patch_size, args.batch_size, args.num_workers, val_prop=0.3)
train_dataloader, val_dataloader = get_train_dataloaders(camelyon_data_path, 16, 32, 64, val_prop=0.2, n_samples=10000)

print(len(train_dataloader), len(val_dataloader))

In [None]:
model = DVBCDModel(
                cnet_name="unet_6", mnet_name="resnet_18_in", 
                sigmaRui_sq=torch.tensor([1e-03, 1e-03]), lambda_val=0.005, lr_cnet=1e-4, 
                lr_mnet=1e-4, lr_decay=0.1, clip_grad_cnet=10e5, clip_grad_mnet=10e5,
                device=DEVICE
                )
#model.set_callbacks([EarlyStopping(model, "val_loss_mse", patience=2), History()])
model.fit(2, train_dataloader, val_dataloader, pretraining=True)
model.fit(10, train_dataloader, val_dataloader, pretraining=False)
#model.save("./weights/DVBCDModel_prueba.pth")
#model.load("./weights/DVBCDModel_prueba.pth")
#model.to(DEVICE)

In [None]:
img_prueba = train_dataloader.dataset[2][0].unsqueeze(0).to(DEVICE)
img_od = train_dataloader.dataset[2][1].unsqueeze(0).to(DEVICE)
mean, var, cnet, y_rec = model.forward(img_prueba)
cnet_h = cnet[0,0,:,:].cpu().detach().numpy()
cnet_e = cnet[0,1,:,:].cpu().detach().numpy()
plt.subplot(121)
plt.imshow(cnet_h)
plt.subplot(122)
plt.imshow(cnet_e)
plt.show()

In [None]:
b = y_rec[0,:,:,:].cpu().detach().numpy().transpose(1,2,0)
b = undo_normalization(b, np.log(256.0), 0)
b = od2rgb_np(b).astype(np.uint8)

In [None]:
c = img_od[0,:,:,:].cpu().detach().numpy().transpose(1,2,0)
c = undo_normalization(c, np.log(256.0), 0)
c = od2rgb_np(c).astype(np.uint8)

In [None]:
plt.subplot(121)
plt.imshow(c)
plt.subplot(122)
plt.imshow(b)
plt.show()