# Train and test a set of models of different noise (or noise free) types

In [None]:
#!/usr/bin/env python
# coding: utf-8

from __future__ import division

import numpy as np
import time
import os
import pickle

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy

import import_ipynb
from data_generators import mnist_generator
from utilities import train_mnist
from model_architectures import build_model_cnn

In [None]:
model_null_types = [None, 'u', 's', 'm', 'us', 'um', 'sm', 'usm']

dir_models = './saved_models_mnist_sets/'
batch_size = 32
n_epochs = 20

# Download the MNIST data located here: http://yann.lecun.com/exdb/mnist/
# and set dir_mnist to the location of your downloaded data:
dir_mnist = './mnist'
# dir_mnist = '/home/mroos/Data/pylearn2data/mnist'

for i_model_set in range(0, 31):
    seed = i_model_set
    np.random.seed(seed)
    tf.random.set_seed(seed)
    for null_types in model_null_types:
        # Instantiate generators. One for training data and one for testing data.
        if null_types is None:
            p_null_class = 0.0
        else:
            p_null_class = 0.5
        gen_train = mnist_generator(dir_mnist, batch_size=batch_size, dataset='train',
                                    random_order=True, null_types=null_types, p_null_class=p_null_class)
        gen_test = mnist_generator(dir_mnist, batch_size=batch_size, dataset='test',
                                    random_order=False, null_types=null_types, p_null_class=p_null_class)

        print('\n\n==========================================')
        print('Training model set %d, null_types=%s' % (i_model_set, null_types))
        print('==========================================\n')
        
        model = build_model_cnn(include_null_class=True)
        optimizer = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        loss_instance = SparseCategoricalCrossentropy(from_logits=True) # uses categorical integer label encoding
        metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
        model.compile(optimizer=optimizer, loss=loss_instance, metrics=metrics)    

        if null_types is None:
            model_filename = os.path.join(dir_models, 'mnist_model_baseline_%0.2d.h5' % (i_model_set))
            history_filename = os.path.join(dir_models, 'mnist_model_history_baseline_%0.2d.pkl' % (i_model_set))
        else:
            model_filename = os.path.join(dir_models, 'mnist_model_%s_%0.2d.h5' % (null_types, i_model_set))
            history_filename = os.path.join(dir_models, 'mnist_model_history_%s_%0.2d.pkl' % (null_types, i_model_set))

        model, history = train_mnist(model, gen_train=gen_train, gen_test=gen_test,
                                     n_epochs=n_epochs, model_filename=model_filename, verbose=0)

        pickle.dump(history.history, open(history_filename, 'wb'))