In [None]:
import os
import time as t
import random as rn
import numpy as np
import tensorflow as tf
import keras.backend as K
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, accuracy_score
from keras import metrics, optimizers, regularizers
from keras.datasets import mnist, cifar10
from keras.models import Model, load_model
from keras.layers import Dense, Input, Reshape, Concatenate, Subtract, ZeroPadding2D
from keras.layers import LeakyReLU, BatchNormalization, Conv2D, Flatten, Dropout
from keras.layers import AveragePooling2D, Softmax
from keras.initializers import truncated_normal as tn
from tensorflow.keras.utils import to_categorical

In [None]:
def cal_openset_baccu(ground_truth=None, prediction=None, label_ref=None):
    """Calculate balanced accuracy for open set recognition.

    :param ground_truth: True labels.
    :param prediction: Predicted labels.
    :param label_ref: A list of class labels in ascending order.
    :return: Balanced accuracy.
    """

    # Abnormal samples have the label of zero which are considered negative
    matrix = confusion_matrix(ground_truth, prediction, labels=label_ref)

    # Number of correctly predicted abnormal samples
    tn = matrix[0, 0]

    # Number of correctly predicted normal samples
    tp = np.trace(matrix) - tn

    num_pos = np.count_nonzero(ground_truth)
    num_neg = len(ground_truth) - num_pos

    tnr = tn/num_neg
    tpr = tp/num_pos
    baccu = 0.5 * (tnr + tpr)

    return baccu


def cal_closed_set_accu(ground_truth=None, prediction=None):
    """Calculate conventional closed set accuracy.

    :param ground_truth: True labels.
    :param prediction: Predicted labels.
    :return: Closed set accuracy.
    """

    closed_set_accuracy = accuracy_score(y_true=ground_truth, y_pred=prediction)
    print('Closed-set accuracy is %.4f' % closed_set_accuracy)

    return closed_set_accuracy


def cal_modified_auc(ground_truth=None, prediction=None):
    """Calculate modified AUC according to Neal et al.

    :param ground_truth: True labels.
    :param prediction: Predicted logits values.
    :return: Modified AUC.
    """

    pred_abnormal = prediction[:, 0]
    pred_normal = np.max(prediction[:, 1:], axis=-1)
    pred_score = pred_abnormal - pred_normal

    auc = roc_auc_score((ground_truth == 0)*1, pred_score)
    print('AUC is %.4f' % auc)

    return auc

In [None]:
import scipy.io

def set_seed(first_seed=2018):
    """Set seed for reproducible results.

    :param first_seed: Integer number as the global seed.
    """

    os.environ['PYTHONHASHSEED'] = '0'
    np.random.seed(first_seed)
    rn.seed(10)

    session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
    tf.compat.v1.set_random_seed(16)

    sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
    K.set_session(sess)


def _extract_data(data=None, label=None, target_lb=None):
    """Extract dataset regarding given normal / abnormal labels-

    :param data: A numpy tensor. First axis should be the number of samples.
    :param label: The corresponding labels for the data.
    :param target_lb: An integer value standing for the only one known class.
    :return: normal data, abnormal data, normal labels, abnormal labels
    """

    index = 0
    if isinstance(target_lb, int):
        index = index + (label == target_lb) * 1
    else:
        for lb in target_lb:
            index = index + (label == lb) * 1

    normal_idx = np.where(index == 1)[0]
    abnormal_idx = np.where(index == 0)[0]

    data_normal = data[normal_idx]
    data_abnormal = data[abnormal_idx]

    label_normal = label[normal_idx]
    label_abnormal = label[abnormal_idx]

    return data_normal, data_abnormal, label_normal, label_abnormal


def _reshape_data(data=None, data_shape=None, num_channels=None):
    """Reshape image data into vectors / matrices / tensors.

    :param data: A numpy tensor. First axis should be the number of samples.
    :param data_shape: Desired data shape. It should be a string.
    :param num_channels: Number of the channels of the given data.
    :return: Reshaped data.
    """

    num_samples = data.shape[0]
    data = data.reshape(num_samples, -1)
    num_features = data.shape[-1]
    height = int(np.sqrt(num_features / num_channels))
    width = num_features // (num_channels*height)
    if not isinstance(width, int):
        raise ValueError('\nThe input images should be in square form...')

    if data_shape == 'vector':
        pass

    elif data_shape == 'matrix':
        if num_channels == 1:
            data = data.reshape(num_samples, height, width)
        elif num_channels == 3:
            data = data.reshape(num_samples, height, width, num_channels)
            # Transform RGB images into gray-scale images
            data = 0.2989 * data[:, :, :, 0] + 0.5870 * data[:, :, :, 1] + 0.1140 * data[:, :, :, 2]
            data = data.reshape(num_samples, height, width)
        else:
            raise ValueError('The input data should be either gray-scale images or color images...')

    elif data_shape == 'tensor':
        data = data.reshape(num_samples, height, width, num_channels)

    else:
        raise ValueError('\nNo suitable data shape is found. Please enter a desired data shape...')

    return data


