In [None]:
import numpy as np
import pandas as pd

import os

import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import seaborn as sns

import imageio
from skimage import transform,io

from collections import Counter

import tensorflow.keras.backend as K

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.utils import to_categorical   

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
if False: # only for JupyterLab with GPUs environment
    import sys
    sys.path.insert(0, '/notebooks/')
    sys.path.insert(0, '../')
    from CapsuleLib.utils import gpu_config
    gpuid = gpu_config(False, True)
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

## Set paths

In [None]:
SB3_DATA_PATH = 'data/GLF_SB3/' # set paths here!!!
NORMAL_DATA_PATH = 'data/normal-images' 
OOD_DATA_PATH = 'data/ood-images' 

## Prepare data

In [None]:
TRAINING = False

BATCH_SIZE = 64
IM_SIZE_SB2 = 224

TEMP = 10000
FREEZE = 0
suffix = f'-t{TEMP}'
suffix

In [None]:
classes = pd.DataFrame(columns=["label", "class"])

filepaths = []
for item in os.listdir(NORMAL_DATA_PATH):
    classes = classes.append({"label":item.split('_')[0], "class":item.split('_')[1]}, ignore_index=True)
    path = NORMAL_DATA_PATH + item
    label = item.split('_')[0]
    for image in os.listdir(path):
        if image.endswith('.png'):
            filepaths.append(path+"/"+image)
            
df = pd.DataFrame({"filepath": filepaths})
df['label'] = df['filepath'].str[14:15]
df = df[df.label!='2'].reset_index(drop=True)
classes = classes[classes.label!='2']
df['video'] = df['filepath'].str.split('/').str[4].str[:6]
df.head()

In [None]:
videos = ['182SJH', '070CBA', '079CMS', '078APM', '084SMS', '178EFB', '177SLG', '082VNR', '094IAV', '081SNR']
videos_train, videos_val = train_test_split(videos, test_size=0.2, random_state=0)
df_train = df[df.video.isin(videos_train)]
df_class_val = df[df.video.isin(videos_val)]
df_class_val.shape

In [None]:
def rotate_image(image):
    return np.rot90(image, np.random.choice([-1, 0, 1, 2]))
datagen = ImageDataGenerator(
    rescale=1./255.,
    vertical_flip=True,
    horizontal_flip=True,
    preprocessing_function=rotate_image
)
train_gen=datagen.flow_from_dataframe(
    dataframe=df_train,
    x_col="filepath",
    y_col="label",
    class_mode="categorical",
    target_size=(IM_SIZE_SB2,IM_SIZE_SB2),
    batch_size=BATCH_SIZE,
    shuffle=True,
    seed=1)
val_class_gen=datagen.flow_from_dataframe(
    dataframe=df_class_val,
    x_col="filepath",
    y_col="label",
    class_mode="categorical",
    target_size=(IM_SIZE_SB2,IM_SIZE_SB2),
    batch_size=BATCH_SIZE,
    shuffle=True,
    seed=1)

In [None]:
def visualize_images(generator, n=5):
    
    batch = generator.next()[0]
    
    fig, ax = plt.subplots(n,n, figsize=(n,n))
    for i in range(n**2):
        ax[i%n,int(i/n)].imshow(batch[i])
        ax[i%n,int(i/n)].axis('off')
    fig.tight_layout()
    fig.show()

In [None]:
visualize_images(train_gen)

In [None]:
visualize_images(val_class_gen)

## Transfer learning from SB2

In [None]:
if not TRAINING:
    model = keras.models.load_model('models/sb3-classifier-best.hdf5')
    model.summary()

In [None]:
if TRAINING:
    model_sb2 = keras.models.load_model('models/sb2-classifier-best.hdf5')
    for layer in model_sb2.layers[:FREEZE]:
        layer.trainable = False
    dropout_top = keras.layers.Dropout(0.5, name="dropout_top")(model_sb2.layers[21].output)
    dense_top = keras.layers.Dense(6, name='logits')(dropout_top)
    temp_top = keras.layers.Lambda(lambda x: x / TEMP)(dense_top)
    out = keras.layers.Activation('softmax', name='softmax')(temp_top)

    model = keras.models.Model(inputs=model_sb2.input, outputs=out)
    model.summary()

In [None]:
if TRAINING:
    early_stop = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', mode='auto', factor=0.5, 
                                  verbose=1, patience=5, min_lr=1e-6)
    mcp_save = ModelCheckpoint(f'models/sb3-classifier.hdf5', save_best_only=True, monitor='val_loss', mode='min')
    def scheduler(epoch, lr):
        return 0.9*lr
    lr_scheduler = LearningRateScheduler(scheduler, verbose=1)

    model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001),
                  loss="categorical_crossentropy", metrics=["categorical_accuracy"])

    history = model.fit(train_gen, epochs=100, validation_data=val_class_gen, 
                        callbacks=[early_stop, reduce_lr, mcp_save, lr_scheduler],
                       class_weight={0:2,1:2,2:3,3:2,4:2,5:1})
    

