Skip to content

Commit

Permalink
Add support for distribute dataset (keras-team#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
qlzh727 committed May 10, 2023
1 parent 421df85 commit 6be0e2a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
@@ -1,10 +1,12 @@
"""Tests for tf.distribute related functionality under tf implementation."""

import numpy as np
import pytest
import tensorflow as tf
from tensorflow.python.eager import context

from keras_core import backend
from keras_core.backend.tensorflow import trainer as tf_trainer
from keras_core import layers
from keras_core import models
from keras_core import testing
Expand Down Expand Up @@ -79,3 +81,41 @@ def run_fn(data):
self.assertEqual(result.values[1].shape, [8, 2])
self.assertNotAllClose(result.values[0], result.values[1])
self.assertAllClose(result.values[0], tf.zeros([8, 2]))

def test_epoch_iterator(self):
x = np.random.random((100, 16))
y = np.random.random((100, 4))
sample_weight = np.random.random((100,))
batch_size = 16
shuffle = True

strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])

epoch_iterator = tf_trainer.TFEpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
shuffle=shuffle,
distribute_strategy=strategy
)
steps_seen = []
for step, data_iterator in epoch_iterator.enumerate_epoch():
steps_seen.append(step)
batch = next(data_iterator)
self.assertEqual(len(batch), 3)
x, y, sample_weight = batch
self.assertTrue(
isinstance(x,
tf.types.experimental.distributed.PerReplica))
# Make sure the local batch size is 8
if step < 6:
self.assertEqual(x.values[0].shape, [8, 16])
self.assertEqual(y.values[0].shape, [8, 4])
self.assertEqual(sample_weight.values[0].shape, [8])
else:
# Last partial batch
self.assertEqual(x.values[0].shape, [2, 16])
self.assertEqual(y.values[0].shape, [2, 4])
self.assertEqual(sample_weight.values[0].shape, [2])
self.assertEqual(steps_seen, [0, 1, 2, 3, 4, 5, 6])
15 changes: 11 additions & 4 deletions keras_core/backend/tensorflow/trainer.py
Expand Up @@ -243,6 +243,7 @@ def fit(
steps_per_epoch=steps_per_epoch,
shuffle=shuffle,
class_weight=class_weight,
distribute_strategy=self.distribute_strategy,
)

# Container that configures and calls callbacks.
Expand Down Expand Up @@ -285,6 +286,7 @@ def fit(
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
distribute_strategy=self.distribute_strategy,
)
val_logs = self.evaluate(
x=val_x,
Expand Down Expand Up @@ -346,6 +348,7 @@ def evaluate(
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
distribute_strategy=self.distribute_strategy,
)

# Container that configures and calls callbacks.
Expand Down Expand Up @@ -385,6 +388,7 @@ def predict(
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
distribute_strategy=self.distribute_strategy,
)

# Container that configures and calls callbacks.
Expand Down Expand Up @@ -428,20 +432,23 @@ def predict(


class TFEpochIterator(EpochIterator):
def __init__(self, *args, **kwargs):
def __init__(self, distribute_strategy=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self._distribute_strategy = distribute_strategy
self._steps_seen = 0

def enumerate_epoch(self):
if self.steps_per_epoch:
if not self._current_iterator:
self._current_iterator = iter(
self.data_adapter.get_tf_dataset()
)
self._distribute_strategy.experimental_distribute_dataset(
self.data_adapter.get_tf_dataset()))
for step in range(self.steps_per_epoch):
yield step, self._current_iterator
else:
iterator = iter(self.data_adapter.get_tf_dataset())
iterator = iter(
self._distribute_strategy.experimental_distribute_dataset(
self.data_adapter.get_tf_dataset()))
if self.num_batches:
for step in range(self.num_batches):
yield step, iterator
Expand Down

0 comments on commit 6be0e2a

Please sign in to comment.