<a href="https://colab.research.google.com/github/omarcevi/ML-Projects/blob/main/TF_Distributed_training_sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
# Download the Cassava dataset from TensorFlow
data, info = tfds.load(name='cassava', as_supervised=True, with_info=True)
NUM_CLASSES = info.features['label'].num_classes

Downloading and preparing dataset 1.26 GiB (download: 1.26 GiB, generated: Unknown size, total: 1.26 GiB) to /root/tensorflow_datasets/cassava/0.1.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/5656 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/cassava/0.1.0.incompleteWDG3TP/cassava-train.tfrecord*...:   0%|          …

Generating test examples...:   0%|          | 0/1885 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/cassava/0.1.0.incompleteWDG3TP/cassava-test.tfrecord*...:   0%|          |…

Generating validation examples...:   0%|          | 0/1889 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/cassava/0.1.0.incompleteWDG3TP/cassava-validation.tfrecord*...:   0%|     …

Dataset cassava downloaded and prepared to /root/tensorflow_datasets/cassava/0.1.0. Subsequent calls will reuse this data.


In [3]:
# Add a preprocess_data function to scale the image
def preprocess_data(image, label):
	image = tf.image.resize(image, (300, 300))
	return tf.cast(image, tf.float32) / 255., label

In [4]:
# Define the model
def create_model():
	base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False)
	x = base_model.output
	x = tf.keras.layers.GlobalAveragePooling2D()(x)
	x = tf.keras.layers.Dense(1016, activation='relu')(x)
	predictions = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(x)
	model = tf.keras.Model(inputs=base_model.input, outputs=predictions)
	return model

In [5]:
# Create the strategy object
strategy = tf.distribute.MirroredStrategy()

# Create model variables within the strategy scope
with strategy.scope():
	model = create_model()
	model.compile(
		loss='sparse_categorical_crossentropy',
		optimizer=tf.keras.optimizers.Adam(0.0001),
		metrics=['accuracy'])

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5


In [6]:
# Change the batch size
batch_size = 64 * strategy.num_replicas_in_sync

# Map, shuffle, and prefetch data
train_data = data['train'].map(preprocess_data)
train_data = train_data.shuffle(1000)
train_data = train_data.batch(batch_size)
train_data = train_data.prefetch(tf.data.experimental.AUTOTUNE)

model.fit(train_data, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x7ebc0817f550>