## Visualize the outputs of the trained model

In [None]:
# Import statements

from operator import add

from tqdm import tqdm
from dataset import get_data

from models import *
from metrics import *
from utils import read_image, process_image
import random
import cv2
import numpy as np
import torch
from models import *
from glob import glob
import matplotlib.pyplot as plt
from dataset import *

In [None]:
# Preambles

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
configs = {"model": "Unet",
           "data": "fives",
           "loss_fn": "clDice",
           "size": (224, 224),
           "augments": False,
           "data_path": 'datasets/',
           "result_path": 'results',
           "device": "cuda",
           "learning_rate": 0.001,
           "batch_size": 16,
           "epochs": 40,
           "num_workers": 2,
           }

In [None]:
# Load the data
_, _, test_x, test_y = get_data(configs)
# x_sub, y_sub = zip(*random.sample(list(zip(test_x, test_y)), 5))
data, label = zip(*random.sample(list(zip(test_x, test_y)), 5))
# print(len(x), len(y))

In [None]:
def return_pred(x, model):
    with torch.no_grad():
        """ Prediction and Calculating FPS """
        pred_y = model.predict(x)

        pred_y = pred_y[0].cpu().numpy()  # (1, 512, 512)
        pred_y = np.squeeze(pred_y, axis=0)  # (512, 512)
        pred_y = pred_y > 0.5
        pred_y = np.array(pred_y, dtype=np.uint8)

    return pred_y

In [None]:
def get_batch_predictions(x, model_name: str, model_constructor, loss_names):
    """ Batch predictions for the loss"""
    predictions = {}
    for loss_fn in loss_names:
        predictor_name = f"{configs['result_path']}/{model_name}/{configs['data']}/No_Augmentation_{loss_fn}"
        model = model_constructor().to(device)
        model.load_state_dict(torch.load(predictor_name, map_location=device))
        prediction = return_pred(x, model)
        predictions[loss_fn] = prediction

    return predictions



In [None]:
def visualize_sample(image, mask, predictions):

    """ Visualization """
    fig, axes = plt.subplots(
        nrows=1, ncols=2+len(predictions), figsize=(10, 5))
    axes[0].imshow(image)
    axes[0].title.set_text(f"Original Image")
    axes[0].axis("off")
    axes[1].imshow(mask_parse(mask))
    axes[1].title.set_text(f"Ground Truth")
    axes[1].axis("off")
    for i, (k, v) in enumerate(predictions.items()):
        print
        axes[i+2].imshow(mask_parse(predictions[k])*255, cmap="seismic")
        axes[i+2].title.set_text(f"{k} Prediction")
        axes[i+2].axis("off")
    fig.show()
    # # save image
    # fig.savefig(f"results/{data_str}/{name}.png")

In [None]:
def plot_image_predictions(model_name, model_constructor, data, label):
    for i, (x, y) in enumerate(zip(data, label)):
        """ Extract the name """
        # name = x.split("/")[-1].split(".")[0]

        """ read and process image """
        image, mask = read_image(x, y)
        x, y = process_image(image, mask)
        "Get Bathc predictions for the loss"
        loss_names = ["cldice", "DiceLoss"]
        
        predictions = get_batch_predictions(x, model_name, model_constructor, loss_names)

        visualize_sample(image, mask, predictions)


## Plotting images predictions for UNET model


In [None]:
model_name = "Unet"
model_constructor = Unet
plot_image_predictions(model_name, model_constructor, data, label)


## Plotting images predictions for FPN model

In [None]:
model_name = "FPN"
model_constructor = FPN
plot_image_predictions(model_name, model_constructor, data, label)


## Plotting images predictions for MANet model


In [None]:
model_name = "MANet"
model_constructor = MANet
plot_image_predictions(model_name, model_constructor, data, label)


## Plotting images predictions for UNET++ model


In [None]:
model_name = "Unet++"
model_constructor = UnetPlusPlus
plot_image_predictions(model_name, model_constructor, data, label)

## Plot images, preprocessed image and mask

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [None]:

H = 224
W = 224
size = (W, H)

def read_image(file_name, size):

    image = cv2.imread(file_name, cv2.IMREAD_COLOR) ## (512, 512, 3)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, size)

    return image

def process(image):
    # x = np.transpose(image, (2, 0, 1))      ## (3, 512, 512)
    x = image/255.0
    # x = np.expand_dims(x, axis=0)           ## (1, 3, 512, 512)
    x = x.astype(np.float32)
    # x = torch.from_numpy(x)

    return x


def plot_processed_images(orig_path, proc_path, mask_path):

    orig_image = read_image(orig_path, size)
    proc_image = read_image(proc_path, size)
    mask_image = read_image(mask_path, size)

    x = process(orig_image)
    x_ = process(proc_image)
    y = process(mask_image)



    """ Visualization """
    fig, axes = plt.subplots(
        nrows=1, ncols=3, figsize=(12, 7))
    axes[0].imshow(x)
    axes[0].title.set_text(f"Original Image")
    axes[0].axis("off")
    axes[1].imshow(x_)
    axes[1].title.set_text(f"Processed Image")
    axes[1].axis("off")
    axes[2].imshow(y)
    axes[2].title.set_text(f"Image Mask")
    axes[2].axis("off")
    fig.show()

In [None]:

orig_path = 'datasets/UCH/4.jpg'
proc_path = 'datasets/preprocessed_train/4.png'
mask_path = 'datasets/masks_train/4.png'


plot_processed_images(orig_path, proc_path, mask_path)

In [None]:

orig_path = 'datasets/UCH/12.jpg'
proc_path = 'datasets/preprocessed_train/12.png'
mask_path = 'datasets/masks_train/12.png'


plot_processed_images(orig_path, proc_path, mask_path)