# 02-Validation-example  
Load a model and check the predictions

In [1]:
%load_ext autoreload
%autoreload 2

import gi_tract_seg as gt
from gi_tract_seg.models.module import GITractSegmentatonLitModule
import torch
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import seaborn as sns
from pytorch_toolbelt.inference import GeneralizedTTA, d4_image_augment, d4_image_deaugment
from torchvision.utils import make_grid
import torchvision

In [None]:
model = torch.load('../models/deep_run_6_unet_fold0.pt')
model.eval()
model.cuda();

In [None]:
font = cv2.FONT_HERSHEY_SIMPLEX
org = (20, 20)
fontScale = .5
color = (1, 0, 0)
thickness = 2

In [None]:
datamodule = gt.data.datamodule.GITractDataModule(
    '../data/processed/train_df_agg.parquet', 
    '../data/raw/train/', 
    '../data/processed/masks/', 
    val_fold=0,
    batch_size=128,
    num_workers=2,
    shuffle_train=False
)
datamodule.prepare_data()
datamodule.setup()

In [None]:
dice_scores = []
predictions = []
dice_meter = gt.models.metrics.DiceMeter('multilabel')
for batch in tqdm(datamodule.train_dataloader()):
    with torch.no_grad():
        batch_preds = model(batch['image'].cuda())
        dice_scores.extend(
            [dice_meter.score_fn(batch['mask'][i].cuda(), batch_preds[i]) for i in range(batch_preds.size()[0])]
        )
        predictions.append(torch.nn.Sigmoid()(batch_preds).cpu().numpy())
dice_scores = np.array(dice_scores)
dice_scores = pd.DataFrame(dice_scores, columns=["lb", "sb", "st"])
dice_scores['mean'] = np.nanmean(dice_scores, axis=1)
predictions = np.concatenate(predictions)

In [None]:
sns.boxplot(data = pd.melt(dice_scores, value_vars=["lb", "sb", "st", "mean"]), x = 'variable', y = 'value')
mean_values = np.nanmean(dice_scores, axis=0)
mean_values_str =\
    ', '.join([f'{gi_c}:{val:.3f}' for gi_c, val in zip(dice_scores.columns, mean_values)])
plt.title('GI tract DICE per class\n'+mean_values_str);
plt.xlabel('Class')
plt.ylabel('Dice')

In [None]:
train_df = datamodule.train_data.reset_index(drop=False)
data_dices = pd.merge(
    dice_scores, 
    train_df,
    left_index=True, 
    right_index=True)
data_dices.head()

In [None]:
mean_across_case_day = data_dices.groupby(['case', 'day'])['mean'].mean().reset_index().sort_values('mean')
mean_across_case_day.head()

In [None]:
def extract_scan_predictions(scan_id):
    image_data = datamodule.train_dataset[scan_id]
    image_rgb = np.tile(image_data['image'].cpu().numpy().transpose(1, 2, 0), 3).astype(np.float32)
    mask_rgb = image_data['mask']
    with torch.no_grad():
        prediction_rgb = model(image_data['image'][None,:].cuda()).cpu()[0]    
        dice_score = dice_meter.score_fn(mask_rgb, prediction_rgb)
    prediction_rgb = torch.nn.Sigmoid()(prediction_rgb)
    mask_rgb = mask_rgb.cpu().numpy().transpose(1, 2, 0).astype(np.float32)
    prediction_rgb = prediction_rgb.numpy()
    prediction_rgb = prediction_rgb.transpose(1, 2, 0).astype(np.float32)
    seg_overlay_gt = cv2.addWeighted(
        src1=image_rgb, 
        alpha=0.999, 
        src2=mask_rgb, 
        beta=0.35, 
        gamma=0
    )
    seg_overlay_gt = cv2.putText(
        seg_overlay_gt,
        f'GT {scan_id}', 
        org, 
        font, 
        fontScale, 
        color, 
        thickness,
        cv2.LINE_AA
    )
    seg_overlay_pred = cv2.addWeighted(
        src1=image_rgb, 
        alpha=0.999, 
        src2=prediction_rgb, 
        beta=0.35, 
        gamma=0
    )
    seg_overlay_pred = cv2.putText(
        seg_overlay_pred,
        f'PRED {scan_id}', 
        org, 
        font, 
        fontScale, 
        color, 
        thickness,
        cv2.LINE_AA
    )
    return seg_overlay_gt, seg_overlay_pred, dice_score

