In [1]:
'''This notebook contains the implementations of the explicit inconsistency scores.'''

'This notebook contains the implementations of the explicit inconsistency scores.'

In [2]:
%matplotlib inline
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

import torch
import segmentation_models_pytorch as smp
import albumentations as albu

import seaborn as sns
import pylab as py
import pandas as pd

from scripts.preprocessing import Dataset
from scripts.helper import plot_double_image, plot_triple_image, get_color_img, get_concat_h
from scripts.augmentation import get_training_augmentation, get_validation_augmentation, to_tensor, get_preprocessing, get_preprocessing_unlabeled

DEVICE = ('cuda:2' if torch.cuda.is_available() else 'cpu')

In [3]:
##### Assign the location of your datasets

DATA_DIR = '/raid/maruf/data-800/'
# train set
x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'trainannot')
# validation set
x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'valannot')
# test set
x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'testannot')

In [4]:
# encoder: resnext50
# encoder weights: imagenet
# segmentation model: FPN 

ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['dorsal', 'adipose', 'caudal', 'anal', 
           'pelvic', 'pectoral', 'head', 'eye', 
           'caudal-ray', 'alt-ray', 'alt-spine', 'trunk']
ACTIVATION = 'sigmoid'
# segmentation model with pretrained encoder
model = smp.FPN(
    encoder_name = ENCODER,
    encoder_weights = ENCODER_WEIGHTS,
    classes = len(CLASSES),
    activation = ACTIVATION,
)
# preprocessing function for this encoder
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [5]:
train_dataset = Dataset(
    x_train_dir,
    y_train_dir,
    augmentation = get_training_augmentation(),
    preprocessing = get_preprocessing(preprocessing_fn),
    classes = CLASSES,
)

valid_dataset = Dataset(
    x_valid_dir,
    y_valid_dir,
    augmentation = get_validation_augmentation(),
    preprocessing = get_preprocessing(preprocessing_fn),
    classes = CLASSES,
)

test_dataset = Dataset(
    x_test_dir,
    y_test_dir,
    augmentation = get_validation_augmentation(),
    preprocessing = get_preprocessing(preprocessing_fn),
    classes = CLASSES,
)

# loading data from the dataset
train_loader = DataLoader(train_dataset, batch_size = 8, shuffle=True, num_workers = 12)
valid_loader = DataLoader(valid_dataset, batch_size = 1, shuffle=False, num_workers = 4)
test_loader = DataLoader(test_dataset)

In [6]:
from scripts.helper import get_midpoint, get_row_col_mat, dorsal_to, anterior_to

In [7]:
def dorsal_to(pred_mask, sub_mask, obj_mask, row_mat): 
    max_sub = np.max(row_mat * sub_mask)
    max_obj = np.max(row_mat * obj_mask)

    return float(max(0, max_sub - max_obj) > 0)

In [8]:
def dorsal_to_batch(labelmaps, tensor=True):
    batch_size = labelmaps.shape[0]
    if tensor:
        y = labelmaps.detach().cpu().numpy()
    else:
        y = labelmaps
    penalty = 0
    for bs in range(batch_size):
        dorsal = y[bs, 0, :, :]
        adipose = y[bs, 1, :, :]
        caudal = y[bs, 2, :, :]
        anal = y[bs, 3, :, :]
        pelvic = y[bs, 4, :, :]
        pectoral = y[bs, 5, :, :]
        head = y[bs, 6, :, :]
        trunk = y[bs, 11, :, :]
        
        if dorsal.sum() == 0 or caudal.sum() == 0 or head.sum() == 0 or trunk.sum() == 0:
            continue
        dorsal_mid = get_midpoint(dorsal)
        caudal_mid = get_midpoint(caudal)
        head_mid = get_midpoint(head)
        trunk_mid = get_midpoint(trunk)
        row_mat, col_mat = get_row_col_mat(dorsal_mid, caudal_mid, head_mid, trunk_mid)
        
        if dorsal.sum() != 0 and pectoral.sum() != 0:
            penalty += dorsal_to(y[bs, :, :, :], dorsal, pectoral, row_mat)
            
        if dorsal.sum() != 0 and pelvic.sum() != 0:
            penalty += dorsal_to(y[bs, :, :, :], dorsal, pelvic, row_mat)
            
        if dorsal.sum() != 0 and anal.sum() != 0:
            penalty += dorsal_to(y[bs, :, :, :], dorsal, anal, row_mat)
            
    return penalty

