In [1]:
import math
import numpy as np
import tensorflow as tf
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.utils import shuffle
from tensorflow.keras import backend as K
from tensorflow.keras.activations import softmax, relu
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.datasets import mnist, fashion_mnist, cifar10
from tensorflow.keras.initializers import RandomUniform
from tensorflow.keras.layers import Dense, Conv2D, Input, Flatten, Conv2DTranspose, Reshape, BatchNormalization, ReLU, \
    Layer, GlobalAveragePooling2D, Concatenate, Add
from tensorflow.keras.models import Model, load_model
from skimage.measure import shannon_entropy

In [2]:
class LAMAECallback(Callback):
    def __init__(self, saved_model_name):
        super().__init__()   
        self.saved_model_name = saved_model_name
        
    def on_epoch_end(self, epoch, logs=None):
        # Save the model with minimal reconstruction error on the validation set
        x_val_recon, _ = self.model.predict(x_val)
        recon_err_mean = np.mean(np.square(x_val - x_val_recon))
        
        global min_recon_err
        if recon_err_mean < min_recon_err:
            min_recon_err = recon_err_mean
            self.model.save(self.saved_model_name)
            print('save model, epoch={}, recon_err_mean={}'.format(epoch, recon_err_mean, 2))

In [4]:
def compute_cosine_distances(a, b):
    """Calculate the cosine distance between a and b"""
    a_normalized, _ = tf.linalg.normalize(a, ord=1, axis=-1)
    b_normalized, _ = tf.linalg.normalize(b, ord=1, axis=-1)
    b_normalized_transposed = tf.transpose(b_normalized)
    distance = tf.matmul(a_normalized, b_normalized_transposed)
            
    return distance


class MemoryUnit(Layer):
    def __init__(self, class_num, block_size, name=None, **kwargs):
        super(MemoryUnit, self).__init__(name=name)
        self.class_num = class_num
        self.block_size = block_size
        super(MemoryUnit, self).__init__(**kwargs)
    
    def get_config(self):
        config = super(MemoryUnit, self).get_config()
        config.update({
            "class_num": self.class_num,
            "block_size": self.block_size})
        
        return config
    
    def build(self, input_shape):        
        # shape: (class_num x block_size, latent dim) = (M, L)
        self.weight = self.add_weight(shape=(self.class_num * self.block_size, input_shape[-1]),
                                      initializer=RandomUniform(-5, 5),
                                      trainable=True)

    def call(self, z, y):
        y = tf.where(tf.equal(tf.reduce_max(y, axis=1, keepdims=True), y),
                     tf.constant(1.0),
                     tf.constant(0.0))
        att_weight = tf.tile(y, [1, self.block_size])
        
        # z x Mem^T -> (batchxL) x (LxM) = batchxM
        sim_weight = compute_cosine_distances(z, self.weight) 
        
        com_weight = att_weight * sim_weight
        com_weight = softmax(com_weight)

        output = tf.matmul(com_weight, self.weight)
        return output

In [4]:
def get_lamae_model(input_shape=(28, 28, 1), 
                    channels=(32, 64), 
                    latent_dim=32, 
                    skip_dim=0, 
                    class_num=9, 
                    block_size=10):
    x_in = Input(shape=input_shape)
    x = x_in
    # encoder
    for c in channels:
        x = Conv2D(c, 3, 2, padding='same')(x)
        x = BatchNormalization()(x)
        x = ReLU()(x)

    shape = K.int_shape(x)
    x = Flatten()(x)
    
    z = Dense(latent_dim)(x)
    z = BatchNormalization()(z)
    
    if skip_dim != 0:
        sz = Dense(skip_dim)(z)
        sz = Dense(latent_dim)(sz)
    
    clf_out = Dense(class_num, activation='softmax', name='clf')(z)
    z = MemoryUnit(class_num, block_size)(z, clf_out)
    
    # skip connection
    if skip_dim != 0:
        z = Add()([z, sz])
    x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(z)
    x = Reshape((shape[1], shape[2], shape[3]))(x)
    
    # decoder
    for c in channels[::-1]:
        x = Conv2DTranspose(c, 3, 2, padding='same')(x)
        x = BatchNormalization()(x)
        x = ReLU()(x)

    x_out = Conv2DTranspose(input_shape[2], 3, activation='sigmoid', padding='same', name='recon')(x)

    model = Model(x_in, [x_out, clf_out])

    model.compile(optimizer='adam', loss={'recon': 'mse', 'clf': 'sparse_categorical_crossentropy'})

    return model

# Mnist

In [5]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape((60000, 28, 28, 1)) / 255.0
x_test = x_test.reshape((10000, 28, 28, 1)) / 255.0
print(x_train.shape)
print(x_test.shape)


