In [None]:
%load_ext autoreload
%autoreload 2

import joblib
import warnings
import tensorflow as tf
import numpy as np

from src.models import LeNet
from src.data_utils import *

warnings.filterwarnings("ignore")
tf.get_logger().setLevel(tf.logging.ERROR)
np.random.seed(42)

# Train LeNet

In [None]:
(X_train, y_train), (X_test, y_test) = get_mnist()
X_train, X_train_db, y_train, y_train_db = split_to_create_db(X_train, y_train, fold_size=0.2)

In [None]:
# prepare the model
lenet = LeNet(
    input_shape=X_train.shape[1:],
    num_classes=10,
)

epochs = 20
mc_rate = 0.5
lenet.set_mc_dropout_rate(mc_rate)
lenet.train(X_train, y_train, X_test, y_test, epochs=epochs, verbose=1)
lenet.save_model(f'Assets/lenet-{mc_rate}-{epochs}-4folds')
# lenet.load_model("Assets/lenet-0.5-20-4folds.h5")

In [None]:
def mc_dropout(net, X_train, batch_size=1000, dropout=0.5, T=100):
    """
    net: keras model with set_mc_dropout_rate function

    Forward passes T times, then take the variance from all the predictions for each class.
    the mc_dropout score for an example will be the mean of the variances for all the classes.
    y_mc_dropout is the mean of all runs.
    """
    net.set_mc_dropout_rate(dropout)
    model = net.model
    repetitions = []
    # Todo: parallelize
    for _ in range(T):
        pred = model.predict(X_train, batch_size)
        repetitions.append(pred)

    net.set_mc_dropout_rate(0)
    preds = np.array(repetitions)  # T x data x pred

    # average over all passes
    y_mc_dropout = preds.mean(axis=0)

    # get variance from all preds for each example (output: batch x preds classes) each cell is var
    mc = np.var(preds, axis=0)
    # mean of vars of each class (out: one dim array with batch as dim)
    mc_uncertainty = np.mean(mc, axis=-1)

    return y_mc_dropout, -mc_uncertainty

def create_db(net, X_train_db, max_mc_iters):
    from tqdm import tqdm
    db = {}
    for t in tqdm(range(2, max_mc_iters+1)):
        mean, var = mc_dropout(net, X_train_db, T=t)
        db[t] = (mean, var)    
    joblib.dump(db, f"Assets/db_{max_mc_iters}_iters.jblib", compress=True)
    return db

In [None]:
ans = create_db(lenet, X_train_db, 2)