In [None]:
import torch
from models import BCDnet
import wandb

In [None]:
run_path = 'francastro-team/DVBCD/cj1p1dw1'

api = wandb.Api()
run = api.run(run_path)
config = run.config
print(config)

In [None]:
weights_file = run.file('weights/best.pt').download(replace=True)
weights_dict = torch.load('weights/best.pt')

In [None]:
model = BCDnet(cnet_name=config['cnet_name'], mnet_name=config['mnet_name'])
model.load_state_dict(weights_dict)

In [None]:
from datasets import CamelyonDataset
dataset = CamelyonDataset(config['camelyon_data_path'], centers=[0], patch_size=config['patch_size'], n_samples=100, load_at_init=False)

In [None]:
from utils import od2rgb, rgb2od
import matplotlib.pyplot as plt

img = dataset[3]
img_np = img.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')
img_od = rgb2od(img)
M_mean, M_var, C_mean = model(img_od.unsqueeze(0)) # (1, 3, 2), (1, 1, 2), (1, 2, 224, 224)
img_rec_od = torch.einsum('bcs,bshw->bchw', M_mean, C_mean) # (1, 3, 224, 224)
img_rec = torch.clamp(od2rgb(img_rec_od), 0.0, 255.0)
img_rec_np = img_rec.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')

fig, ax = plt.subplots(1, 2)
ax[0].imshow(img_np)
ax[0].set_title('Original')
ax[0].axis('off')

ax[1].imshow(img_rec_np)
ax[1].set_title('Reconstructed')
ax[1].axis('off')
plt.show()

In [None]:
from datasets import WSSBDatasetTest
dataset = WSSBDatasetTest(config['wssb_data_path'], organ_list=['Colon'], load_at_init=False)

In [None]:
img, M_gt = dataset[3]
img_np = img.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')
img_od = rgb2od(img)
M_mean, M_var, C_mean = model(img_od.unsqueeze(0)) # (1, 3, 2), (1, 1, 2), (1, 2, 224, 224)
img_rec_od = torch.einsum('bcs,bshw->bchw', M_mean, C_mean) # (1, 3, 224, 224)
img_rec = torch.clamp(od2rgb(img_rec_od), 0.0, 255.0)
img_rec_np = img_rec.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')

fig, ax = plt.subplots(1, 2)
ax[0].imshow(img_np)
ax[0].set_title('Original')
ax[0].axis('off')

ax[1].imshow(img_rec_np)
ax[1].set_title('Reconstructed')
ax[1].axis('off')
plt.show()


In [None]:
from utils import direct_deconvolution

img, M_gt = dataset[0]
img_od = rgb2od(img)

M_mean, M_var, C_mean = model(img_od.unsqueeze(0)) # (1, 3, 2), (1, 1, 2), (1, 2, 224, 224)

H_rec_od = torch.einsum('bcs,bshw->bschw', M_mean, C_mean)[:,0,:,:] # (batch_size, H, W)
H_rec = torch.clamp(od2rgb(H_rec_od), 0.0, 255.0) # (batch_size, 3, H, W)
E_rec_od = torch.einsum('bcs,bshw->bschw', M_mean, C_mean)[:,1,:,:] # (batch_size, H, W)
E_rec = torch.clamp(od2rgb(E_rec_od), 0.0, 255.0) # (batch_size, 3, H, W)

H_rec_np = H_rec.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')
E_rec_np = E_rec.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')

C_gt = direct_deconvolution(img_od, M_gt)
M_gt = M_gt.unsqueeze(0) # (1, 3, 2)
C_gt = C_gt.unsqueeze(0) # (1, 2, 224, 224)

H_gt_od = torch.einsum('bcs,bshw->bschw', M_gt, C_gt)[:,0,:,:] # (batch_size, H, W)
H_gt = torch.clamp(od2rgb(H_gt_od), 0.0, 255.0) # (batch_size, 3, H, W)
E_gt_od = torch.einsum('bcs,bshw->bschw', M_gt, C_gt)[:,1,:,:] # (batch_size, H, W)
E_gt = torch.clamp(od2rgb(E_gt_od), 0.0, 255.0) # (batch_size, 3, H, W)

H_gt_np = H_gt.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')
E_gt_np = E_gt.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype('uint8')

fig, ax = plt.subplots(2, 2)
ax[0,0].imshow(H_gt_np)
ax[0,0].set_title('Ground truth Hematoxylin')
ax[0,0].axis('off')

ax[0,1].imshow(H_rec_np)
ax[0,1].set_title('Reconstructed Hematoxylin')
ax[0,1].axis('off')

ax[1,0].imshow(E_gt_np)
ax[1,0].set_title('Ground truth Eosin')
ax[1,0].axis('off')

ax[1,1].imshow(E_rec_np)
ax[1,1].set_title('Reconstructed Eosin')
ax[1,1].axis('off')
plt.show()