Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add *_on_batch functions to torch trainer #328

Merged
merged 2 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions integration_tests/torch_backend_keras_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
model.evaluate(x, y, verbose=0)
model.predict(x, verbose=0)

# Test on batch functions
model.train_on_batch(x, y)
model.test_on_batch(x, y)
model.predict_on_batch(x)

# Test functional model.
inputs = keras_core.Input(shape=(32, 32, 3))
outputs = layers.Conv2D(filters=10, kernel_size=3)(inputs)
Expand Down
257 changes: 201 additions & 56 deletions keras_core/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,100 @@ def __init__(self):
self.test_function = None
self.predict_function = None

def train_step(self, data):
data = data[0]
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
def make_train_function(self, force=False):
if self.train_function is not None and not force:
return self.train_function

def one_step_on_data(data):
"""Runs a single training step on a batch of data."""
data = data[0]
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(
data
)

# Compute prediction error
if self._call_has_training_arg():
y_pred = self(x, training=True)
else:
y_pred = self(x)
# Compute prediction error
if self._call_has_training_arg():
y_pred = self(x, training=True)
else:
y_pred = self(x)

# Call torch.nn.Module.zero_grad() to clear the leftover gradients for
# the weights from the previous train step.
self.zero_grad()
# Call torch.nn.Module.zero_grad() to clear the leftover gradients
# for the weights from the previous train step.
self.zero_grad()

loss = self.compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
)
self._loss_tracker.update_state(loss)
loss = self.compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
)
self._loss_tracker.update_state(loss)

# Compute gradients
if self.trainable_weights:
# Backpropagation
trainable_weights = [v for v in self.trainable_weights]
# Compute gradients
if self.trainable_weights:
# Backpropagation
trainable_weights = [v for v in self.trainable_weights]

# Call torch.Tensor.backward() on the loss to compute gradients for
# the weights.
loss.backward()
# Call torch.Tensor.backward() on the loss to compute gradients
# for the weights.
loss.backward()

gradients = [v.value.grad for v in trainable_weights]
gradients = [v.value.grad for v in trainable_weights]

# Update weights
# Update weights
with torch.no_grad():
self.optimizer.apply_gradients(
zip(gradients, trainable_weights)
)
else:
warnings.warn("The model does not have any trainable weights.")

return self.compute_metrics(
x, y, y_pred, sample_weight=sample_weight
)

self.train_function = one_step_on_data

def make_test_function(self, force=False):
if self.test_function is not None and not force:
return self.test_function

