# Train the model with transfer learning

In [1]:
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import os
import pickle
from shutil import copy2
import tensorflow
import IPython

### Import the model

In [16]:
from tensorflow.python.keras.applications import ResNet50
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Flatten, GlobalAveragePooling2D

num_classes = 10 

model = Sequential()

model.add(ResNet50(    
  include_top=False,          
  weights='imagenet', 
  pooling='avg' 
))

model.add(Dense(
  num_classes, 
  activation='softmax' 
))

model.layers[0].trainable = False

In [17]:
model.compile(
  optimizer='sgd', 
  loss='categorical_crossentropy', 
  metrics=['accuracy'] 
)

### Create data generators

In [18]:
from tensorflow.python.keras.applications.resnet50 import preprocess_input
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator

image_size = 224

data_generator_no_aug = ImageDataGenerator(preprocessing_function=preprocess_input)

data_generator_with_aug = ImageDataGenerator(preprocessing_function=preprocess_input,
                                   horizontal_flip=True,
                                   rotation_range=20,
                                   width_shift_range = 0.2,
                                   height_shift_range = 0.2)

train_generator_with_aug = data_generator_with_aug.flow_from_directory(
        working_train_dir,
        target_size=(image_size, image_size),
        batch_size=4,
        class_mode='categorical')

validation_generator = data_generator_no_aug.flow_from_directory(
        working_test_dir,
        target_size=(image_size, image_size),
        class_mode='categorical')

Found 3547 images belonging to 10 classes.
Found 400 images belonging to 10 classes.


### Train the model

In [10]:
history_aug = model.fit_generator(
        train_generator_with_aug,
        steps_per_epoch=10,
        epochs=1,
        validation_data=validation_generator,
        validation_steps=1)

Instructions for updating:
Use tf.cast instead.


### Save the model and the history

In [12]:
with open('history.pkl', 'wb') as f:
    pickle.dump(history_aug.history, f)

In [None]:
model.save('model.h5')