# Keras with MirroredStrategy
In this lab, you will take the solution you created in Lab 03b and implement MirroredStrategy.  You will take a Keras Sequential Model and convert to estimator, then distribute with MirroredStrategy.

It will be easiest to take the results you had from 4.2b and replace the model with the Keras Sequential Model.

Helpful references:
- https://www.tensorflow.org/guide/keras#multiple_gpus
- https://keras.io/losses/#sparse_categorical_crossentropy

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Conv2D, MaxPooling2D
import os
import numpy as np
import time

### Data Pipeline
Reuse data pipeline components from previous labs.

In [None]:
def parse_tfrecord(example):
    feature={'idx'     : tf.FixedLenFeature((), tf.int64),
             'label'   : tf.FixedLenFeature((), tf.int64),
             'image'   : tf.FixedLenFeature((), tf.string, default_value="")}
    parsed = tf.parse_single_example(example, feature)
    image = tf.decode_raw(parsed['image'],tf.float64)
    image = tf.cast(image,tf.float32)
    image = tf.reshape(image,[32,32,3])
    return image, parsed['label']

def image_scaling(x):
    return tf.image.per_image_standardization(x)

def distort(x):
    x = tf.image.resize_image_with_crop_or_pad(x, 40, 40)
    x = tf.random_crop(x, [32, 32, 3])
    x = tf.image.random_flip_left_right(x)
    return x

In [None]:
def dataset_input_fn(params):
    dataset = tf.data.TFRecordDataset(
        params['filenames'],num_parallel_reads=params['threads'])
    dataset = dataset.map(parse_tfrecord, num_parallel_calls=params['threads'])
    dataset = dataset.map(lambda x,y: (image_scaling(x),y),num_parallel_calls=params['threads'])
    if params['mode']==tf.estimator.ModeKeys.TRAIN:
        dataset = dataset.map(lambda x,y: (distort(x),y),num_parallel_calls=params['threads'])
        dataset = dataset.shuffle(buffer_size=params['shuffle_buff'])
    dataset = dataset.repeat()
    dataset = dataset.batch(params['batch'])
    dataset = dataset.prefetch(8*params['batch'])
    return dataset

In [None]:
train_files = !ls ./data/cifar10_data_00*.tfrecords
val_files   = !ls ./data/cifar10_data_01*.tfrecords

train_params = {'filenames'    : train_files,
                'mode'         : tf.estimator.ModeKeys.TRAIN,
                'threads'      : 16,
                'shuffle_buff' : 100000,
                'batch'        : 100
               }

eval_params  = {'filenames'    : val_files,
                'mode'         : tf.estimator.ModeKeys.EVAL,
                'threads'      : 8,
                'batch'        : 200
               }

train_spec = tf.estimator.TrainSpec(input_fn=lambda: dataset_input_fn(train_params),max_steps=200000)
eval_spec  = tf.estimator.EvalSpec(input_fn=lambda: dataset_input_fn(eval_params),steps=10,throttle_secs=60)


### DO: Setup Keras model

In [None]:
num_classes = 10
model_params  = {'drop_out'      : 0.2,
                 'dense_units'   : 1024,
                 'learning_rate' : 1e-3
                }

In [None]:
#Define Keras model
model = Sequential()
#
# code here
#

In [None]:
distribution = tf.contrib.distribute.MirroredStrategy(num_gpus=4)
config = tf.estimator.RunConfig(save_checkpoints_secs = 300,
                                keep_checkpoint_max = 5,
                                session_config=tf.ConfigProto(
                                    allow_soft_placement=True, log_device_placement=True),
                                train_distribute = distribution)

In [None]:
name = 'cnn_model/keras_model_'
name = name + 'dense' + str(model_params['dense_units']) + '_'
name = name + 'drop' + str(model_params['drop_out']) + '_'
name = name + 'lr' + str(model_params['learning_rate']) + '_'
name = name + time.strftime("%Y%m%d%H%M%S")
model_dir  = os.path.join('./',name)

print(model_dir)

### DO: Create estimator from Keras Model

In [None]:
estimator = # code here

In [None]:
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)