# Patch-based ODIN qualitative analysis

In [None]:
import numpy as np
import pandas as pd
import os

import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import seaborn as sns

import imageio
from skimage import transform,io

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from collections import Counter 

from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, LearningRateScheduler

from sklearn.feature_extraction.image import extract_patches_2d

from PIL import ImageFile, Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

if False: # only for JupyterLab with GPUs environment
    import sys
    sys.path.insert(0, '/notebooks/')
    sys.path.insert(0, '../')
    from CapsuleLib.utils import gpu_config, availableGPU
    gpuid = gpu_config(False, True)

## Paths setting

In [None]:
NORMAL_IMAGES_PATH = "/data/capri/polyp_db_v3/normal_dbv1"  # set paths here!
OOD_IMAGES_PATH = "/notebooks/Arnau/patch-self-supervised/data/ood-images"

##  Loading data

In [None]:
BATCH_SIZE = 64
PATCH_SIZE = (64,64)
IM_SIZE = 256
TEMP = 1000
K = 20
FOLD = 2

In [None]:
path = NORMAL_IMAGES_PATH
filepaths, videos = [], []
for image in os.listdir(path):
    if image.endswith('.png'):
        image_id = image[:3]
        filepaths.append(os.path.join(path,image))
        videos.append(image_id)
df = pd.DataFrame({"path": filepaths, "video":videos})
df['video'] = df['video'].astype(int)

folds = pd.read_csv('3fold.csv')

path = OOD_IMAGES_PATH

filepaths, image_ids = [], []
for image in os.listdir(path):
    if image.endswith('.png'):
        image_id = image.split('-')[0][-3:]
        filepaths.append(os.path.join(path,image))
        image_ids.append(image_id)
df_ood = pd.DataFrame({"path": filepaths, "image":image_ids})

def rotate_image(image):
    return np.rot90(image, np.random.choice([-1, 0, 1, 2]))

datagen = ImageDataGenerator(rescale=1./255., vertical_flip=True,
    horizontal_flip=True,preprocessing_function=rotate_image)

normal_gen=datagen.flow_from_dataframe(dataframe=df, x_col="path", class_mode=None, 
    target_size=(IM_SIZE,IM_SIZE), batch_size=BATCH_SIZE, shuffle=True, seed=1)

ood_gen=datagen.flow_from_dataframe(dataframe=df_ood, x_col="path", class_mode=None, 
    target_size=(IM_SIZE,IM_SIZE), batch_size=BATCH_SIZE, shuffle=True, seed=1)

model = keras.models.load_model(f'models/patches-classifier-{K}-t1000-f{FOLD}.hdf5')

## Helper functions

In [None]:
def get_msp(image_batch):
    """
    Returns the list of MSP given a batch of images
    """
    softmax_probs = model.predict(image_batch)
    return 1 - np.max(softmax_probs, axis=1)

def sliding_windows(image, mode="16"):
    """
    Returns overlapping 64x64 patches from a 256x256 image
    """
    windows = []
    if mode=="16":
        for y in range(32, 192, 44):
            for x in range(32, 192, 44):
                windows.append(image[y:y+64, x:x+64])
    elif mode=="28":
        for x in range(51,150, 46):
            windows.append(image[4:68, x:x+64])
        for y in range(28, 192, 46):
            for x in range(28, 192, 46):
                windows.append(image[y:y+64, x:x+64])
        for x in range(51,150, 46):
            windows.append(image[188:252, x:x+64])
        for y in range(51,150, 46):
            windows.append(image[y:y+64, 4:68])
        for y in range(51,150, 46):
            windows.append(image[y:y+64, 188:252])
    return np.array(windows)

def perturb_images(images, epsilon=0.0005):
    """
    Perturbs a batch of image for a given epsilon magnitude
    """
    
    test_ds_var = tf.Variable(images, trainable=True)

    with tf.GradientTape() as tape:
        tape.watch(test_ds_var)
        logits = model(test_ds_var, training=False)
        loss = -tf.reduce_mean(tf.reduce_max(logits, axis=1))

    gradients = tape.gradient(loss, test_ds_var)
    gradients = tf.math.greater_equal(gradients, 0)
    gradients = tf.cast(gradients, tf.float32)
    gradients = (gradients - 0.5) * 2

    static_tensor = tf.convert_to_tensor(test_ds_var) - epsilon * gradients
    static_tensor = tf.clip_by_value(static_tensor, 0., 255.)
    
    return static_tensor

def get_ordered_patches(image, epsilon=0.0005, mode="16"):
    """
    Returns list of patches, from most abnormal to most normal
    """
    patches = sliding_windows(image, mode=mode)
    if epsilon==0:
        msp_patches = get_msp(patches)
    else:
        msp_patches = get_msp(perturb_images(patches, epsilon))
        
    return msp_patches, [x for _, x in sorted(zip(msp_patches, patches), reverse=True)]

def plot_strip(image_strip):
    """
    Plot strip of images
    """
    fig, ax = plt.subplots(1,len(image_strip), figsize=(len(image_strip),2))
    for i, im in enumerate(image_strip):
        ax[i].imshow(im)
        ax[i].axis('off')

def get_16_heatmap(msp_patches):
    """
    Returns heatmap of the given patches
    """
    heatmap = np.zeros((256,256))
    count = np.zeros((256,256))
    i = 0
    windows = []
    for y in range(32, 192, 44):
        for x in range(32, 192, 44):
            heatmap[y:y+64, x:x+64] += msp_patches[i]
            count[y:y+64, x:x+64] += 1
            i += 1
    count = np.clip(count,1.,np.inf)   

    return heatmap/count