In [9]:
def dorsal_to_nobatch(labelmap, tensor=True):
    if tensor:
        y = labelmap.detach().cpu().numpy()
    else:
        y = labelmap
        
    penalty = 0
    
    dorsal = y[0, :, :]
    adipose = y[1, :, :]
    caudal = y[2, :, :]
    anal = y[3, :, :]
    pelvic = y[4, :, :]
    pectoral = y[5, :, :]
    head = y[6, :, :]
    trunk = y[11, :, :]

    if dorsal.sum() == 0 or caudal.sum() == 0 or head.sum() == 0 or trunk.sum() == 0:
        return 0
    dorsal_mid = get_midpoint(dorsal)
    caudal_mid = get_midpoint(caudal)
    head_mid = get_midpoint(head)
    trunk_mid = get_midpoint(trunk)
    row_mat, col_mat = get_row_col_mat(dorsal_mid, caudal_mid, head_mid, trunk_mid)

    if dorsal.sum() != 0 and pectoral.sum() != 0:
        penalty += dorsal_to(y[bs, :, :, :], dorsal, pectoral, row_mat)

    if dorsal.sum() != 0 and pelvic.sum() != 0:
        penalty += dorsal_to(y[bs, :, :, :], dorsal, pelvic, row_mat)

    if dorsal.sum() != 0 and anal.sum() != 0:
        penalty += dorsal_to(y[bs, :, :, :], dorsal, anal, row_mat)
        
    return penalty

In [10]:
explicit_penalty = 0
rn_penalty = 0

In [11]:
print('############# Explicit Penalty ##############')
print('Trainloader:')
batch_no = 0
for x, y in train_loader:
    batch_no += 1
    p_ = dorsal_to_batch(y)
    print('Batch {} penalty : {}'.format(batch_no, p_))
    explicit_penalty += p_
    
print('Testloader:')
batch_no = 0
for x, y in test_loader:
    batch_no += 1
    p_ = dorsal_to_batch(y)
    print('Batch {} penalty : {}'.format(batch_no, p_))
    explicit_penalty += p_
    
print('Validloader:')
batch_no = 0
for x, y in valid_loader:
    batch_no += 1
    p_ = dorsal_to_batch(y)
    print('Batch {} penalty : {}'.format(batch_no, p_))
    explicit_penalty += p_

############# Explicit Penalty ##############
Trainloader:
Batch 1 penalty : 0.0
Batch 2 penalty : 0.0
Batch 3 penalty : 0.0
Batch 4 penalty : 0.0
Batch 5 penalty : 0.0
Batch 6 penalty : 0.0
Batch 7 penalty : 0.0
Batch 8 penalty : 0.0
Batch 9 penalty : 0.0
Batch 10 penalty : 0.0
Batch 11 penalty : 0.0
Batch 12 penalty : 0.0
Batch 13 penalty : 0.0
Batch 14 penalty : 0.0
Batch 15 penalty : 0.0
Batch 16 penalty : 0.0
Batch 17 penalty : 0.0
Batch 18 penalty : 0.0
Batch 19 penalty : 0.0
Batch 20 penalty : 0.0
Batch 21 penalty : 0.0
Batch 22 penalty : 0.0
Batch 23 penalty : 0.0
Batch 24 penalty : 0.0
Batch 25 penalty : 0.0
Batch 26 penalty : 0.0
Batch 27 penalty : 0.0
Batch 28 penalty : 0.0
Batch 29 penalty : 0.0
Batch 30 penalty : 1.0
Batch 31 penalty : 0.0
Batch 32 penalty : 0.0
Batch 33 penalty : 0.0
Batch 34 penalty : 0.0
Batch 35 penalty : 0.0
Batch 36 penalty : 0.0
Batch 37 penalty : 0.0
Testloader:
Batch 1 penalty : 0.0
Batch 2 penalty : 0.0
Batch 3 penalty : 0.0
Batch 4 penalty : 0.0