def get_data(dataset=None, normal_class=None, data_format=None, preprocess='minmax'):
    """Obtain the dataset in a desired form stored in a dictionary.

    :param dataset: The name of desired dataset: mnist, fmnist or cifar10.
    :param normal_class: The class which is considered to be known during training.
    :param data_format: The desired data shape: vector, matrix or tensor.
    :param preprocess: The name of a preprocessing method: minmax or mean.
    :return: A dictionary containing training and testing samples.
    """

    if dataset == 'mnist':
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        num_channel = 1
    elif dataset == 'cifar10':
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        num_channel = 3
    else:
        raise ValueError('\nNo datasets are found. Please select one dataset...')

    # Reshape data and its label into desired format
    y_train = np.reshape(y_train, newshape=(-1,))
    y_test = np.reshape(y_test, newshape=(-1,))

    x_train = _reshape_data(data=x_train, data_shape=data_format, num_channels=num_channel)
    x_test = _reshape_data(data=x_test, data_shape=data_format, num_channels=num_channel)

    # Data normalization
    x_train = (x_train / 255).astype('float32')
    x_test = (x_test / 255).astype('float32')

    if preprocess == 'minmax':
        x_train = (x_train - np.min(x_train))/(np.max(x_train) - np.min(x_train))
        x_test = (x_test - np.min(x_test)) / (np.max(x_test) - np.min(x_test))
    elif preprocess == 'mean':
        x_train = x_train - np.mean(x_train)
        x_test = x_test - np.mean(x_test)
    else:
        raise ValueError('\nPlease give a valid preprocessing method...')

    if normal_class is None:
        data = {'x_train': x_train,
                'y_train': y_train,
                'x_test': x_test,
                'y_test': y_test}
    else:
        train_set = _extract_data(data=x_train, label=y_train, target_lb=normal_class)
        test_set = _extract_data(data=x_test, label=y_test, target_lb=normal_class)

        data = {'x_train_normal': train_set[0], 'x_train_abnormal': train_set[1],
                'y_train_normal': train_set[2], 'y_train_abnormal': train_set[3],
                'x_test_normal': test_set[0], 'x_test_abnormal': test_set[1],
                'y_test_normal': test_set[2], 'y_test_abnormal': test_set[3]}
    return data


# ==================== Image Processing ====================


def assign_label(normal_class=None, original_label=None, include_zero=None):
    """Assign labels to the selected known classes.

    :param normal_class: A list of unique selected known classes labels.
    :param original_label: A list of selected known classes samples' labels-
    :param include_zero: Boolean variable. True for include zero as the starting label. Otherwise one.
    :return: Modified labels.
    """

    num_normal_cls = len(normal_class)

    temp_lb = 0
    for idx in range(num_normal_cls):
        temp_lb = temp_lb + (original_label == normal_class[idx]) * idx

    # New labels begin with 1
    if not include_zero:
        temp_lb = temp_lb + 1
    return temp_lb


def split_data(model_name=None, data=None, rho=None, split_method=None, ground_truth=None, normal_class=None):
    """Split the dataset according to a given split method.

    :param model_name: The name for the model used for splitting.
    :param data: Selected known classes training data.
    :param rho: Splitting ratio.
    :param split_method: The name of splitting method: cnn.
    :param ground_truth: Original label list for the selected training samples.
    :param normal_class: A list of selected known classes.
    :return: Indices of typical and atypical samples in the training dataset.
    """

    cnn_path = './trained_models/cnn_for_ds_%s.h5' % model_name

    print('\nSplitting data...')

    if split_method == 'cnn':
        if os.path.isfile(cnn_path):
            model = load_model(filepath=cnn_path, compile=False)
        else:
            raise ValueError('No suitable CNN for data splitting...')

        print('\nCalculating categorical probability using trained CNN...')
        probability = model.predict(data, batch_size=128)
        pred = to_categorical(np.argmax(probability, axis=-1), num_classes=probability.shape[-1])

        gt = assign_label(normal_class=normal_class, original_label=ground_truth, include_zero=True)
        gt = to_categorical(gt, num_classes=len(normal_class))
        sim_score = gt * pred * probability

        sim_score = np.max(sim_score, axis=-1)
        sim_thr = np.percentile(sim_score, rho)

        print('\nThe sim_thr is %.4f' % sim_thr)
    else:
        raise ValueError('No suitable data splitting method...')

    typical_index = np.where(sim_score > sim_thr)
    atypical_index = np.where(sim_score <= sim_thr)

    return typical_index, atypical_index


