In [None]:
import numpy as np
import pandas as pd

import os

import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

import imageio
from skimage import transform,io

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from collections import Counter 

from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, LearningRateScheduler

from sklearn.feature_extraction.image import extract_patches_2d

from PIL import ImageFile, Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

if False: # only for JupyterLab with GPUs environment
    import sys
    sys.path.insert(0, '/notebooks/')
    sys.path.insert(0, '../')
    from CapsuleLib.utils import gpu_config, availableGPU
    gpuid = gpu_config(False, True)

## Paths setting

In [None]:
NORMAL_IMAGES_PATH = "/data/capri/polyp_db_v3/normal_dbv1"  # set paths here!
OOD_IMAGES_PATH = "/notebooks/Arnau/patch-self-supervised/data/ood-images"

## Data preparation

In [None]:
BATCH_SIZE = 64
PATCH_SIZE = (64,64)
IM_SIZE = 256
TEMP = 1000
N = 20

In [None]:
path = NORMAL_IMAGES_PATH
filepaths, videos = [], []
for image in os.listdir(path):
    if image.endswith('.png'):
        image_id = image[:3]
        filepaths.append(os.path.join(path,image))
        videos.append(image_id)
df = pd.DataFrame({"path": filepaths, "video":videos})
df['video'] = df['video'].astype(int)

folds = pd.read_csv('3fold.csv')

path = OOD_IMAGES_PATH

filepaths, image_ids = [], []
for image in os.listdir(path):
    if image.endswith('.png'):
        image_id = image.split('-')[0][-3:]
        filepaths.append(os.path.join(path,image))
        image_ids.append(image_id)
df_ood = pd.DataFrame({"path": filepaths, "image":image_ids})

def rotate_image(image):
    return np.rot90(image, np.random.choice([-1, 0, 1, 2]))

datagen = ImageDataGenerator(rescale=1./255., vertical_flip=True,
    horizontal_flip=True,preprocessing_function=rotate_image)

ood_gen=datagen.flow_from_dataframe(dataframe=df_ood, x_col="path", class_mode=None, 
    target_size=(IM_SIZE,IM_SIZE), batch_size=BATCH_SIZE, shuffle=True, seed=1)

## Helper functions

In [None]:
def get_msp(image_batch):
    """
    Returns the list of MSP given a batch of images
    """
    softmax_probs = model.predict(image_batch)
    return 1 - np.max(softmax_probs, axis=1)

def sliding_window(image):
    """
    Returns sixteen overlapping 64x64 patches from a 256x256 image
    """
    windows = []
    for y in range(32, 192, 44):
        for x in range(32, 192, 44):
            windows.append(image[y:y+64, x:x+64])
    return np.array(windows)

def perturb_images(images, epsilon=0.002):
    """
    Perturbs a batch of image for a given epsilon magnitude
    """
    
    test_ds_var = tf.Variable(images, trainable=True)

    with tf.GradientTape() as tape:
        tape.watch(test_ds_var)
        logits = model(test_ds_var, training=False)
        loss = -tf.reduce_mean(tf.reduce_max(logits, axis=1))

    gradients = tape.gradient(loss, test_ds_var)
    gradients = tf.math.greater_equal(gradients, 0)
    gradients = tf.cast(gradients, tf.float32)
    gradients = (gradients - 0.5) * 2

    static_tensor = tf.convert_to_tensor(test_ds_var) - epsilon * gradients
    static_tensor = tf.clip_by_value(static_tensor, 0., 255.)
    
    return static_tensor

def get_hist_roc_patches(runs=100, epsilon=0, summary="max"):
    """
    Computes histogram and ROC for the given epsilon and summary function
    """
    
    msp_list = []
    y_pred, y_true, y_probs = [], [], []
    
    for i, generator in enumerate([val_gen, ood_gen]):
        msp = []
        for n in range(runs):
            batch = generator.next()
            batch_msp = []
            for image in batch:
                patches = sliding_window(image)
                k = len(patches)
                if epsilon==0:
                    msp_patches = get_msp(patches)
                else:
                    msp_patches = get_msp(perturb_images(patches, epsilon))
                    
                if summary=="top5": summary_patches = sum(sorted(msp_patches, reverse=True)[:5])/5.
                if summary=="max": summary_patches = max(msp_patches)
                if summary=="avg": summary_patches = sum(msp_patches)/k
                    
                assert 0 <= summary_patches <= 1
                
                batch_msp.append(summary_patches)
                msp.append(summary_patches)
                
            if len(batch_msp) == BATCH_SIZE:
                y_probs = np.concatenate([y_probs, batch_msp])
                if i == 0:
                    y_true = np.concatenate([y_true, np.zeros(BATCH_SIZE)])
                else:
                    y_true = np.concatenate([y_true, np.ones(BATCH_SIZE)])
                    
        msp_list.append(msp)

    fig, ax = plt.subplots(figsize=(10,6))
    logbins = np.logspace(-3,0,80)
    ax.hist(msp_list[0], bins=logbins, alpha=0.4, label="Validation", color="blue")
    ax.hist(msp_list[1], bins=logbins, alpha=0.5, label="OOD", color="orange")
    plt.xlabel("1-MSP")
    plt.xscale('log')
    plt.legend()
    plt.title("MSP histograms")
    plt.savefig(f'results/{N}/hist{suffix}.png')

    threshold_range = np.concatenate([[0.],np.logspace(-8,-4,50), np.logspace(-4,-1,100), np.logspace(-1,0,100)])
    roc_values = []

    for threshold in threshold_range:
        y_pred = np.where(y_probs >= threshold, 1, 0)
        sensitivity = np.sum(np.logical_and(y_pred,y_true))/np.sum(y_true)
        specificity = np.sum(~np.logical_or(y_pred,y_true)/(y_pred.size-np.sum(y_true)))
        roc_values.append((1-specificity, sensitivity))

    revspec, sens = zip(*roc_values)
    revspec, sens = np.array(revspec), np.array(sens)
    spec = 1-np.array(revspec)

    df_roc = pd.DataFrame({'spec':spec, 'sens':sens})
    df_roc.to_csv(f'results/{N}/roc{suffix}.csv')
        
    return spec, sens