def plot_scan(scan_id, figsize=(15, 7)):
    seg_overlay_gt, seg_overlay_pred, dice_score = extract_scan_predictions(scan_id)
    f, ax = plt.subplots(1, 1, figsize=figsize)
    ax.imshow(np.concatenate([seg_overlay_gt, seg_overlay_pred], axis=1))

def plot_around_id(scan_id, figsize=(15, 7), plot_around = 2, save=True, plot=True):
    overlays = []
    dice_scores = []
    for scan_id_ in range(scan_id-plot_around, scan_id+plot_around+1):
        seg_overlay_gt, seg_overlay_pred, dice_score = extract_scan_predictions(scan_id_)
        overlay_ = np.concatenate([seg_overlay_gt, seg_overlay_pred], axis=1)
        overlays.append(overlay_)
    f, ax = plt.subplots(1, 1, figsize=(15, 7*len(overlays)))
    ax.imshow(np.concatenate(overlays))
    plt.tight_layout()
    if save:
        f.savefig(f'images/worst_cases/{scan_id}.jpg')
    if not plot:
        plt.close()
        
def plot_train_case(case_data, case, day, dice_score, show, save, ):
    overlayed_images = []
    for inner_idx, idx in enumerate(case_data.index):
        image_data = datamodule.train_dataset[idx]
        image_rgb = np.tile(image_data['image'].cpu().numpy().transpose(1, 2, 0), 3).astype(np.float32)
        mask_rgb = image_data['mask'].cpu().numpy().transpose(1, 2, 0).astype(np.float32)
        with torch.no_grad():
            prediction_rgb = torch.nn.Sigmoid()(model(image_data['image'][None,:].cuda())).cpu().numpy()[0]
        prediction_rgb = prediction_rgb.transpose(1, 2, 0).astype(np.float32)

        seg_overlay_gt = cv2.addWeighted(
            src1=image_rgb, 
            alpha=0.999, 
            src2=mask_rgb, 
            beta=0.35, 
            gamma=0
        )
        seg_overlay_gt = cv2.putText(
            seg_overlay_gt,
            f'GT ({inner_idx})', 
            org, 
            font, 
            fontScale, 
            color, 
            thickness,
            cv2.LINE_AA
        )
        seg_overlay_gt = torch.from_numpy(seg_overlay_gt.transpose(2,0,1))

        seg_overlay_pred = cv2.addWeighted(
            src1=image_rgb, 
            alpha=0.999, 
            src2=prediction_rgb, 
            beta=0.35, 
            gamma=0
        )
        seg_overlay_pred = cv2.putText(
            seg_overlay_pred,
            f'PRED ({inner_idx})', 
            org, 
            font, 
            fontScale, 
            color, 
            thickness,
            cv2.LINE_AA
        )
        seg_overlay_pred = torch.from_numpy(seg_overlay_pred.transpose(2,0,1))
        overlayed_images.append(seg_overlay_gt)
        overlayed_images.append(seg_overlay_pred)

    grid = make_grid(overlayed_images, nrow=8)
    img = torchvision.transforms.ToPILImage()(grid)
    if show:
        img.show()
    if save:
        img.save(f'images/cases/{case}_{day}_{dice_score:.3f}.jpg')

In [None]:
selected_worst_cases = data_dices.loc[data_dices['empty']==0].sort_values('mean').head(100)

In [None]:
for case_idx in selected_worst_cases.index:
    plot_around_id(case_idx, plot=False)

In [None]:
for idx in tqdm(range(mean_across_case_day.shape[0])):
    case = mean_across_case_day['case'].values[idx]
    day = mean_across_case_day['day'].values[idx]
    dice_score = mean_across_case_day['mean'].values[idx]
    mask_case_day = (train_df['case']==case) & (train_df['day']==day)
    case_data = train_df[mask_case_day]
    print(f'========= {case}:{day} =======')
    show = True if idx <=5 else False
    plot_train_case(case_data, case, day, dice_score, show, True)

In [2]:
train_data = pd.read_parquet('../data/processed/train_df_agg.parquet')

In [3]:
train_data_cleaned = train_data.loc[(train_data['case']!=7) & (train_data['day']!=0), :]
train_data_cleaned = train_data_cleaned.loc[(train_data_cleaned['case']!=81) & (train_data_cleaned['day']!=30), :]
train_data_cleaned.to_parquet('../data/processed/train_df_agg_cleaned.parquet')