In [12]:
# Relational Network
from scripts.RN import RN
rn_model = RN().to(DEVICE)
rn_model.load_state_dict(torch.load('saved_models/rn_model.pt'))

<All keys matched successfully>

In [13]:
def dorsal_to_rn_batch(labelmaps, rn_model, tensor=True, device=DEVICE):
    if tensor == False:
        prediction = torch.tensor(labelmaps).float()
    else:
        prediction = labelmaps
    D = prediction[:, 0, :, :]
    Pel = prediction[:, 4, :, :]
    Pec = prediction[:, 5, :, :]
    An = prediction[:,3, :, :]
    H = prediction[:, 6, :, :]
    T = prediction[:, 2, :, :]
    
    penalty = 0
    
    # Dorsal-Pectoral
    if D.sum() != 0 and Pec.sum() != 0:
        rn_inp = torch.stack((H, T, D, Pec), dim=1).to(device)
        y_pred = rn_model(rn_inp)
        penalty += (y_pred<0.5).float().sum().item()
    # Dorsal-Pelvic
    if D.sum() != 0 and Pel.sum() != 0:
        rn_inp = torch.stack((H, T, D, Pel), dim=1).to(device)
        y_pred = rn_model(rn_inp)
        penalty += (y_pred<0.5).float().sum().item()
    # Dorsal-Anal
    if D.sum() != 0 and An.sum() != 0:
        rn_inp = torch.stack((H, T, D, An), dim=1).to(device)
        y_pred = rn_model(rn_inp)
        penalty += (y_pred<0.5).float().sum().item()
    return penalty


def dorsal_to_rn_nobatch(labelmaps, rn_model, tensor=True, device=DEVICE):
    if tensor == False:
        prediction = torch.tensor(labelmaps).float()
    else:
        prediction = labelmaps
    labelmaps = torch.zeros(1, prediction.shape[0], prediction.shape[1], prediction.shape[2])
    labelmaps[0, :, :, :] = prediction
    
    return dorsal_to_rn_batch(labelmaps, rn_model, tensor=True, device=DEVICE)

In [14]:
batch_no = 0
for x, y in train_loader:
    batch_no += 1
    print('Batch {} penalty : {}'.format(batch_no, dorsal_to_rn_batch(y, rn_model)))

Batch 1 penalty : 0.0
Batch 2 penalty : 0.0
Batch 3 penalty : 0.0
Batch 4 penalty : 0.0
Batch 5 penalty : 0.0
Batch 6 penalty : 0.0
Batch 7 penalty : 0.0
Batch 8 penalty : 0.0
Batch 9 penalty : 0.0
Batch 10 penalty : 0.0
Batch 11 penalty : 0.0
Batch 12 penalty : 1.0
Batch 13 penalty : 0.0
Batch 14 penalty : 0.0
Batch 15 penalty : 0.0
Batch 16 penalty : 0.0
Batch 17 penalty : 0.0
Batch 18 penalty : 0.0
Batch 19 penalty : 0.0
Batch 20 penalty : 0.0
Batch 21 penalty : 1.0
Batch 22 penalty : 0.0
Batch 23 penalty : 0.0
Batch 24 penalty : 0.0
Batch 25 penalty : 0.0
Batch 26 penalty : 0.0
Batch 27 penalty : 0.0
Batch 28 penalty : 0.0
Batch 29 penalty : 0.0
Batch 30 penalty : 0.0
Batch 31 penalty : 0.0
Batch 32 penalty : 0.0
Batch 33 penalty : 0.0
Batch 34 penalty : 0.0
Batch 35 penalty : 0.0
Batch 36 penalty : 0.0
Batch 37 penalty : 0.0


