In [24]:
import time
import os
import keras
import numpy as np
from importlib import reload
import tensorflow as tf
from tensorflow.keras import optimizers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import VGG19
from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint
from tensorflow.keras.models import load_model
import pickle
import shutil

%config Completer.use_jedi = False
# ROOT_DIR = '/Users/halmagyi/Documents/MachineLearning/ML_Notes/BaysianNNets/ColdEnsembles'
ROOT_DIR = '/Users/halmagyi/Documents/MachineLearning/ML_Notes/BaysianNNets/ColdEnsembles'
DATA_DIR = os.path.join(ROOT_DIR, 'data')
CALLBACKS_DIR = os.path.join(DATA_DIR, 'callbacks')
os.chdir(ROOT_DIR)

import src.subleading; reload(src.subleading)
from src.subleading import *
import src.model; reload(src.model)
from src.model import *
import src.mnist; reload(src.mnist)
from src.mnist import *
import src.hessian; reload(src.hessian)
from src.hessian import *
import src.callbacks; reload(src.callbacks)
from src.callbacks import EvaluateAfterNBatch
import src.schedules; reload(src.schedules)
from src.schedules import lr_scheduler, StepDecay



MNIST_path = os.path.join(os.path.expanduser('~'), '.keras/datasets/mnist.npz')

num_classes = 10
x_train_flat, x_test_flat, x_train_flat_bias, x_test_flat_bias, Y_train, Y_test = make_mnist_data(num_classes)

In [8]:
# initial_learning_rate = 0.001
# lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
#     initial_learning_rate,
#     decay_steps=100,
#     decay_rate=0.96,
#     staircase=True)

# model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=lr_schedule),
#               loss='sparse_categorical_crossentropy',
#               metrics=['accuracy'])

# Ensemble with different seeds

In [27]:
batch_size = 64
epochs = 60

num_classes = 10
num_layers = 3
data_length = x_train_flat.shape[1]
hidden_width = 500
optimizer = Adam()

callback_ensemble = []
histories = []
models = []

validation_data=(x_test_flat, Y_test)

num_models = 2

checkpoint_filepath = os.path.join(CALLBACKS_DIR, 'checkpoints')
if os.path.isdir(checkpoint_filepath):
    shutil.rmtree(checkpoint_filepath, ignore_errors=False, onerror=None)
os.mkdir(checkpoint_filepath)