def one_step_on_data(data):
"""Runs a single test step on a batch of data."""
with torch.no_grad():
self.optimizer.apply_gradients(
zip(gradients, trainable_weights)
data = data[0]
(
x,
y,
sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg():
y_pred = self(x, training=False)
else:
y_pred = self(x)
loss = self.compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
)
else:
warnings.warn("The model does not have any trainable weights.")
self._loss_tracker.update_state(loss)
return self.compute_metrics(
x, y, y_pred, sample_weight=sample_weight
)

self.test_function = one_step_on_data

return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
def make_predict_function(self, force=False):
if self.predict_function is not None and not force:
return self.predict_function

def one_step_on_data(data):
"""Runs a predict test step on a batch of data."""
with torch.no_grad():
data = data[0]
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg():
y_pred = self(x, training=False)
else:
y_pred = self(x)
return y_pred

self.predict_function = one_step_on_data

def fit(
self,
Expand Down Expand Up @@ -127,6 +182,7 @@ def fit(
)

self.stop_training = False
self.make_train_function()
callbacks.on_train_begin()

for epoch in range(initial_epoch, epochs):
Expand All @@ -141,7 +197,7 @@ def fit(
# Callbacks
callbacks.on_train_batch_begin(step)

logs = self.train_step(data)
logs = self.train_function(data)

# Callbacks
callbacks.on_train_batch_end(step, self._pythonify_logs(logs))
Expand Down Expand Up @@ -197,19 +253,6 @@ def fit(
callbacks.on_train_end(logs=training_logs)
return self.history

def test_step(self, data):
data = data[0]
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg():
y_pred = self(x, training=False)
else:
y_pred = self(x)
loss = self.compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
)
self._loss_tracker.update_state(loss)
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)

def evaluate(
self,
x=None,
Expand Down Expand Up @@ -256,13 +299,13 @@ def evaluate(
# Switch the torch Module back to testing mode.
self.eval()

self.make_test_function()
callbacks.on_test_begin()
logs = None
self.reset_metrics()
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_test_batch_begin(step)
with torch.no_grad():
logs = self.test_step(data)
logs = self.test_function(data)
callbacks.on_test_batch_end(step, self._pythonify_logs(logs))
logs = self.get_metrics_result()
callbacks.on_test_end(logs)
Expand All @@ -271,15 +314,6 @@ def evaluate(
return logs
return self._flatten_metrics_in_order(logs)

def predict_step(self, data):
data = data[0]
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg():
y_pred = self(x, training=False)
else:
y_pred = self(x)
return y_pred

def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
Expand Down Expand Up @@ -322,15 +356,126 @@ def append_to_outputs(batch_outputs, outputs):
# Switch the torch Module back to testing mode.
self.eval()

self.make_predict_function()
callbacks.on_predict_begin()
outputs = None
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_predict_batch_begin(step)
with torch.no_grad():
batch_outputs = self.predict_step(data)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(
batch_outputs, np.concatenate, outputs
)

def train_on_batch(
self,
x,
y=None,
sample_weight=None,
class_weight=None,
return_dict=False,
):
"""Runs a single gradient update on a single batch of data.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something we should do soon is make sure the base Trainer class has docstrings everywhere, that the Trainer subclasses don't have docstrings, then we programatically set the docstrings on the subclasses (using the ones from the base class)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Issue created: #329


Args:
x: Input data. Must be array-like.
y: Target data. Must be array-like.
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
In the case of temporal data, you can pass a 2D array
with shape `(samples, sequence_length)`, to apply a different
weight to every timestep of every sample.
class_weight: Optional dictionary mapping class indices (integers)
to a weight (float) to apply to the model's loss for the samples
from this class during training. This can be useful to tell the
model to "pay more attention" to samples from an
under-represented class. When `class_weight` is specified
and targets have a rank of 2 or greater, either `y` must
be one-hot encoded, or an explicit final dimension of 1
must be included for sparse class labels.
return_dict: If `True`, loss and metric results are returned as a
dict, with each key being the name of the metric. If `False`,
they are returned as a list.

Returns:
A scalar loss value (when no metrics and `return_dict=False`),
a list of loss and metric values
(if there are metrics and `return_dict=False`), or a dict of
metric and loss values (if `return_dict=True`).
"""
self._assert_compile_called("train_on_batch")
self.make_train_function()
if class_weight is not None:
if sample_weight is not None:
raise ValueError(
"Arguments `sample_weight` and `class_weight` "
"cannot be specified at the same time. "
f"Received: sample_weight={sample_weight}, "
f"class_weight={class_weight}"
)
sample_weight = data_adapter_utils.class_weight_to_sample_weights(
y, class_weight
)

data = (x, y, sample_weight)
logs = self.train_function([data])
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)

def test_on_batch(
self,
x,
y=None,
sample_weight=None,
return_dict=False,
):
"""Test the model on a single batch of samples.

Args:
x: Input data. Must be array-like.
y: Target data. Must be array-like.
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
In the case of temporal data, you can pass a 2D array
with shape `(samples, sequence_length)`, to apply a different
weight to every timestep of every sample.
return_dict: If `True`, loss and metric results are returned as a
dict, with each key being the name of the metric. If `False`,
they are returned as a list.

Returns:
A scalar loss value (when no metrics and `return_dict=False`),
a list of loss and metric values
(if there are metrics and `return_dict=False`), or a dict of
metric and loss values (if `return_dict=True`).
"""
self._assert_compile_called("test_on_batch")
self.make_test_function()

data = (x, y, sample_weight)

logs = self.test_function([data])
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)

def predict_on_batch(self, x):
"""Returns predictions for a single batch of samples.

Args:
x: Input data. It must be array-like.

Returns:
NumPy array(s) of predictions.
"""
self.make_predict_function()
batch_outputs = self.predict_function((x,))
batch_outputs = tf.nest.map_structure(
lambda x: np.array(x), batch_outputs
)
return batch_outputs