In [15]:
print('############# RN Penalty ##############')
print('Trainloader:')
batch_no = 0
for x, y in train_loader:
    batch_no += 1
    p_ = dorsal_to_rn_batch(y, rn_model)
    print('Batch {} penalty : {}'.format(batch_no, p_))
    rn_penalty += p_
    
print('Testloader:')
batch_no = 0
for x, y in test_loader:
    batch_no += 1
    p_ = dorsal_to_rn_batch(y, rn_model)
    print('Batch {} penalty : {}'.format(batch_no, p_))
    rn_penalty += p_
    
print('Validloader:')
batch_no = 0
for x, y in valid_loader:
    batch_no += 1
    p_ = dorsal_to_rn_batch(y, rn_model)
    print('Batch {} penalty : {}'.format(batch_no, p_))
    rn_penalty += p_

############# RN Penalty ##############
Trainloader:
Batch 1 penalty : 0.0
Batch 2 penalty : 0.0
Batch 3 penalty : 0.0
Batch 4 penalty : 0.0
Batch 5 penalty : 0.0
Batch 6 penalty : 0.0
Batch 7 penalty : 0.0
Batch 8 penalty : 0.0
Batch 9 penalty : 1.0
Batch 10 penalty : 1.0
Batch 11 penalty : 1.0
Batch 12 penalty : 0.0
Batch 13 penalty : 0.0
Batch 14 penalty : 0.0
Batch 15 penalty : 0.0
Batch 16 penalty : 1.0
Batch 17 penalty : 0.0
Batch 18 penalty : 0.0
Batch 19 penalty : 0.0
Batch 20 penalty : 0.0
Batch 21 penalty : 0.0
Batch 22 penalty : 0.0
Batch 23 penalty : 0.0
Batch 24 penalty : 0.0
Batch 25 penalty : 0.0
Batch 26 penalty : 0.0
Batch 27 penalty : 1.0
Batch 28 penalty : 0.0
Batch 29 penalty : 0.0
Batch 30 penalty : 0.0
Batch 31 penalty : 0.0
Batch 32 penalty : 0.0
Batch 33 penalty : 0.0
Batch 34 penalty : 0.0
Batch 35 penalty : 0.0
Batch 36 penalty : 0.0
Batch 37 penalty : 0.0
Testloader:
Batch 1 penalty : 0.0
Batch 2 penalty : 0.0
Batch 3 penalty : 0.0
Batch 4 penalty : 0.0
Batch

In [16]:
print('Explicit Penalty:', explicit_penalty)
print('Relational-Network Penalty:', rn_penalty)

Explicit Penalty: 3.0
Relational-Network Penalty: 5.0


In [None]:
x_unlabeled_test_dir = '/raid/maruf/inhs-800/img_crp/'
test_dataset_unlabeled = Dataset(
    x_unlabeled_test_dir,
    masks_dir=None,
    augmentation = get_validation_augmentation(),
    preprocessing = get_preprocessing_unlabeled(preprocessing_fn),
    classes = CLASSES,
)
test_loader_unlabeled = DataLoader(test_dataset_unlabeled, batch_size=1)
num_test_img = len(test_dataset_unlabeled)
unlabeled_inconsistency = 0
for test_no in range(num_test_img):
    image, mask = test_dataset_unlabeled[test_no]
    img_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pred_mask = best_model.predict(img_tensor)
    pred_mask = pred_mask.squeeze().cpu().numpy().round()
    unlabeled_inconsistency += explicit_dorsal_to_nobatch(labelmap=pred_mask, tensor=False)
    
    if test_no == 1000:
        break