def handle_data(ood_class: int):
    # id = 1, ood = -1
    index = np.array([e != ood_class for e in y_train])
    new_x_train, new_y_train = x_train[index], y_train[index]
    for i in range(len(new_y_train)):
        new_y_train[i] = new_y_train[i] - 1 if new_y_train[i] > ood_class else new_y_train[i]
    
    new_ood_x_test_1 = x_train[index == False]
    
    index = np.array([e != ood_class for e in y_test])
    new_ood_x_test_2 = x_test[index == False]
    new_ood_x_test = list(new_ood_x_test_1) + list(new_ood_x_test_2)
    new_id_x_test = x_test[index][:len(new_ood_x_test)]
    
    new_x_test = np.array(list(new_id_x_test) + list(new_ood_x_test))
    new_y_test = np.array([1 for _ in range(len(new_id_x_test))] + [-1 for _ in range(len(new_ood_x_test))])
    new_x_test, new_y_test = shuffle(new_x_test, new_y_test)
    
    # new_y_train：[0,1,...,8], new_y_test: {-1, 1}
    return new_x_train, new_y_train, new_x_test, new_y_test

(60000, 28, 28, 1)
(10000, 28, 28, 1)


In [6]:
for ood_class in range(10):
    new_x_train, new_y_train, new_x_test, new_y_test = handle_data(ood_class)
    x_val = new_x_train[50000:]
    y_val = new_y_train[50000:]
    new_x_train = new_x_train[:50000]
    new_y_train = new_y_train[:50000]
    for ite in range(1):
        print('*' * 30, 'ood={}, ite={}'.format(ood_class, ite), '*' * 30)
        saved_model_name = './saved_model/mnist_ood{}_ite{}.h5'.format(ood_class, ite)

        min_recon_err = float('inf')
        model = get_lamae_model()
        model.fit(new_x_train, {'recon': new_x_train, 'clf': new_y_train},
              shuffle=True,
              epochs=200,
              batch_size=128,
              verbose=0,
              callbacks = [LAMAECallback(saved_model_name)])
        
        # Testing
        model = load_model(saved_model_name, custom_objects={'MemoryUnit': MemoryUnit})        
        x_test_recon, _ = model.predict(new_x_test)
        recon_err = np.mean(np.square(new_x_test - x_test_recon), axis=(1, 2, 3))
        
        # LAMAE
        auroc = round(100 * roc_auc_score(new_y_test, -recon_err), 2)
        auprc = round(100 * average_precision_score(new_y_test, -recon_err), 2)
        print('LAMAE test set, ood={}, ite={}, auroc={}%, auprc={}%'.format(ood_class, ite, auroc, auprc))
        
        # LAMAE+
        entropys = np.array([shannon_entropy(e) for e in new_x_test])
        sigma = 1e-9
        auroc = round(100 * roc_auc_score(new_y_test, -recon_err / (sigma + np.log(entropys + 1))), 2)
        auprc = round(100 * average_precision_score(new_y_test, -recon_err / (sigma + np.log(entropys + 1))), 2)
        print('LAMAE+ test set, ood={}, ite={}, auroc={}%, auprc={}%'.format(ood_class, ite, auroc, auprc))
        print()

****************************** ood=0, ite=0 ******************************
save model, epoch=0, recon_err_mean=0.07304661010102313
save model, epoch=1, recon_err_mean=0.06439831797065132
save model, epoch=2, recon_err_mean=0.06438556842932208
save model, epoch=4, recon_err_mean=0.06434732218505604
save model, epoch=6, recon_err_mean=0.06434575164646973
save model, epoch=8, recon_err_mean=0.06429845968085995
save model, epoch=10, recon_err_mean=0.062117923859081485
save model, epoch=11, recon_err_mean=0.05589562052728274
save model, epoch=12, recon_err_mean=0.05273641010001821
save model, epoch=18, recon_err_mean=0.0523946469219118
save model, epoch=23, recon_err_mean=0.05007775427277692
save model, epoch=30, recon_err_mean=0.039101891806901615
save model, epoch=39, recon_err_mean=0.03816913532914424
save model, epoch=46, recon_err_mean=0.037245903512328704
save model, epoch=48, recon_err_mean=0.03567006885536443
save model, epoch=54, recon_err_mean=0.03383220303225055
save model, epoch

