In [None]:
import torch
from torch import nn
import numpy as np
from pathlib import Path
from torch.utils import data
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import wandb
# from torchsummary import summary
from skimage.io import imread
import cv2

from models.incept_unet import InceptionedUNet
from datasets.blastocyst import SFUDataset, ClinicMultifocalDataset
from utils.trainer import Trainer
from utils.losses import DicePlusBCELoss

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
def find_convexHull(pred) -> list:
    # apply binary thresholding
    ret, thresh = cv2.threshold(pred.astype(np.uint8), 0.5, 255, cv2.THRESH_BINARY)
    # Finding contours for the thresholded image
    contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    # find the position of the longest contour
    contour_lengths = [len(c) for c in contours]
    try:
        i = np.argmax(contour_lengths)
    except ValueError:
        return []

    # creating convex hull object for longest contour
    hull = cv2.convexHull(contours[i], False)

    return [hull]

In [None]:
def calc_thresh(preds, thresh):
    return (preds > thresh).float() * 1

In [None]:
# 288 for 6 blocks
# 272 for 5 and less
IMG_SIZE = 288
TRANSFORM = transforms.Compose([transforms.Resize(IMG_SIZE), transforms.CenterCrop(IMG_SIZE)])
SFU_MASK = 'all'
CLINIC_MASK = 'random'

In [None]:
# clinic data
clinic = Path('/datasets/clinic_multifocal')
# sfu data
sfu = Path('/datasets/sfu/BlastsOnline')

In [None]:
dataset = ClinicMultifocalDataset(clinic, use_augmentations=False, mask='all',
                                   transform=TRANSFORM, use_circle_mask=True)
dataloader = data.DataLoader(dataset=dataset, batch_size=1, shuffle=False)

In [None]:
c_white = colors.colorConverter.to_rgba('white',alpha = 0)
c_red= colors.colorConverter.to_rgba('red', alpha = 1)
cmap_red = colors.LinearSegmentedColormap.from_list('rb_cmap', [c_white,c_red], 512)

In [None]:
x, _ = next(iter(dataloader))

print(f'x = shape: {x.shape}; type: {x.dtype}')
print(f'x = min: {x.min()}; max: {x.max()}')

for j in range(x.shape[0]):
    f, axarr = plt.subplots(1, 11, figsize=(25, 15))
    for i in range(11):
        axarr[i].imshow(x[j][i])
    plt.show()

In [None]:
incept_unet_save_path = Path('PLACE_HOLDER')

In [None]:
incept_unet = torch.load(incept_unet_save_path)
incept_unet.to(device)
None

In [None]:
thresh = 0.55
mean_score = False
plot = False
display_name = True
focals = [-75, -60, -45, -30, -15, 0, 15, 30, 45, 60, 75]

In [None]:
incept_unet.eval()

preds_icm = {}
for x, name in dataloader:
    with torch.no_grad():
        input = x.to(device)

        user_id = name[0].split('_')[0]
        preds_icm[user_id] = {}
        for i in range(x.shape[1]):
            preds = incept_unet(input[None, :, i, :, :])
            preds_thresh = (preds > thresh).float() * 1

            preds_icm[user_id][focals[i]] = (x[0][i], preds_thresh[0, 0, :, :].cpu().numpy())

            if not plot:
                continue

            f, axarr = plt.subplots(1, 3, figsize = (11, 4))
            axarr[0].imshow(x[0][i])
            axarr[1].imshow(preds[0].cpu().permute(1, 2, 0))
            axarr[2].imshow(preds_thresh[0].cpu().permute(1, 2, 0))

            title = ''
            if display_name:
                title = f'{name[0]}_{focals[i]}'
            f.suptitle(title)
            axarr[0].set_title('Image')
            axarr[1].set_title('Prediction')
            axarr[2].set_title(f'Prediction with Threshold ({thresh})')
            plt.show()
            print('---------------------'*3)

In [None]:
overlap = True

for key, value in preds_icm.items():
    fig, axarr = plt.subplots(2, 11, figsize=(30, 5))

    print(key)
    for i, f in enumerate(focals):
        x = value[f][0]
        pred = value[f][1]
        hull = find_convexHull(pred)

        if overlap:
            axarr[0][i].imshow(x, cmap='gray')
            axarr[0][i].imshow(pred, cmap=cmap_red, alpha=0.4)
            drawing = x.numpy().copy()
            cv2.drawContours(drawing, hull, -1, (1, 1, 1), 3)
        else:
            axarr[0][i].imshow(pred)
            drawing = np.zeros((x.shape[0], x.shape[1], 1), np.uint8)
            cv2.drawContours(drawing, hull, -1, 1, 2)
        axarr[0][i].set_title(f)

        axarr[1][i].imshow(drawing, cmap='gray')
    plt.show()