# Cross validating Keras models

In [None]:
import os

import matplotlib.pyplot as plt

from keras.layers.core import Activation
from keras.layers.core import Dense
from keras.models import Sequential
from keras.utils import np_utils
from sklearn.datasets import load_digits
from sklearn.model_selection import ShuffleSplit

from faculty_xval.validation import JobsCrossValidator
from faculty_xval.utilities import job_name_to_job_id

**Note**: Please create a directory for storing the results of cross validation in your personal directory (`/project/{USER_NAME}/temp`).

In [None]:
JOB_NAME = "cross_validation_{}".format(
    os.environ["USER_NAME"]
)
REFERENCE_DIR = "/project/{}/temp/".format(
    os.environ["USER_NAME"]
)
if not os.path.isdir(REFERENCE_DIR):
    raise OSError((
        "Path {} cannot be found "
        + "or is not a directory"
    ).format(REFERENCE_DIR))

N_SPLITS = 10
TEST_SIZE = 0.25

N_NODES = 40
ACTIVATIONS = ["relu", "softmax"]
LOSS = "categorical_crossentropy"
OPTIMIZER = "adam"
FIT_KWARGS = {
    "epochs": 16,
    "batch_size": 32,
    "verbose": 0
}

NUM_SUBRUNS = 3

## Initialise the cross validator

In [None]:
job_id = job_name_to_job_id(JOB_NAME)
cross_validator = JobsCrossValidator(job_id, REFERENCE_DIR)

## Load the data

In [None]:
dataset = load_digits()
features = dataset["data"]
targets = dataset["target"]

In [None]:
# Convert targets to one-hot encoding.
targets = np_utils.to_categorical(targets)

In [None]:
print("Features:")
plt.imshow(features[3].reshape(8,8))
plt.show()

print("Targets:")
print(targets[3])

In [None]:
split_generator = ShuffleSplit(
    n_splits=N_SPLITS,
    test_size=TEST_SIZE
).split(features)

## Define the Keras model

In [None]:
model = Sequential()
model.add(Dense(
    N_NODES,
    input_shape=(features.shape[1],)
))
model.add(Activation(ACTIVATIONS[0]))
model.add(Dense(targets.shape[1]))
model.add(Activation(ACTIVATIONS[1]))

In [None]:
model.compile(loss=LOSS, optimizer=OPTIMIZER)

## Perform cross validation

In [None]:
cross_validator.run(
    model,
    [features],
    [targets],
    split_generator,
    NUM_SUBRUNS,
    fit_kwargs=FIT_KWARGS
)