In [None]:
y_true, y_pred = np.empty(0,), np.empty(0,)
for i, batch in enumerate(val_class_gen):
    y_true = np.concatenate([y_true,batch[1].argmax(axis=1)])
    y_pred = np.concatenate([y_pred,model.predict(batch[0]).argmax(axis=1)])
    if BATCH_SIZE*i>=df_class_val.shape[0]:
        break

In [None]:
matrix = confusion_matrix(y_true, y_pred)
labels = classes.sort_values('label')['class'].values
df_cm = pd.DataFrame(matrix, labels, labels)
sns.heatmap(df_cm, annot=True, fmt='2g', cmap='Blues') 
plt.xlabel('predicted class')
plt.ylabel('true class')
plt.title("Confusion matrix")
plt.show()

In [None]:
perc_matrix = matrix/np.array([matrix.sum(axis=1)]).T
df_perc = pd.DataFrame(perc_matrix, labels, labels)
sns.heatmap(df_perc, annot=True, cmap='Blues') 
plt.xlabel('predicted class')
plt.ylabel('true class')
plt.title("Percentage confusion matrix")
plt.show()

## ODIN

In [None]:
def perturb_images(images, epsilon=0.002):
    
    test_ds_var = tf.Variable(images, trainable=True)

    with tf.GradientTape() as tape:
        tape.watch(test_ds_var)
        logits = model(test_ds_var, training=False)
        loss = -tf.reduce_mean(tf.reduce_max(logits, axis=1))

    gradients = tape.gradient(loss, test_ds_var)
    gradients = tf.math.greater_equal(gradients, 0)
    gradients = tf.cast(gradients, tf.float32)
    gradients = (gradients - 0.5) * 2

    static_tensor = tf.convert_to_tensor(test_ds_var) - epsilon * gradients
    static_tensor = tf.clip_by_value(static_tensor, 0., 255.)
    
    return static_tensor

In [None]:
path = OOD_DATA_PATH
filepaths, image_ids = [], []
for image in os.listdir(path):
    if image.endswith('.png'):
        image_id = image.split('-')[0][-3:]
        filepaths.append(os.path.join(path,image))
        image_ids.append(image_id)
df_ood = pd.DataFrame({"path": filepaths, "image":image_ids})

path = NORMAL_DATA_PATH
filepaths, videos = [], []
for image in os.listdir(path):
    if image.endswith('.png'):
        image_id = image[:3]
        filepaths.append(os.path.join(path,image))
        videos.append(image_id)
df_normal = pd.DataFrame({"path": filepaths, "video":videos})

val_gen=datagen.flow_from_dataframe(
    dataframe=df_normal,
    x_col="path",
    class_mode=None,
    target_size=(IM_SIZE_SB2,IM_SIZE_SB2),
    batch_size=BATCH_SIZE,
    shuffle=True,
    seed=1)

ood_gen=datagen.flow_from_dataframe(
    dataframe=df_ood,
    x_col="path",
    class_mode=None,
    target_size=(IM_SIZE_SB2,IM_SIZE_SB2),
    batch_size=BATCH_SIZE,
    shuffle=True,
    seed=1)

In [None]:
def visualize_images_2(generator, n=5):
    
    batch = generator.next()
    
    fig, ax = plt.subplots(n,n, figsize=(n,n))
    for i in range(n**2):
        ax[i%n,int(i/n)].imshow(batch[i])
        ax[i%n,int(i/n)].axis('off')
    fig.tight_layout()
    fig.show()

In [None]:
visualize_images_2(val_gen)

In [None]:
visualize_images_2(ood_gen)

### Plot MSP histograms

In [None]:
def get_msp(image_batch):
    softmax_probs = model.predict(image_batch)
    return 1-np.max(softmax_probs, axis=1)

def plot_hist(runs=50, epsilon=0):
    msp_list = []
    for generator in [val_gen, ood_gen]:
        msp = []
        for n in range(runs):
            batch = generator.next()
            if epsilon==0:
                msp = np.concatenate([msp,get_msp(batch)])
            else:
                msp = np.concatenate([msp,get_msp(perturb_images(batch, epsilon))])
        msp_list.append(msp)
    
    fig, ax = plt.subplots(figsize=(10,6))
    logbins = np.logspace(-7,0,80)
    ax.hist(msp_list[0], bins=logbins, alpha=0.4, label="Validation", color="blue")
    ax.hist(msp_list[1], bins=logbins, alpha=0.5, label="OOD", color="orange")
    #ax.axvline(x=THRESHOLD, color='r')
    plt.xlabel("MSP")
    plt.xscale('log')
    plt.legend()
    plt.title("MSP histograms")
    fig.show()
    
    fig.savefig(f'results/hist-e{epsilon}-F{FREEZE}.png')

In [None]:
for eps in [0, 0.001, 0.0025]:
    plot_hist(epsilon=eps)