In [None]:
def build_cnn(img_height=None, num_channel=None, reg=None, latent_fea=None, num_normal_class=None, cnn_type=None):
    """Build CNN or OSRNET.

    :param img_height: Input image height (width should be equal to this).
    :param num_channel: Number of image channels.
    :param reg: Decay of regularization terms.
    :param latent_fea: Number of latent features.
    :param num_normal_class: Number of known classes.
    :param cnn_type: The name of the CNN used as a backbone: modified_vgg, alexnet, mlp, densenet, etc.
    :return: A list of models.
    """

    # ==================== Constants Definition ====================
    acti_func = 'linear'
    clf_acti = 'softmax'

    acti_alpha = 0.2
    set_bias = False

    weights_init = tn(mean=0, stddev=0.01)

    bn_eps = 1e-3
    bn_m = 0.99

    logits_layer = None

    # ==================== General Input Layer ====================
    input_layer = Input(shape=(img_height, img_height, num_channel), name='input_layer')

    if cnn_type == 'modified_vgg':

        conv_1 = Conv2D(filters=32, kernel_size=(3, 3), activation=acti_func, name='conv_1',
                        padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                        kernel_initializer=weights_init)(input_layer)
        conv_11 = Conv2D(filters=32, kernel_size=(3, 3), activation=acti_func, name='conv_11',
                         padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                         kernel_initializer=weights_init)(conv_1)
        conv_1 = Concatenate()([conv_1, conv_11])  # 32x32x64

        lrelu_1 = LeakyReLU(alpha=acti_alpha)(conv_1)

        pool_1 = AveragePooling2D(pool_size=(2, 2), name='pool_1')(lrelu_1)  # 16x16 / 14x14

        bn_1 = BatchNormalization(momentum=bn_m, epsilon=bn_eps, name='bn_1')(pool_1)

        conv_2 = Conv2D(filters=64, kernel_size=(3, 3), activation=acti_func, name='conv_2',
                        padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                        kernel_initializer=weights_init)(bn_1)
        conv_22 = Conv2D(filters=64, kernel_size=(3, 3), activation=acti_func, name='conv_22',
                         padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                         kernel_initializer=weights_init)(conv_2)
        conv_2 = Concatenate()([conv_2, conv_22])  # 16x16x128

        lrelu_2 = LeakyReLU(alpha=acti_alpha)(conv_2)

        pool_2 = AveragePooling2D(pool_size=(2, 2), name='pool_2')(lrelu_2)  # 8x8 / 7x7

        if img_height == 28:
            pool_2 = ZeroPadding2D(padding=(1, 1))(pool_2)  # zero-padding if mnist or fashion-mnist

        bn_2 = BatchNormalization(momentum=bn_m, epsilon=bn_eps, name='bn_2')(pool_2)

        conv_3 = Conv2D(filters=128, kernel_size=(3, 3), activation=acti_func, name='conv_3',
                        padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                        kernel_initializer=weights_init)(bn_2)
        conv_33 = Conv2D(filters=128, kernel_size=(3, 3), activation=acti_func, name='conv_33',
                         padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                         kernel_initializer=weights_init)(conv_3)
        conv_3 = Concatenate()([conv_3, conv_33])  # 8x8x256

        lrelu_3 = LeakyReLU(alpha=acti_alpha)(conv_3)

        pool_3 = AveragePooling2D(pool_size=(2, 2), name='pool_3')(lrelu_3)  # 4x4

        bn_3 = BatchNormalization(momentum=bn_m, epsilon=bn_eps, name='bn_3')(pool_3)

        conv_4 = Conv2D(filters=256, kernel_size=(3, 3), activation=acti_func, name='conv_4',
                        padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                        kernel_initializer=weights_init)(bn_3)
        conv_44 = Conv2D(filters=256, kernel_size=(3, 3), activation=acti_func, name='conv_44',
                         padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                         kernel_initializer=weights_init)(conv_4)
        conv_4 = Concatenate()([conv_4, conv_44])  # 4x4x512

        lrelu_4 = LeakyReLU(alpha=acti_alpha)(conv_4)

        pool_4 = AveragePooling2D(pool_size=(2, 2), name='pool_4')(lrelu_4)  # 2x2

        bn_4 = BatchNormalization(momentum=bn_m, epsilon=bn_eps, name='bn_4')(pool_4)

        conv_5 = Conv2D(filters=256, kernel_size=(1, 1), activation=acti_func, name='conv_5',
                        kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                        kernel_initializer=weights_init)(bn_4)  # 2x2
        conv_55 = Conv2D(filters=256, kernel_size=(1, 1), activation=acti_func, name='conv_55',
                         kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                         kernel_initializer=weights_init)(bn_4)
        conv_5 = Concatenate()([conv_5, conv_55])  # 2x2x512

        lrelu_5 = LeakyReLU(alpha=acti_alpha)(conv_5)

        bn_5 = BatchNormalization(name='bn_5')(lrelu_5)

        flt_7 = Flatten()(bn_5)

        dense_8 = Dense(units=256, activation=acti_func, name='dense_8',
                        kernel_regularizer=regularizers.l2(reg), use_bias=not set_bias,
                        kernel_initializer=weights_init)(flt_7)

        lrelu_8 = LeakyReLU(alpha=acti_alpha)(dense_8)

        drop_8 = Dropout(rate=0.5)(lrelu_8)

        dense_9 = Dense(units=latent_fea, activation=acti_func, name='dense_9',
                        kernel_regularizer=regularizers.l2(reg), use_bias=not set_bias,
                        kernel_initializer=weights_init)(drop_8)

        lrelu_9 = LeakyReLU(alpha=acti_alpha)(dense_9)

        drop_9 = Dropout(rate=0.5)(lrelu_9)

        dense_10 = Dense(units=num_normal_class, activation='linear', name='dense_10',
                         kernel_regularizer=regularizers.l2(reg), use_bias=not set_bias,
                         kernel_initializer=weights_init)(drop_9)
        sf_10 = Softmax()(dense_10)

        top_layer = Reshape(target_shape=(-1,), name='top_layer')(sf_10)
        latent_layer = Reshape(target_shape=(-1,), name='latent_layer')(lrelu_9)
        logits_layer = Reshape(target_shape=(-1,), name='logits_layer')(dense_10)

    elif cnn_type == 'logits_cnn':

        conv_1 = Conv2D(filters=32, kernel_size=(7, 7), activation=acti_func, name='conv_1',
                        padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                        kernel_initializer=weights_init)(input_layer)
        conv_11 = Conv2D(filters=32, kernel_size=(7, 7), activation=acti_func, name='conv_11',
                         padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                         kernel_initializer=weights_init)(conv_1)
        conv_1 = Concatenate()([conv_1, conv_11])

        lrelu_1 = LeakyReLU(alpha=acti_alpha)(conv_1)

        pool_1 = AveragePooling2D(pool_size=(2, 2), name='pool_1')(lrelu_1)  # 16x16 / 14x14

        bn_1 = BatchNormalization(name='bn_1')(pool_1)

        conv_2 = Conv2D(filters=64, kernel_size=(3, 3), activation=acti_func, name='conv_2',
                        padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                        kernel_initializer=weights_init)(bn_1)
        conv_22 = Conv2D(filters=64, kernel_size=(3, 3), activation=acti_func, name='conv_22',
                         padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                         kernel_initializer=weights_init)(conv_2)
        conv_2 = Concatenate()([conv_2, conv_22])

        lrelu_2 = LeakyReLU(alpha=acti_alpha)(conv_2)

        pool_2 = AveragePooling2D(pool_size=(2, 2), name='pool_2')(lrelu_2)  # 8x8 / 7x7

        if img_height == 28:
            pool_2 = ZeroPadding2D(padding=(1, 1))(pool_2)

        bn_2 = BatchNormalization(name='bn_2')(pool_2)

        conv_3 = Conv2D(filters=128, kernel_size=(3, 3), activation=acti_func, name='conv_3',
                        padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                        kernel_initializer=weights_init)(bn_2)
        conv_33 = Conv2D(filters=128, kernel_size=(3, 3), activation=acti_func, name='conv_33',
                         padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                         kernel_initializer=weights_init)(conv_3)
        conv_3 = Concatenate()([conv_3, conv_33])

        lrelu_3 = LeakyReLU(alpha=acti_alpha)(conv_3)

        pool_3 = AveragePooling2D(pool_size=(2, 2), name='pool_3')(lrelu_3)  # 4x4

        bn_3 = BatchNormalization(name='bn_3')(pool_3)

        conv_4 = Conv2D(filters=256, kernel_size=(3, 3), activation=acti_func, name='conv_4',
                        padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                        kernel_initializer=weights_init)(bn_3)
        conv_44 = Conv2D(filters=256, kernel_size=(3, 3), activation=acti_func, name='conv_44',
                         padding='same', kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                         kernel_initializer=weights_init)(conv_4)
        conv_4 = Concatenate()([conv_4, conv_44])

        lrelu_4 = LeakyReLU(alpha=acti_alpha)(conv_4)

        pool_4 = AveragePooling2D(pool_size=(2, 2), name='pool_4')(lrelu_4)  # 2x2

        bn_4 = BatchNormalization(name='bn_4')(pool_4)

        conv_5 = Conv2D(filters=256, kernel_size=(1, 1), activation=acti_func, name='conv_5',
                        kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                        kernel_initializer=weights_init)(bn_4)  # 2x2
        conv_55 = Conv2D(filters=256, kernel_size=(1, 1), activation=acti_func, name='conv_55',
                         kernel_regularizer=regularizers.l2(reg), use_bias=set_bias,
                         kernel_initializer=weights_init)(bn_4)
        conv_5 = Concatenate()([conv_5, conv_55])

        lrelu_5 = LeakyReLU(alpha=acti_alpha)(conv_5)

        bn_5 = BatchNormalization(name='bn_5')(lrelu_5)

        flt_7 = Flatten()(bn_5)

        dense_8 = Dense(units=256, activation=acti_func, name='dense_8',
                        kernel_regularizer=regularizers.l2(reg), use_bias=not set_bias,
                        kernel_initializer=weights_init)(flt_7)

        lrelu_8 = LeakyReLU(alpha=acti_alpha)(dense_8)

        drop_8 = Dropout(rate=0.5)(lrelu_8)

        dense_9 = Dense(units=latent_fea, activation=acti_func, name='dense_9',
                        kernel_regularizer=regularizers.l2(reg), use_bias=not set_bias,
                        kernel_initializer=weights_init)(drop_8)

        lrelu_9 = LeakyReLU(alpha=acti_alpha)(dense_9)

        drop_9 = Dropout(rate=0.5)(lrelu_9)

        dense_10 = Dense(units=num_normal_class, activation='linear', name='dense_10',
                         kernel_regularizer=regularizers.l2(reg), use_bias=not set_bias,
                         kernel_initializer=weights_init)(drop_9)
        clf_layer = Softmax(name='softmax')(dense_10)

        top_layer = Reshape(target_shape=(-1,), name='top_layer')(clf_layer)
        latent_layer = Reshape(target_shape=(-1,), name='latent_layer')(dense_10)
        logits_layer = Reshape(target_shape=(-1,), name='latent_layer')(dense_10)

    else:
        raise ValueError('No suitable CNN architecture...')

    cnn = Model(inputs=input_layer, outputs=top_layer, name=cnn_type)
    cnn_latent = Model(inputs=input_layer, outputs=latent_layer, name=cnn_type + '_latent')
    cnn_logits = Model(inputs=input_layer, outputs=logits_layer, name=cnn_type + '_logits')

    # ==================== Intra-Class Networks ====================

    input_1 = Input(shape=(img_height, img_height, num_channel), name='input_1')
    input_2 = Input(shape=(img_height, img_height, num_channel), name='input_2')

    lat_1 = cnn_latent(input_1)
    lat_2 = cnn_latent(input_2)

    latent_dist = Subtract(name='latent_dist')([lat_1, lat_2])

    dense_ly = Dense(units=1, activation='sigmoid', name='dense_ly',
                     kernel_regularizer=regularizers.l2(reg))(latent_dist)

    ic_network = Model(inputs=[input_1, input_2], outputs=dense_ly)

    # ==================== Joint layers ====================
    dense_11 = Dense(units=num_normal_class-1, activation='softmax', name='dense_joint',
                     kernel_regularizer=regularizers.l2(reg), use_bias=not set_bias,
                     kernel_initializer=weights_init)(latent_layer)
    joint_layer = Reshape(target_shape=(-1,), name='joint_layer')(dense_11)

    joint_cnn = Model(inputs=input_layer, outputs=[top_layer, joint_layer])

    # cnn.summary()
    # ic_network.summary()

    return cnn, cnn_latent, joint_cnn, cnn_logits