In [17]:
# count = 0
# for x, y in train_loader:
#     batch_size = y.shape[0]
#     y = y.detach().cpu().numpy()
#     for bs in range(batch_size):
#         dorsal = y[bs, 0, :, :]
#         adipose = y[bs, 1, :, :]
#         caudal = y[bs, 2, :, :]
#         anal = y[bs, 3, :, :]
#         pelvic = y[bs, 4, :, :]
#         pectoral = y[bs, 5, :, :]
#         head = y[bs, 6, :, :]
#         trunk = y[bs, 11, :, :]
        
#         if dorsal.sum() == 0 or caudal.sum() == 0 or head.sum() == 0 or trunk.sum() == 0:
#             continue
            
#         dorsal_mid = get_midpoint(dorsal)
#         caudal_mid = get_midpoint(caudal)
#         head_mid = get_midpoint(head)
#         trunk_mid = get_midpoint(trunk)
        
#         row_mat, col_mat = get_row_col_mat(dorsal_mid, caudal_mid, head_mid, trunk_mid)
        
#         if dorsal.sum() == 0 or pectoral.sum() == 0:
#             print('Dorsal or Pectoral fin is missing.')
#         else:
#             penalty = dorsal_to(y[bs, :, :, :], dorsal, pectoral, row_mat)
#             if penalty != 0:
#                 print('Dorsal _ dorsal-to _ Pectoral: \t Penalty:', penalty)
#                 plt.imshow(get_color_img(y[bs, :, :, :], normal=False))
#                 plt.show()
            
#         if dorsal.sum() == 0 or pelvic.sum() == 0:
#             print('Dorsal or Pelvic fin is missing.')   
#         else:
#             penalty = dorsal_to(y[bs, :, :, :], dorsal, pelvic, row_mat)
#             if penalty != 0:
#                 print('Dorsal _ dorsal-to _ Pelvic: \t Penalty:', penalty)
#                 plt.imshow(get_color_img(y[bs, :, :, :], normal=False))
#                 plt.show()
            
#         if dorsal.sum() == 0 or anal.sum() == 0:
#             print('Dorsal or Anal fin is missing.')
#         else:
#             penalty = dorsal_to(y[bs, :, :, :], dorsal, anal, row_mat)
#             if penalty != 0:
#                 print('Dorsal _ dorsal-to _ Anal: \t Penalty:', penalty)
#                 plt.imshow(get_color_img(y[bs, :, :, :], normal=False))
#                 plt.show()
                
#     count += 1
#     print('Batch {} done'.format(count))


In [18]:
# def anterior_to(head, caudal, sub, obj):
#     if head[0] == -1:
#         head = sub
#     if caudal[0] == -1:
#         caudal = obj
        
#     v = head - caudal
#     u_1 = sub - caudal
#     u_2 = obj - caudal
#     sub_val = np.dot(v, u_1)
#     obj_val = np.dot(v, u_2)

#     return (sub_val<obj_val).astype(np.uint8)

# def dorsal_to(sub, obj, labelmap):
#     trunk = labelmap[11, :, :]
#     pectoral = labelmap[5, :, :]
#     dorsal = labelmap[0, :, :]
#     adipose = labelmap[1, :, :]
#     caudal = labelmap[2, :, :]
#     anal = labelmap[3, :, :]
#     pelvic = labelmap[4, :, :]
#     ray = labelmap[9, :, :]
#     spine = labelmap[10, :, :]
#     head = labelmap[6, :, :]
    
#     body = trunk + pectoral + dorsal + adipose + caudal + anal + pelvic + ray + spine
    
#     mid = ((sub+obj)/2).astype(int)

#     if (trunk[mid[1], mid[0]] == 0):
#         return (body[mid[1], mid[0]] == 0).astype(int)

#     else:
#         return 0