for i in range(num_models):

    model = make_mnist_model(num_classes, num_layers, data_length, hidden_width, seed=3*i, output_l2=0.001)


    model.compile(optimizer=optimizer,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    evaluate_Nbatch = EvaluateAfterNBatch(x_test_flat, Y_test, N=500)

    schedule = StepDecay(initAlpha=1e-3, factor=3**(-1), dropEvery=20)
#     callbacks = [evaluate_Nbatch, LearningRateScheduler(schedule)]

    checkpoint_modelnum_filepath = os.path.join(checkpoint_filepath, f'model_{i}')

    model_checkpoint_callback = ModelCheckpoint(
        filepath=checkpoint_modelnum_filepath,
        save_weights_only=False,
        monitor='loss',
        mode='min',
        save_best_only=True)


    callbacks = [LearningRateScheduler(schedule), model_checkpoint_callback]

    history = model.fit(x=x_train_flat, 
              y=Y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=2,
              callbacks=callbacks,
              validation_data=validation_data)
    
    models += [model]
    callback_ensemble += [callbacks]
    histories += [history]

Epoch 1/60
938/938 - 6s - loss: 0.2151 - accuracy: 0.9364 - val_loss: 0.1258 - val_accuracy: 0.9621 - lr: 0.0010 - 6s/epoch - 6ms/step
Epoch 2/60
938/938 - 5s - loss: 0.0917 - accuracy: 0.9740 - val_loss: 0.0991 - val_accuracy: 0.9725 - lr: 0.0010 - 5s/epoch - 5ms/step
Epoch 3/60
938/938 - 5s - loss: 0.0676 - accuracy: 0.9802 - val_loss: 0.0985 - val_accuracy: 0.9713 - lr: 0.0010 - 5s/epoch - 5ms/step
Epoch 4/60
938/938 - 5s - loss: 0.0514 - accuracy: 0.9851 - val_loss: 0.0753 - val_accuracy: 0.9790 - lr: 0.0010 - 5s/epoch - 5ms/step
Epoch 5/60
938/938 - 5s - loss: 0.0418 - accuracy: 0.9880 - val_loss: 0.0748 - val_accuracy: 0.9805 - lr: 0.0010 - 5s/epoch - 5ms/step
Epoch 6/60
938/938 - 5s - loss: 0.0363 - accuracy: 0.9898 - val_loss: 0.0837 - val_accuracy: 0.9789 - lr: 0.0010 - 5s/epoch - 6ms/step
Epoch 7/60
938/938 - 5s - loss: 0.0319 - accuracy: 0.9915 - val_loss: 0.0813 - val_accuracy: 0.9802 - lr: 0.0010 - 5s/epoch - 5ms/step
Epoch 8/60
938/938 - 5s - loss: 0.0246 - accuracy: 0.99

KeyboardInterrupt: 

In [25]:
models = []
for i in range(num_models):
    checkpoint_modelnum_filepath = os.path.join(checkpoint_filepath, f'model_{i}')
    models += [load_model(checkpoint_modelnum_filepath)]

In [26]:
models

[<keras.engine.functional.Functional at 0x7ffc1a8547f0>,
 <keras.engine.functional.Functional at 0x7ffc1a8a12b0>]

In [51]:
def pred_ensemble(models, histories, X):
    preds = []
    for model, history in zip(models, histories):
        loss = history.history['loss'][-1]
        preds += [model.predict(X) * np.exp(-loss)]
        
    preds = np.sum(preds, axis=0)
#     preds = np.array()
    return preds

preds = np.argmax(pred_ensemble(models, histories, x_test_flat), axis=1)
np.sum(preds == np.argmax(Y_test,axis=1))

9881

In [52]:
[history.history['loss'][-1] for history in histories]

[1.6213876733672805e-05,
 1.940501169883646e-05,
 2.138576928700786e-05,
 1.897824949992355e-05,
 1.8695041944738477e-05,
 2.0125420633121394e-05,
 2.087637949443888e-05,
 2.010149364650715e-05,
 2.108018634316977e-05,
 2.042125015577767e-05]

In [50]:
def get_truefalse_preds(model, X, Y):
    Y_true = np.array(list(map(np.argmax, Y)))

    y_preds = model.predict(X)
    Y_preds = np.array(list(map(np.argmax, y_preds)))
    
    Y_preds_prob = np.array(list(map(np.max, y_preds)))

    true_args = np.where((Y_preds == Y_true) == True)[0]
    false_args = np.where((Y_preds == Y_true) == False)[0]
    
    true_preds = Y_preds[true_args]
    false_preds = Y_preds[false_args]
    
    true_probs = np.array(Y_preds_prob[true_args])
    false_probs = np.array(Y_preds_prob[false_args])

    return true_args, false_args, true_preds, false_preds, true_probs, false_probs


True_args = []; False_args=[]; True_preds=[]; False_preds=[]; True_probs=[]; False_probs=[];

for i in range(num_models):
    true_args, false_args, true_preds, false_preds, true_probs, false_probs = \
        get_truefalse_preds(models[i], x_test_flat, Y_test)
    
    True_args += [true_args]
    False_args += [false_args]
    True_preds += [true_preds]
    False_preds += [false_preds]
    True_probs += [true_probs]
    False_probs += [false_probs]

In [51]:
FA_flat = np.array(sorted(list(set([f for e in False_args for f in e]))))

#probs
probs = np.transpose(np.array([list(map(np.max, model.predict(x_test_flat[FA_flat]))) for model in models]))
model_nums = list(map(np.argmax, probs))

len(model_nums)

203

In [52]:
preds = np.transpose(np.array([list(map(np.argmax, model.predict(x_test_flat[FA_flat]))) for model in models]))

preds_ensemble = np.array([p[model_num] for p, model_num in zip(preds, model_nums)])

y_true_fa = list(map(np.argmax, Y_test[FA_flat]))

len(preds_ensemble) - sum(preds_ensemble == y_true_fa)

128