In [None]:
def plot_hist_epsilon(epsilon_list):
    
    for epsilon in epsilon_list:
        print(epsilon)
        
        msp_list = []
        for generator in [val_gen, ood_gen]:
            msp = []
            for n in range(200):
                batch = generator.next()
                if epsilon==0:
                    msp = np.concatenate([msp,get_msp(batch)])
                else:
                    msp = np.concatenate([msp,get_msp(perturb_images(batch, epsilon))])
            msp_list.append(msp)
            
        fig, ax = plt.subplots(figsize=(4,3), dpi=100)
        logbins = np.logspace(-6,0,50)
        ax.hist(msp_list[0], bins=logbins, alpha=0.4, label="Validation", color="blue")
        ax.hist(msp_list[1], bins=logbins, alpha=0.5, label="OOD", color="orange")
        ax.set_xlabel("1-$\mathcal{S}(x)$")
        ax.set_yticks([])
        ax.set_xscale('log')
        ax.legend()
        plt.savefig(f'results/tfg-hist-eps{epsilon}.png',  bbox_inches='tight')
        fig.show()

### Plot ROC curve

In [None]:
def plot_roc(runs=30, threshold_range=np.concatenate([[0.],np.logspace(-8,-4,30), np.logspace(-4,-1,150), np.logspace(-1,0,50)]),
                   label="MSP", batch_size = BATCH_SIZE, epsilon=0):
    
    y_pred, y_true, y_probs = [], [], []
    
    for n in range(runs):
        
        val_batch = val_gen.next()
        test_batch = ood_gen.next()
        
        if epsilon==0:
            val_msp = get_msp(val_batch)
            test_msp = get_msp(test_batch)
        else:
            val_msp = get_msp(perturb_images(val_batch, epsilon))
            test_msp = get_msp(perturb_images(test_batch, epsilon))
        
        if val_msp.shape[0]==batch_size:
            y_probs = np.concatenate([y_probs, val_msp])
            y_true = np.concatenate([y_true, np.zeros(batch_size)])
        
        if test_msp.shape[0]==batch_size:
            y_probs = np.concatenate([y_probs, test_msp])
            y_true = np.concatenate([y_true, np.ones(batch_size)])

    roc_values = []
    
    for threshold in threshold_range:
        
        y_pred = np.where(y_probs >= threshold, 1, 0)
        
        sensitivity = np.sum(np.logical_and(y_pred,y_true))/np.sum(y_true)
        specificity = np.sum(~np.logical_or(y_pred,y_true)/(y_pred.size-np.sum(y_true)))
                             
        roc_values.append((1-specificity, sensitivity))

    plt.step(*list(zip(*roc_values)), where='post', label=f'$\epsilon$={epsilon}')
    plt.xlabel("1 - specificity")
    plt.ylabel("sensitivity")
    plt.plot([0,1], [0,1], '--', color="gray", linewidth=2)
    plt.legend()
    plt.title("WCE ODIN ROC")
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.savefig(f'results/roc-e{epsilon}-F{FREEZE}.png')
    
    revspec, sens = zip(*roc_values)
    revspec, sens = np.array(revspec), np.array(sens)
    spec = 1-np.array(revspec)
    
    df_roc = pd.DataFrame({'spec':spec, 'sens':sens})
    df_roc.to_csv(f'results/roc-e{epsilon}.csv')

    # metrics
    print("AUROC:",np.abs(np.trapz(spec, np.array(sens))))
    print("FPR80:",revspec[np.argmin(np.abs(sens-0.85))])
    print("FPR90:",revspec[np.argmin(np.abs(sens-0.90))])

In [None]:
plot_roc(epsilon=0.002)

In [None]:
for eps in [0, 0.001, 0.0025, 0.005]:
    plot_roc(epsilon=eps)

## Qualitative analysis

In [None]:
images = np.empty((0,224,224,3))
scores = np.empty((0,))
N = 4
for n in range(N):
    batch = ood_gen.next()
    batch_scores = get_msp(perturb_images(batch, 0.002))
    
    images = np.concatenate([images, batch])
    scores = np.concatenate([scores, batch_scores])
    
sorted_images = list(images[np.argsort(scores)])

In [None]:
# top 5
n = 5
top_image = np.zeros((224, 224*n, 3))
top_images = sorted_images[-n:]
np.random.shuffle(top_images)
for i, image in enumerate(top_images):
    top_image[:,i*224:(i+1)*224,:] += image
plt.imshow(top_image)
plt.axis('off')
plt.savefig(f'qualitative/top_images.png', dpi=300)

In [None]:
# top 5
n = 5
bot_image = np.zeros((224, 224*n, 3))
bot_images = sorted_images[:n]
np.random.shuffle(bot_images)
for i, image in enumerate(bot_images):
    bot_image[:,i*224:(i+1)*224,:] += image
plt.imshow(bot_image)
plt.axis('off')
plt.savefig(f'qualitative/bot_images.png', dpi=300)