<a href="https://colab.research.google.com/github/nikitamaia/tensorflow-examples/blob/main/dist_strat_blog_single_gpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Train a ResNet50 model on the Cassava dataset

You should utilize a GPU Runtime for this notebook:
*Runtime > Change runtime type > Hardware accelerator: GPU*

To learn more about the dataset, check out https://www.tensorflow.org/datasets/catalog/cassava

This notebook is from the blog "Getting Started With Distributed Training on GCP"





In [1]:
import tensorflow as tf
print(tf.__version__)

import tensorflow_datasets as tfds

2.3.0


## Import the data from Tensorflow Datasets

In [18]:
data, info = tfds.load(name='cassava', as_supervised=True, with_info=True)
NUM_CLASSES = info.features['label'].num_classes

## Set up the input pipeline using tf.data

*Using tf.data is highly recommended when doing distributed training*

In [28]:
def preprocess_data(image, label):
  image = tf.image.resize(image, (300,300))
  return tf.cast(image, tf.float32) / 255., label

In [30]:
def create_dataset(train_data, batch_size):
  train_data = train_data.map(preprocess_data, 
                                 num_parallel_calls=tf.data.experimental.AUTOTUNE)
  train_data  = train_data.shuffle(1000)
  train_data  = train_data.cache().batch(batch_size)
  train_data  = train_data.prefetch(tf.data.experimental.AUTOTUNE)
  return train_data

In [31]:
def create_model():
  base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False)
  x = base_model.output
  x = tf.keras.layers.GlobalAveragePooling2D()(x)
  x = tf.keras.layers.Dense(1016, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
  predictions = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(x)
  model = tf.keras.Model(inputs=base_model.input, outputs=predictions)
  return model

In [32]:
model = create_model()
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=tf.keras.optimizers.Adam(0.0001),
    metrics=['accuracy'])

In [33]:
BATCH_SIZE = 64

In [34]:
train_data = create_dataset(data['train'], BATCH_SIZE)

In [35]:
model.fit(train_data, epochs = 5)

Epoch 1/5




Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fb9f008a358>