def train_logits_cnn(data=None, label=None, normal_class=None, reg=None, epoch=None, batch_size=None, name=None):
    """Train a normal CNN and save the layers from to bottom layer to logit output layer for further data splitting.

    :param data: Training data in a 4D tensor.
    :param label: Corresponding original labels for the training data.
    :param normal_class: A list of selected known classes labels.
    :param reg: Decay for the regularization term.
    :param epoch: Trianing epochs.
    :param batch_size: The size of batch sizes.
    :param name: The name of the CNN for saving.
    :return: CNN models.
    """

    num_normal_class = len(normal_class)
    label = to_categorical(assign_label(normal_class=normal_class, original_label=label, include_zero=True))
    num_img, img_height, img_width, num_channel = data.shape[0], data.shape[1], data.shape[2], data.shape[-1]

    model_set = build_cnn(img_height=img_height, num_channel=num_channel, reg=reg,
                          latent_fea=256, num_normal_class=num_normal_class, cnn_type='logits_cnn')

    cnn = model_set[0]
    cnn_latent = model_set[1]

    customized_optimizer = optimizers.RMSprop(lr=1e-4, decay=1e-9)
    cnn.compile(optimizer=customized_optimizer,
                loss='categorical_crossentropy',
                metrics=['accuracy'])
    cnn_latent.compile(optimizer=customized_optimizer,
                       loss='mse')

    cnn.fit(x=data,
            y=label,
            batch_size=batch_size,
            epochs=epoch,
            verbose=2)

    if name is not None:
        cnn_latent.save('./trained_models/cnn_for_ds_' + name + '.h5')
        return
    else:
        return cnn, cnn_latent