def get_28_heatmap(msp_patches):
    """
    Returns 28-patches heatmap of the given patches
    """
    heatmap = np.zeros((256,256))
    count = np.zeros((256,256))
    i = 0
    for x in range(51,150, 46):
        heatmap[4:68, x:x+64] += msp_patches[i]
        count[4:68, x:x+64] += 1
        i += 1
    for y in range(28, 192, 46):
        for x in range(28, 192, 46):
            heatmap[y:y+64, x:x+64] += msp_patches[i]
            count[y:y+64, x:x+64] += 1
            i += 1
    for x in range(51,150, 46):
            heatmap[188:252, x:x+64] += msp_patches[i]
            count[188:252, x:x+64] += 1
            i += 1
    for y in range(51,150, 46):
        heatmap[y:y+64, 4:68] += msp_patches[i]
        count[y:y+64, 4:68] += 1
        i += 1
    for y in range(51,150, 46):
        heatmap[y:y+64, 188:252] += msp_patches[i]
        count[y:y+64, 188:252] += 1
        i += 1
    count = np.clip(count,1.,np.inf)
    return heatmap/count

def plot_heatmap(image, epsilon=0.0005, mode="16"):
    """
    Plots the heatmap of an image
    """
    
    msp_patches, _ = get_ordered_patches(image, epsilon, mode=mode)
    
    if mode =="16":
        hmax = sns.heatmap(get_16_heatmap(msp_patches), alpha=1,
                cmap="Reds", cbar=False, xticklabels=False, yticklabels=False, 
                       square=True, vmin=0.1, vmax=0.8)
        hmax.imshow(image, alpha=0.6, zorder = 1) 
    
    elif mode =="28":
        hmax = sns.heatmap(get_28_heatmap(msp_patches), alpha=1,
                cmap="Reds", cbar=False, xticklabels=False, yticklabels=False, 
                       square=True, vmin=0.2, vmax=1)
        hmax.imshow(image, alpha=0.6, zorder = 1) 

## Patches strip

In [None]:
batch = ood_gen.next()

In [None]:
fig, ax = plt.subplots(8,8, figsize=(8,8))
for i, image in enumerate(batch):
    ax[i//8,i%8].imshow(image)
    ax[i//8,i%8].axis('off')

In [None]:
image = batch[11]
image = rotate_image(image)
plt.imshow(image)
plt.axis('off')
plt.savefig('qualitative/tfg-strip-wrong-image-2.png', bbox_inches='tight')
plot_strip(get_ordered_patches(image, epsilon=0)[1])
plt.savefig('qualitative/tfg-strip-wrong-patches-2.png', bbox_inches='tight')

## Heatmap

In [None]:
plot_heatmap(image)
plt.savefig('qualitative/tfg-wrong-heatmap-2.png', bbox_inches='tight')

## Examples

In [None]:
example_image_paths = ['04562.png', # white tube
 '05720.png', # white tube + white blob
 '067_1_1_04043.png', # ???
 '04607.png', # white tube
 '00679.png', # ?????
 '064_1_2_04812.png', # metal tube?
 '003_1_021215.png', # ?????
 '04579.png', # white tube
 '05865.png', # white tube
 '064_2_010151.png'] # blood
path = OOD_IMAGES_PATH

example_images = []
for image in example_image_paths:
    pic = Image.open(os.path.join(path,image))
    example_images.append(np.array(pic,dtype="float32")/255)
    
fig, ax = plt.subplots(2,5, figsize=(10,4))
for i, image in enumerate(example_images):
    msp_patches, _ = get_ordered_patches(image, 0.0005)
    
    hmax = sns.heatmap(get_16_heatmap(msp_patches), alpha=1,
            cmap="Reds", cbar=False, xticklabels=False, yticklabels=False, 
                   square=True, vmin=0, vmax=0.8, ax=ax[i//5,i%5])
    hmax.imshow(image, alpha=0.7, zorder = 1) 

## Highest- and lowest-score images

In [None]:
images = np.empty((0,256,256,3))
scores = np.empty((0,))
N = 1
for n in range(N):
    batch = ood_gen.next()
    batch_scores = []
    for image in batch:
        patches = sliding_windows(image)
        msp_patches = get_msp(perturb_images(patches, 0.001))

        summary_patches = sum(sorted(msp_patches, reverse=True)[:5])/5.
        batch_scores.append(summary_patches)
    
    images = np.concatenate([images, batch])
    scores = np.concatenate([scores, batch_scores])
    
sorted_images = list(images[np.argsort(scores)])

In [None]:
# top 5
n = 5
top_image = np.zeros((256, 256*n, 3))
top_images = sorted_images[-n:]
for i, image in enumerate(top_images):
    top_image[:,i*256:(i+1)*256,:] += image
plt.imshow(top_image)
plt.axis('off')
plt.savefig(f'qualitative/patch-top-images.png', dpi=300, bbox_inches='tight',pad_inches = 0)

In [None]:
# bot 5
n = 5
bot_image = np.zeros((256, 256*n, 3))
bot_images = sorted_images[:n]
for i, image in enumerate(bot_images):
    bot_image[:,i*256:(i+1)*256,:] += image
plt.imshow(bot_image)
plt.axis('off')
plt.savefig(f'qualitative/patch-bot-images.png', dpi=300, bbox_inches='tight',pad_inches = 0)