save model, epoch=42, recon_err_mean=0.027877750996571908
save model, epoch=53, recon_err_mean=0.02490018672317298
save model, epoch=63, recon_err_mean=0.02431809533921181
save model, epoch=70, recon_err_mean=0.023469721182953467
save model, epoch=73, recon_err_mean=0.02344716367869373
save model, epoch=74, recon_err_mean=0.022827300066972235
save model, epoch=88, recon_err_mean=0.022404116831196058
save model, epoch=90, recon_err_mean=0.021693687409311883
save model, epoch=92, recon_err_mean=0.020819841653785063
save model, epoch=93, recon_err_mean=0.02014513450038573
save model, epoch=94, recon_err_mean=0.019977353165798854
save model, epoch=112, recon_err_mean=0.01985237892409896
save model, epoch=117, recon_err_mean=0.019571017807680266
save model, epoch=123, recon_err_mean=0.01914541652049515
save model, epoch=134, recon_err_mean=0.018816083604440502
save model, epoch=135, recon_err_mean=0.01837821801127934
save model, epoch=136, recon_err_mean=0.01808056687253058
save model, epoc

save model, epoch=87, recon_err_mean=0.02204214537820739
save model, epoch=91, recon_err_mean=0.02157740738596243
save model, epoch=95, recon_err_mean=0.02040621832119869
save model, epoch=98, recon_err_mean=0.018649069476635863
save model, epoch=111, recon_err_mean=0.017932608134838875
save model, epoch=123, recon_err_mean=0.017489046263264644
save model, epoch=128, recon_err_mean=0.017467449936994707
save model, epoch=130, recon_err_mean=0.01720763984906595
save model, epoch=154, recon_err_mean=0.01693595714564324
save model, epoch=164, recon_err_mean=0.016887437894140187
save model, epoch=172, recon_err_mean=0.016733937168752353
save model, epoch=179, recon_err_mean=0.01668523745025343
save model, epoch=187, recon_err_mean=0.01571389654619901
LAMAE test set, ood=8, ite=0, auroc=97.72%, auprc=98.19%
LAMAE+ test set, ood=8, ite=0, auroc=97.24%, auprc=97.88%

****************************** ood=9, ite=0 ******************************
save model, epoch=0, recon_err_mean=0.080745943970108

# Fashion Mnist

In [16]:
(x_train, y_train), (id_x, _) = fashion_mnist.load_data()
x_train = x_train.reshape((60000, 28, 28, 1)) / 255.0
id_x = id_x.reshape((10000, 28, 28, 1)) / 255.0

x_val = x_train[50000:]
y_val = y_train[50000:]
new_x_train = x_train[:50000]
new_y_train = y_train[:50000]

(_, _), (ood_x, _) = mnist.load_data()
ood_x = ood_x.reshape((10000, 28, 28, 1)) / 255.0

new_x_test = np.array(list(id_x) + list(ood_x))
new_y_test = np.array([1 for _ in range(len(id_x))] + [-1 for _ in range(len(ood_x))])

for ite in range(1):
    print('*' * 30, 'ite={}'.format(ite), '*' * 30)
    saved_model_name = './saved_model/fashion_mnist_ite{}.h5'.format(ite)

    min_recon_err = float('inf')
    model = get_lamae_model(class_num=10)
    model.fit(new_x_train, {'recon': new_x_train, 'clf': new_y_train},
          shuffle=True,
          epochs=200,
          batch_size=128,
          verbose=0,
          callbacks = [LAMAECallback(saved_model_name)])

    # Testing
    model = load_model(saved_model_name, custom_objects={'MemoryUnit': MemoryUnit})        
    x_test_recon, _ = model.predict(new_x_test)
    recon_err = np.mean(np.square(new_x_test - x_test_recon), axis=(1, 2, 3))

    # LAMAE
    auroc = round(100 * roc_auc_score(new_y_test, -recon_err), 2)
    auprc = round(100 * average_precision_score(new_y_test, -recon_err), 2)
    print('LAMAE test set, ite={}, auroc={}%, auprc={}%'.format(ite, auroc, auprc))

    # LAMAE+
    entropys = np.array([shannon_entropy(e) for e in new_x_test])
    sigma = 1e-9
    auroc = round(100 * roc_auc_score(new_y_test, -recon_err / (sigma + np.log(entropys + 1))), 2)
    auprc = round(100 * average_precision_score(new_y_test, -recon_err / (sigma + np.log(entropys + 1))), 2)
    print('LAMAE+ test set, ite={}, auroc={}%, auprc={}%'.format(ite, auroc, auprc))
    print()