In [None]:
def train_joint_osrnet(data=None, name=None, rho=None, reg=None, latent_fea=None, num_epoch=None, batch_size=64,
                       split_method='cnn', normal_class=None, backbone='modified_vgg',
                       loss_weights=None, pretrain_ep=None):

    # ==================== split training data ====================
    typical_index, atypical_index = split_data(model_name=name,
                                               data=data['x_train_normal'],
                                               rho=rho,
                                               split_method=split_method,
                                               ground_truth=data['y_train_normal'],
                                               normal_class=normal_class)

    # ==================== assign labels to typical and atypical data====================
    typical_label = data['y_train_normal'][typical_index]
    atypical_label = data['y_train_normal'][atypical_index] * 0

    typical_label = assign_label(normal_class=normal_class, original_label=typical_label, include_zero=False)
    print('\nThere are %d typical normal samples and %d atypical normal samples...' % (len(typical_label),
                                                                                       len(atypical_label)))

    # assign labels for closed-set regularization
    normal_label = assign_label(normal_class=normal_class, original_label=data['y_train_normal'], include_zero=True)
    normal_lb_ty = normal_label[typical_index]
    normal_lb_aty = normal_label[atypical_index]
    normal_label = np.concatenate([normal_lb_ty, normal_lb_aty])

    # ==================== shuffle the training data ====================
    normal_x = np.vstack([data['x_train_normal'][typical_index], data['x_train_normal'][atypical_index]])
    normal_y = np.concatenate([typical_label, atypical_label])

    training_idx = np.random.permutation(np.arange(0, len(normal_y)))
    normal_x = normal_x[training_idx]
    normal_y = normal_y[training_idx]
    normal_label = normal_label[training_idx]

    # ==================== create and compile network ====================
    img_height, img_width = normal_x.shape[1], normal_x.shape[2]
    num_ch = normal_x.shape[-1]
    num_train = normal_x.shape[0]
    num_all_cls = 1 + len(normal_class)
    idx = np.arange(0, num_train)

    model_set = build_cnn(img_height=img_height, num_channel=num_ch, reg=reg, latent_fea=latent_fea,
                          num_normal_class=num_all_cls, cnn_type=backbone)

    customized_optimizer = optimizers.Adam(lr=3e-4, beta_1=0.5, clipvalue=1.0, decay=1e-10)

    osrnet = model_set[0]
    osrnet_latent = model_set[1]
    osrnet_joint = model_set[2]
    osrnet_logits = model_set[3]

    osrnet.compile(optimizer=customized_optimizer,
                   loss='categorical_crossentropy',
                   metrics=['accuracy'])
    osrnet_latent.compile(optimizer=customized_optimizer, loss='mse')
    osrnet_joint.compile(optimizer=customized_optimizer,
                         loss=['categorical_crossentropy', 'categorical_crossentropy'],
                         loss_weights=loss_weights)
    osrnet_logits.compile(optimizer=customized_optimizer, loss='mse')

    # ==================== prepare test set ====================
    x_test = np.vstack([data['x_test_normal'], data['x_test_abnormal']])
    y_test_normal = assign_label(normal_class=normal_class, original_label=data['y_test_normal'], include_zero=False)
    y_test_gt = np.concatenate([y_test_normal, data['y_test_abnormal']*0])

    # ==================== train the network====================
    best_baccu = 0
    best_auc = 0
    best_cs_accu = 0

    res_baccu = []
    res_auc = []
    res_cs_accu = []
    res_train_accu = []

    record_step = 5

    # pre-train the network
    np.random.shuffle(idx)
    if pretrain_ep is not None:
        osrnet.fit(x=normal_x[idx],
                   y=to_categorical(normal_label+1)[idx],
                   batch_size=batch_size,
                   epochs=pretrain_ep,
                   verbose=2)

    name = name + '_l1_%s_l2_%s' % (str(loss_weights[0]), str(loss_weights[1]))
    for i in range(num_epoch):

        if (i + 1) % record_step == 0 or i == 0:
            print('\nTraining for epoch %d' % (i+1))
            y_test_pred = np.argmax(osrnet.predict(x_test, batch_size=128), axis=-1)
            baccu = cal_openset_baccu(y_test_gt, y_test_pred, label_ref=np.arange(0, num_all_cls))

            y_test_decision_function = osrnet_logits.predict(x_test, batch_size=128)
            auc = cal_modified_auc(y_test_gt, y_test_decision_function)

            y_test_cs_pred = np.argmax(osrnet.predict(data['x_test_normal'], batch_size=128), axis=-1)
            cs_accu = cal_closed_set_accu(y_test_normal, y_test_cs_pred)

            y_train_pred = np.argmax(osrnet.predict(normal_x[:500], batch_size=128), axis=-1)
            train_cs_accu = cal_closed_set_accu(normal_label[:500]+1, y_train_pred)

            res_baccu.append(baccu)
            res_auc.append(auc)
            res_cs_accu.append(cs_accu)
            res_train_accu.append(train_cs_accu)

            if baccu > best_baccu:
                best_baccu = baccu
                best_auc = auc
                best_cs_accu = cs_accu
                osrnet.save(filepath='./trained_models/osrnet_best_%s_rho_%d.h5' % (name, rho))
                osrnet_logits.save(filepath='./trained_models/osrnet_logits_best_%s_rho_%d.h5' % (name, rho))

            print('\nThe best baccu is %.4f' % best_baccu)
            print('The best auc is %.4f' % best_auc)
            print('The best cs accu is %.4f' % best_cs_accu)

        osrnet_joint.fit(x=normal_x[idx],
                         y=[to_categorical(normal_y)[idx],
                            to_categorical(normal_label)[idx]],
                         batch_size=batch_size,
                         epochs=1,
                         verbose=0)
        np.random.shuffle(idx)

    osrnet.save(filepath='./trained_models/osrnet_end_%s_rho_%d.h5' % (name, rho))

