# Test Metrics on Model

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torchvision
import change_dataset_np
from torchvision import datasets, models, transforms
import models
from PIL import Image
import matplotlib.pyplot as plt
img_size = 224
num_classes = 2
batch_size = 1
val_pickle_file = 'change_dataset_train.pkl'

from IPython.display import clear_output, display
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
print('Device:', device)
num_gpu = torch.cuda.device_count()
print('Number of available GPUs:', num_gpu)

In [None]:
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
    'val': transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
}

In [None]:
val_dataset = change_dataset_np.ChangeDatasetNumpy(val_pickle_file, data_transforms['val'])
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

#### Initialize Model and Load Checkpoint

In [None]:
change_net = models.ChangeNet(num_classes=num_classes)
if num_gpu > 1:
    change_net = nn.DataParallel(change_net)
change_net = change_net.to(device)

checkpoint = torch.load('./best_model-NoRandCrop.pkl')
change_net.load_state_dict(checkpoint);
change_net.eval();

In [None]:
iteractive_idx = 0
output = 0
label_img = 0
@interact(idx=widgets.IntSlider(min=0,max=len(val_dataset)-1))
def explore_validation_dataset(idx):
    global iteractive_idx
    global output
    global label_img
    sample = val_dataset[idx]
    reference_img = sample['reference']
    test_img = sample['test']
    label_img = sample['label']
    preds = change_net([reference_img.unsqueeze(0), test_img.unsqueeze(0)])
    _, output = torch.max(preds, 1)
    output = output.unsqueeze(0)
    print(output.shape)
    plt.imshow(reference_img.permute(1, 2, 0).numpy())
    plt.show()
    plt.imshow(test_img.permute(1, 2, 0).numpy())
    plt.show()
    plt.imshow(label_img.squeeze(0).numpy())
    plt.show()
    iteractive_idx = idx

In [None]:
output.shape

In [None]:
label_img.shape

In [None]:
SMOOTH = 1e-6

def iou_binary(outputs: torch.Tensor, labels: torch.Tensor):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape
    outputs = outputs.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W
    
    intersection = (outputs & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = (outputs | labels).float().sum((1, 2))         # Will be zzero if both are 0
    
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    
    thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
    return thresholded.mean()  # Or thresholded.mean() if you are interested in average across the batch

In [None]:
iou_binary(output.to(device), label_img.to(device))

In [None]:
torch.max(output)

In [None]:
torch.max(label_img)