In [None]:
generators, models = [], []
for fold in [1,2,3]:
    
    videos_val = folds[folds[f'fold-{fold}']]['video'].values
    df_val = df[df.video.isin(videos_val)]
    val_gen=datagen.flow_from_dataframe(dataframe=df_val, x_col="path", class_mode=None,
    target_size=(IM_SIZE,IM_SIZE), batch_size=BATCH_SIZE, shuffle=True, seed=1)
    generators.append(val_gen)
    
    model_name = f'models/patches-classifier-{N}-t{TEMP}-f{fold}.hdf5'
    model = keras.models.load_model(model_name)
    models.append(model)

## 3 fold ODIN evaluation

In [None]:
for N in [10,15,20]:
    
    print(f"\n ************** N={N} ***************")
    
    generators, models = [], []
    for fold in [1,2]:
    
        videos_val = folds[folds[f'fold-{fold}']]['video'].values
        df_val = df[df.video.isin(videos_val)]
        val_gen=datagen.flow_from_dataframe(dataframe=df_val, x_col="path", class_mode=None,
        target_size=(IM_SIZE,IM_SIZE), batch_size=BATCH_SIZE, shuffle=True, seed=1)
        generators.append(val_gen)

        model_name = f'models/patches-classifier-{N}-t{TEMP}-f{fold}.hdf5'
        model = keras.models.load_model(model_name)
        models.append(model)
    
    output = ""
    for summ in [ "top5", "max", "avg"]:
        print(summ)
        output += f"{summ}\n"
        for eps in [0, 0.0005,0.001]:

            auroc_list = []
            print(f"  epsilon={eps}")
            output += f"  epsilon={eps}\n"

            for fold in [1,2]:

                suffix = f'-{N}-{summ}-t{TEMP}-eps{eps}-f{fold}'
                model = models[fold-1]
                val_gen = generators[fold-1]

                spec, sens = get_hist_roc_patches(epsilon=eps, summary=summ, runs=50)
                auroc = round(np.abs(np.trapz(spec, np.array(sens))),5)
                auroc_list.append(auroc)
                print(f"    fold {fold} AUROC: {auroc}")
                output += f"    fold {fold} AUROC: {auroc}\n"

            print(f"  Mean AUROC: {np.mean(auroc_list)}\n")
            output += f"  Mean AUROC: {np.mean(auroc_list)}\n\n"

    with open(f'results/output-{N}.txt', 'w') as f:
        f.write(output)

## Plot clean histograms

In [None]:
K = 20
FOLD = 2
model_name = f'models/patches-classifier-{K}-t{TEMP}-f{FOLD}.hdf5'
model = keras.models.load_model(model_name)

In [None]:
def plot_hist_methods(epsilon, runs=50):
    
    for summary in ["max", "top5", "avg"]:
        
        print(summary)
        
        msp_list = []
        for i, generator in enumerate([val_gen, ood_gen]):
            msp = []
            for n in range(runs):
                batch = generator.next()
                batch_msp = []
                for image in batch:
                    patches = sliding_window(image)
                    k = len(patches)
                    if epsilon==0:
                        msp_patches = get_msp(patches)
                    else:
                        msp_patches = get_msp(perturb_images(patches, epsilon))

                    if summary=="top5": summary_patches = sum(sorted(msp_patches, reverse=True)[:5])/5.
                    if summary=="max": summary_patches = max(msp_patches)
                    if summary=="avg": summary_patches = sum(msp_patches)/k 

                    batch_msp.append(summary_patches)
                    msp.append(summary_patches)
            msp_list.append(msp)
            
        fig, ax = plt.subplots(figsize=(4,3), dpi=100)
        logbins = np.logspace(-1,0,50)
        ax.hist(msp_list[0], bins=logbins, alpha=0.4, label="Validation", color="blue")
        ax.hist(msp_list[1], bins=logbins, alpha=0.5, label="OOD", color="orange")
        ax.set_xlabel("1-$\mathcal{S}(x)$")
        ax.set_yticks([])
        ax.set_xscale('log')
        ax.legend()
        plt.savefig(f'results/tfg-hist-{summary}.png', bbox_inches='tight')
        fig.show()

In [None]:
plot_hist_methods(0.0001,runs=50)