<a href="https://colab.research.google.com/github/danielelbrecht/CAP5610-HW-2/blob/master/CAP5610_HW2_kfold.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implement k-fold cross validation on final model, and compare validation accuracy to simple holdout validation

In [3]:
from keras.datasets import cifar10
import numpy as np
import sklearn
import tensorflow as tf
from tensorflow.keras import layers, utils

Using TensorFlow backend.


In [4]:
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [0]:
# Preprocess data
train_images = train_images.reshape((50000, 32, 32, 3)) / 255.0
test_images = test_images.reshape((10000, 32, 32, 3)) / 255.0

train_labels_categorical = utils.to_categorical(train_labels, num_classes=10, dtype='float32')
test_labels_categorical = utils.to_categorical(test_labels, num_classes=10, dtype='float32')



In [0]:
#k-fold cross validation function

def kfold(k, model, data, labels):
  
  length = len(data)
  accuracy = 0
  model.save_weights('initial_weights')
  
  for i in range(k):
    
    model.load_weights('initial_weights')
    
    #Get validation and training splits from data set
    lower_bound = int(i*(length/k))
    upper_bound = int((i+1)*(length/k))

    train_data = np.concatenate((data[0:lower_bound], data[upper_bound:length]))
    val_data = data[lower_bound:upper_bound]
    
    train_labels = np.concatenate((labels[0:lower_bound], labels[upper_bound:length]))
    val_labels = labels[lower_bound:upper_bound]
    
    history = model.fit(train_data, 
                      train_labels, 
                      epochs=1,
                      validation_data=(val_data, val_labels))
    

  return accuracy / k
    

In [27]:
#Define model
model = tf.keras.Sequential()

# First convolutional module
model.add(layers.Conv2D(filters=64, kernel_size=(3,3), activation='relu', input_shape=(32,32,3)))
model.add(layers.MaxPooling2D(pool_size=(2,2)))
model.add(layers.Dropout(0.2))

# Second convolutional module
model.add(layers.Conv2D(filters=128, kernel_size=(3,3), activation='relu'))
model.add(layers.MaxPooling2D(pool_size=(2,2)))
model.add(layers.Dropout(0.2))


# Fully connected layers
model.add(layers.Flatten())
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_6 (Conv2D)            (None, 30, 30, 64)        1792      
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 15, 15, 64)        0         
_________________________________________________________________
dropout_6 (Dropout)          (None, 15, 15, 64)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 13, 13, 128)       73856     
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 6, 6, 128)         0         
_________________________________________________________________
dropout_7 (Dropout)          (None, 6, 6, 128)         0         
_________________________________________________________________
flatten_3 (Flatten)          (None, 4608)              0         
__________

In [28]:
#Perform 5-fold cross validation

kfold(5, model, train_images, train_labels_categorical)


Consider using a TensorFlow optimizer from `tf.train`.
Instructions for updating:
Use tf.train.CheckpointManager to manage checkpoints rather than manually editing the Checkpoint proto.
Train on 40000 samples, validate on 10000 samples
Train on 40000 samples, validate on 10000 samples
 4544/40000 [==>...........................] - ETA: 2:08 - loss: 2.2547 - acc: 0.1345

KeyboardInterrupt: ignored