In [None]:
def run_osrnet(normal_class=None, dataset=None, rho=None):
    """Run OSRNET for open set recognition

    :param normal_class: A list of class labels that are considered as known classes during training.
    :param dataset: The name of a desired dataset: mnist, fmnist or cifar10.
    """

    set_seed()

    # ========== constants ==========
    RHO = rho
    LOSS_WEIGHTS = [1., 1.]
    SPLIT_METHOD = 'cnn'
    TRAIN_EPOCH = 100
    BATCH_SIZE = 64
    PRE_TRAIN_EPOCH = 35

    data = get_data(dataset,
                    normal_class=normal_class,
                    data_format='tensor')

    num_cls = len(normal_class)
    name = dataset + '_'
    for idx in range(num_cls):
        name = name + str(normal_class[idx])

    # Train a network for splitting the known classes
    train_logits_cnn(data=data['x_train_normal'],
                     label=data['y_train_normal'],
                     normal_class=normal_class,
                     reg=1e-3,
                     epoch=TRAIN_EPOCH,
                     batch_size=BATCH_SIZE,
                     name=name)

    # Train an OSRNET
    train_joint_osrnet(data=data,
                       name=name,
                       rho=RHO,
                       reg=1e-3,
                       latent_fea=256,
                       num_epoch=TRAIN_EPOCH,
                       batch_size=BATCH_SIZE,
                       split_method=SPLIT_METHOD,
                       normal_class=normal_class,
                       backbone='modified_vgg',
                       loss_weights=LOSS_WEIGHTS,
                       pretrain_ep=PRE_TRAIN_EPOCH)

