## Importing Libraries

In [1]:
import torch
import pickle
import random
import numpy as np
from tqdm.auto import tqdm
from src.models.unet import UNet
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure
from src.data.load_data import load_data
from src.constants import *


## Defining Helper Functions

In [2]:
def sig(x):
  return 1/(1 + np.exp(-x))

def calculate_iou(gt_mask, pred_mask):
    overlap = pred_mask * gt_mask
    union = (pred_mask + gt_mask) > 0
    iou = overlap.sum() / float(union.sum())
    return iou

## Visualize Losses

In [None]:
with open('outputs/losses/losses_autoaugment.pkl', 'rb') as f:
    loss_autoaugment = pickle.load(f)

with open('outputs/losses/losses_base.pkl', 'rb') as f:
    loss_base = pickle.load(f)

with open('outputs/losses/losses_bigaug.pkl', 'rb') as f:
    loss_bigaug = pickle.load(f)

fig, axs = plt.subplots(1, 3)
fig.set_figheight(5)
fig.set_figwidth(23)

axs[0].plot(loss_base['train_loss'], label="Train Loss")
axs[0].plot(loss_base['test_loss'], label="Test Loss")
axs[0].set_title("Train and Test Loss Over Epochs For Base UNet Model")
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("BCE Loss")

axs[1].plot(loss_bigaug['train_loss'], label="Train Loss")
axs[1].plot(loss_bigaug['test_loss'], label="Test Loss")
axs[1].set_title("Train and Test Loss Over Epochs For UNet Model With BigAug")
axs[1].set_xlabel("Epoch")
axs[1].set_ylabel("BCE Loss")

axs[2].plot(loss_autoaugment['train_loss'], label="Train Loss")
axs[2].plot(loss_autoaugment['test_loss'], label="Test Loss")
axs[2].set_title("Train and Test Loss Over Epochs For UNet Model With AutoAugment")
axs[2].set_xlabel("Epoch")
axs[2].set_ylabel("BCE Loss")

handles, labels = axs[2].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower right')

## Visualize Segmentations

In [None]:
base_model = UNet(3, 1)
bigaug_model = UNet(3, 1)
autoaugment_model = UNet(3, 1)

base_model.load_state_dict(torch.load('outputs/models/unet_base_final.pt', map_location=torch.device('cpu'))) 
bigaug_model.load_state_dict(torch.load('outputs/models/unet_bigaug_final.pt', map_location=torch.device('cpu'))) 
autoaugment_model.load_state_dict(torch.load('outputs/models/unet_autoaug_final_new.pt', map_location=torch.device('cpu'))) 

In [None]:
dsets, train_dataset, test_dataset = load_data()

In [None]:
samples = random.sample(range(1, len(test_dataset)), 3)

for i in samples:
    figure(figsize=(20, 8), dpi=80)

    image, mask = test_dataset[i]
    mask_binary = mask.numpy() > 0.5
    image_dis = image.cpu().permute(1, 2, 0)
    image = image.unsqueeze(0).float()

    # Original Image 
    plt.subplot(2, 4, 1)
    plt.title("Original Image")
    plt.imshow(image_dis)

    # Original Mask
    plt.subplot(2, 4, 5)
    plt.title("Ground Truth Mask")
    plt.imshow(mask.squeeze(0))

    # Base Model
    plt.subplot(2, 4, 2)
    plt.title("Base Model Segmentation")
    pred = base_model(image)
    pred_np = sig(pred[0][0].detach().numpy())
    plt.imshow(pred_np)

    # Base Binary
    plt.subplot(2, 4, 6)
    plt.title("Thresholded Base Model Segmentation")
    pred_binary = pred_np > 0.5
    plt.imshow(pred_binary)

    # BigAug Model
    plt.subplot(2, 4, 3)
    plt.title("BigAug Model Segmentation")
    pred = bigaug_model(image)
    pred_np = sig(pred[0][0].detach().numpy())
    plt.imshow(pred_np)

    # BigAug Binary
    plt.subplot(2, 4, 7)
    plt.title("Thresholded BigAug Model Segmentation")
    pred_binary = pred_np > 0.5
    plt.imshow(pred_binary)

    # AutoAugment Model
    plt.subplot(2, 4, 4)
    plt.title("AutoAug Model Segmentation")
    pred = autoaugment_model(image)
    pred_np = sig(pred[0][0].detach().numpy())
    plt.imshow(pred_np)

    # AutoAugment Binary
    plt.subplot(2, 4, 8)
    plt.title("Thresholded AutoAugment Model Segmentation")
    pred_binary = pred_np > 0.5
    plt.imshow(pred_binary)

In [None]:
dset_loaders = {x: torch.utils.data.DataLoader(dsets[x], batch_size=BATCH_SIZE,
                                               shuffle=True, num_workers=0)
                    for x in ['train', 'test']}

test_loader = dset_loaders['test']

avg_iou_base_sh = 0
avg_iou_big_aug_sh = 0
avg_iou_auto_aug_sh = 0

avg_iou_base_sk = 0
avg_iou_big_aug_sk = 0
avg_iou_auto_aug_sk = 0

count_sh = 0
count_sk = 0

loop = tqdm(test_loader)

for (i, (image, mask)) in enumerate(loop):
    with torch.no_grad():
        mask_binary = mask.numpy() > 0.5
        
        pred_base = base_model(image)
        pred_base_np = sig(pred_base[0][0].numpy())
        pred_base_binary = pred_base_np > 0.5

        pred_big_aug = bigaug_model(image)
        pred_big_aug_np = sig(pred_big_aug[0][0].numpy())
        pred_big_aug_binary = pred_big_aug_np > 0.5

        pred_auto_aug = autoaugment_model(image)
        pred_auto_aug_np = sig(pred_auto_aug[0][0].numpy())
        pred_auto_aug_binary = pred_auto_aug_np > 0.5
        
        if i < 176: 
            avg_iou_base_sk += calculate_iou(mask_binary, pred_base_binary)
            avg_iou_big_aug_sk += calculate_iou(mask_binary, pred_big_aug_binary)
            avg_iou_auto_aug_sk += calculate_iou(mask_binary, pred_auto_aug_binary)
            count_sk += 1
        else: 
            avg_iou_base_sh += calculate_iou(mask_binary, pred_base_binary)
            avg_iou_big_aug_sh += calculate_iou(mask_binary, pred_big_aug_binary)
            avg_iou_auto_aug_sh += calculate_iou(mask_binary, pred_auto_aug_binary)
            count_sh += 1

avg_iou_base_sk = avg_iou_base_sk / count_sk
avg_iou_big_aug_sk = avg_iou_big_aug_sk / count_sk
avg_iou_auto_aug_sk = avg_iou_auto_aug_sk / count_sk

avg_iou_base_sh = avg_iou_base_sh / count_sh
avg_iou_big_aug_sh = avg_iou_big_aug_sh / count_sh
avg_iou_auto_aug_sh = avg_iou_auto_aug_sh / count_sh

print(avg_iou_base_sk, avg_iou_big_aug_sk, avg_iou_auto_aug_sk, avg_iou_base_sh, avg_iou_big_aug_sh, avg_iou_auto_aug_sh)