****************************** ite=0 ******************************
save model, epoch=0, recon_err_mean=0.10740130426265644
save model, epoch=1, recon_err_mean=0.08917269044603907
save model, epoch=3, recon_err_mean=0.044029929882564225
save model, epoch=9, recon_err_mean=0.0427236279861599
save model, epoch=12, recon_err_mean=0.041053776987119255
save model, epoch=15, recon_err_mean=0.0400308906560903
save model, epoch=25, recon_err_mean=0.039979674336593406
save model, epoch=31, recon_err_mean=0.03202413620805771
save model, epoch=42, recon_err_mean=0.027136426937306543
save model, epoch=51, recon_err_mean=0.026271429139769424
save model, epoch=56, recon_err_mean=0.024075263515351372
save model, epoch=59, recon_err_mean=0.024062929630682173
save model, epoch=68, recon_err_mean=0.023059636714099503
save model, epoch=71, recon_err_mean=0.022979406500009788
save model, epoch=87, recon_err_mean=0.022644260849935566
save model, epoch=89, recon_err_mean=0.02221079706781963
save model, epoc

# CIFAR10

In [19]:
def handle_fashion(ood_x):
    # 28 * 28 -> 32 * 32 * 3
    images = []

    for t in range(len(ood_x)):
        img = np.zeros((32, 32, 3), dtype=int)
        for i in range(28):
            for j in range(28):
                img[i + 2][j + 2][0] = ood_x[t][i][j]
                img[i + 2][j + 2][1] = ood_x[t][i][j]
                img[i + 2][j + 2][2] = ood_x[t][i][j]
        img = img / 255.0
        images.append(img)

    return np.array(images)

In [28]:
(x_train, y_train), (id_x, _) = cifar10.load_data()
x_train, id_x = x_train / 255.0, id_x / 255.0
y_train = y_train.reshape(y_train.shape[0])

x_val = x_train[45000:]
y_val = y_train[45000:]
new_x_train = x_train[:45000]
new_y_train = y_train[:45000]

(_, _), (ood_x, _) = fashion_mnist.load_data()
ood_x = handle_fashion(ood_x)

new_x_test = np.array(list(id_x) + list(ood_x))
new_y_test = np.array([1 for _ in range(len(id_x))] + [-1 for _ in range(len(ood_x))])

for ite in range(1):
    print('*' * 30, 'ite={}'.format(ite), '*' * 30)
    saved_model_name = './saved_model/cifar10_ite{}.h5'.format(ite)

    min_recon_err = float('inf')
    model = get_lamae_model(input_shape=(32, 32, 3), 
                            channels=(64, 128, 128, 256), 
                            latent_dim=64, 
                            skip_dim=16, 
                            class_num=10, 
                            block_size=50)
    model.fit(new_x_train, {'recon': new_x_train, 'clf': new_y_train},
          shuffle=True,
          epochs=200,
          batch_size=128,
          verbose=0,
          callbacks = [LAMAECallback(saved_model_name)])

    # Testing
    model = load_model(saved_model_name, custom_objects={'MemoryUnit': MemoryUnit})        
    x_test_recon, _ = model.predict(new_x_test)
    recon_err = np.mean(np.square(new_x_test - x_test_recon), axis=(1, 2, 3))

    # LAMAE
    auroc = round(100 * roc_auc_score(new_y_test, -recon_err), 2)
    auprc = round(100 * average_precision_score(new_y_test, -recon_err), 2)
    print('LAMAE test set, ite={}, auroc={}%, auprc={}%'.format(ite, auroc, auprc))

    # LAMAE+
    entropys = np.array([shannon_entropy(e) for e in new_x_test])
    sigma = 1e-9
    auroc = round(100 * roc_auc_score(new_y_test, -recon_err / (sigma + np.log(entropys + 1))), 2)
    auprc = round(100 * average_precision_score(new_y_test, -recon_err / (sigma + np.log(entropys + 1))), 2)
    print('LAMAE+ test set, ite={}, auroc={}%, auprc={}%'.format(ite, auroc, auprc))
    print()

****************************** ite=0 ******************************
save model, epoch=0, recon_err_mean=0.045423547576021316
save model, epoch=1, recon_err_mean=0.02964047663231754
save model, epoch=2, recon_err_mean=0.0271723156510119
save model, epoch=6, recon_err_mean=0.026033273632340592
save model, epoch=7, recon_err_mean=0.02571689254719984
save model, epoch=8, recon_err_mean=0.02442935849129246
save model, epoch=10, recon_err_mean=0.02429117083186729
save model, epoch=11, recon_err_mean=0.024230629314131013
save model, epoch=13, recon_err_mean=0.023326100364181663
save model, epoch=18, recon_err_mean=0.022671189151649514
save model, epoch=25, recon_err_mean=0.022185325189091304
save model, epoch=28, recon_err_mean=0.021717012363867944
save model, epoch=37, recon_err_mean=0.02142850466613022
save model, epoch=43, recon_err_mean=0.021220117146877424
save model, epoch=56, recon_err_mean=0.021129029082184586
save model, epoch=57, recon_err_mean=0.021126819115736294
save model, epoch