In [None]:
run_osrnet(normal_class=[1, 3, 4, 6, 7, 9], dataset='mnist', rho=10)

  "The `lr` argument is deprecated, use `learning_rate` instead.")


Epoch 1/100
576/576 - 12s - loss: 0.3553 - accuracy: 0.9500
Epoch 2/100
576/576 - 9s - loss: 0.1579 - accuracy: 0.9885
Epoch 3/100
576/576 - 9s - loss: 0.1184 - accuracy: 0.9903
Epoch 4/100
576/576 - 9s - loss: 0.0953 - accuracy: 0.9912
Epoch 5/100
576/576 - 9s - loss: 0.0797 - accuracy: 0.9926
Epoch 6/100
576/576 - 8s - loss: 0.0696 - accuracy: 0.9932
Epoch 7/100
576/576 - 8s - loss: 0.0624 - accuracy: 0.9937
Epoch 8/100
576/576 - 8s - loss: 0.0577 - accuracy: 0.9941
Epoch 9/100
576/576 - 8s - loss: 0.0547 - accuracy: 0.9944
Epoch 10/100
576/576 - 8s - loss: 0.0515 - accuracy: 0.9944
Epoch 11/100
576/576 - 9s - loss: 0.0490 - accuracy: 0.9947
Epoch 12/100
576/576 - 9s - loss: 0.0474 - accuracy: 0.9951
Epoch 13/100
576/576 - 9s - loss: 0.0468 - accuracy: 0.9950
Epoch 14/100
576/576 - 9s - loss: 0.0437 - accuracy: 0.9954
Epoch 15/100
576/576 - 9s - loss: 0.0433 - accuracy: 0.9952
Epoch 16/100
576/576 - 8s - loss: 0.0427 - accuracy: 0.9954
Epoch 17/100
576/576 - 8s - loss: 0.0405 - accur

