# Training a Model for Inferring if an Image is a Squirrel or Bird

This jupyter notebook uses the [TensorFlow CNN tutorial](https://www.tensorflow.org/tutorials/images/cnn) to create a classification model for identifying if an image is shows a squirrel or bird. The model we create is used in a blog for [Intel® Deep Learning Workbench with DevCloud for the Edge](https://devcloud.intel.com/edge).

In [1]:
import warnings
warnings.filterwarnings('ignore') #Tensorflow has deprication warnings

import tensorflow as tf
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np

### Import Dataset from Pickle

The dataset pickle is a combiniation of [CalTech UCSD Birds 200_2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) and [Animals-10](https://www.kaggle.com/alessiocorrado99/animals10) datasets. The code to create the pickle file is in [this repository](). For the code below, the pickle file was uploaded to the executing director (SquirrelVsBirdClassifier).

In [None]:
dataset_dict = pickle.load(open('squirrelvsbirddataset_64.p','rb'))

### Extract Training and Validation Sets from Pickle File
The pickle file contains a dictionary with 4 keys, 'training_samples','validation_samples','training_labels',and'validation_labels'. The samples are numpy matricies representing the images. The labels are a parallel array to the samples and contain a '0' for squirrel images and a '1' for bird images.

In [None]:
training_samples = np.asarray(dataset_dict['training_samples'])
validation_samples = np.asarray(dataset_dict['validation_samples'])
training_labels = np.asarray([x[1] for x in dataset_dict['training_labels']])
validation_labels = np.asarray([x[1] for x in dataset_dict['validation_labels']])
#training_samples = np.expand_dims(training_samples, axis=1)
#validation_samples = np.expand_dims(validation_samples, axis=1)

tf.keras.backend.set_floatx('float16')

### Create Keras Model
This code-block creates the network structure of our model.

In [4]:
#Create the network layers with Keras
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(64, (3,3), activation = 'relu', input_shape = (64,64,3)))
model.add(tf.keras.layers.MaxPooling2D((2,2)))
model.add(tf.keras.layers.Conv2D(128, (3,3), activation = 'relu'))
model.add(tf.keras.layers.MaxPooling2D((2,2)))
model.add(tf.keras.layers.Conv2D(128, (3,3), activation = 'relu'))

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128, activation = 'relu'))
model.add(tf.keras.layers.Dense(2))

#Display info about our model network
model.summary()

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 62, 62, 64)        1792      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 31, 31, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 29, 29, 128)       73856     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 12, 12, 128)       147584    
_________________________________________________________________
flatten (Flatten)            (None, 18432)             0         
_________________________________________

### Train the model
We are training our model using an 'adam' optimizer and Sparse Categorical Cross Entropy loss function.

In [None]:
model.compile(optimizer='adam',
             loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
             metrics=['accuracy'])

history = model.fit(training_samples, training_labels, epochs=10, validation_data=(validation_samples, validation_labels))

### Save the Model
The final step is to save the model. Keras models are saved with the file extension 'h5'.

In [None]:
#Save
model.save('squirrelvsbird_64.h5')