In [None]:
run_osrnet(normal_class=[1, 3, 4, 6, 7, 9], dataset='cifar10', rho=20)

  "The `lr` argument is deprecated, use `learning_rate` instead.")


Epoch 1/100
469/469 - 10s - loss: 1.3255 - accuracy: 0.5178
Epoch 2/100
469/469 - 7s - loss: 0.9608 - accuracy: 0.6911
Epoch 3/100
469/469 - 7s - loss: 0.8283 - accuracy: 0.7451
Epoch 4/100
469/469 - 7s - loss: 0.7272 - accuracy: 0.7833
Epoch 5/100
469/469 - 7s - loss: 0.6541 - accuracy: 0.8133
Epoch 6/100
469/469 - 7s - loss: 0.5936 - accuracy: 0.8329
Epoch 7/100
469/469 - 7s - loss: 0.5409 - accuracy: 0.8540
Epoch 8/100
469/469 - 7s - loss: 0.5017 - accuracy: 0.8714
Epoch 9/100
469/469 - 7s - loss: 0.4676 - accuracy: 0.8837
Epoch 10/100
469/469 - 7s - loss: 0.4353 - accuracy: 0.8951
Epoch 11/100
469/469 - 7s - loss: 0.4044 - accuracy: 0.9080
Epoch 12/100
469/469 - 7s - loss: 0.3796 - accuracy: 0.9164
Epoch 13/100
469/469 - 7s - loss: 0.3602 - accuracy: 0.9243
Epoch 14/100
469/469 - 7s - loss: 0.3395 - accuracy: 0.9359
Epoch 15/100
469/469 - 7s - loss: 0.3227 - accuracy: 0.9397
Epoch 16/100
469/469 - 7s - loss: 0.3112 - accuracy: 0.9461
Epoch 17/100
469/469 - 7s - loss: 0.3020 - accur