Skip to content

Commit

Permalink
add pytorch backend trainer (#247)
Browse files Browse the repository at this point in the history
* add pytorch backend trainer

* addressing the comments

---------

Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
  • Loading branch information
haifeng-jin and haifeng-jin committed Jun 2, 2023
1 parent 49991ef commit 0724e27
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 16 deletions.
17 changes: 17 additions & 0 deletions integration_tests/torch_backend_keras_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np

import keras_core
from keras_core import layers

model = keras_core.Sequential(
[
layers.Dense(1),
]
)
model.compile(loss="mse", optimizer="adam", metrics=["mae"])
history = model.fit(
x=np.random.rand(100, 10),
y=np.random.rand(100, 1),
epochs=10,
shuffle=False,
)
9 changes: 3 additions & 6 deletions keras_core/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ def standardize_dtype(dtype):
dtype = "int32"
if hasattr(dtype, "name"):
dtype = dtype.name
elif config.backend() == "torch":
dtype = str(dtype).split(".")[-1]

if dtype not in ALLOWED_DTYPES:
raise ValueError(f"Invalid dtype: {dtype}")
Expand Down Expand Up @@ -454,12 +456,7 @@ def shape_equal(a, b):


def is_float_dtype(dtype):
if hasattr(dtype, "name"):
dtype = dtype.name
# The is a torch.dtype when using torch backend.
# Need to convert it to a str.
if not isinstance(dtype, str):
dtype = str(dtype).split(".")[-1]
dtype = standardize_dtype(dtype)
return dtype.startswith("float") or dtype.startswith("bfloat")


Expand Down
9 changes: 7 additions & 2 deletions keras_core/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def to_torch_dtype(dtype):
class Variable(KerasVariable):
def _initialize(self, value):
self._value = convert_to_tensor(value, dtype=self._dtype)
self._value.requires_grad_(self.trainable)

def _direct_assign(self, value):
self._value = value
self._value.copy_(value)

def _convert_to_tensor(self, value, dtype=None):
return convert_to_tensor(value, dtype=dtype)
Expand Down Expand Up @@ -78,7 +79,11 @@ def shape(x):

def cast(x, dtype):
dtype = to_torch_dtype(dtype)
return x.to(dtype)
if isinstance(x, KerasVariable):
x = x.value
if is_tensor(x):
return x.to(dtype)
return convert_to_tensor(x, dtype)


def name_scope(name):
Expand Down
2 changes: 1 addition & 1 deletion keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def ones(shape, dtype="float32"):

def zeros(shape, dtype="float32"):
dtype = to_torch_dtype(dtype)
return torch.zeros(*shape, dtype=dtype)
return torch.zeros(size=shape, dtype=dtype)


def absolute(x):
Expand Down
201 changes: 201 additions & 0 deletions keras_core/backend/torch/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,207 @@
import warnings

import torch

from keras_core import callbacks as callbacks_module
from keras_core import optimizers as optimizers_module
from keras_core.trainers import trainer as base_trainer
from keras_core.trainers.data_adapters import data_adapter_utils
from keras_core.trainers.epoch_iterator import EpochIterator


class TorchTrainer(base_trainer.Trainer):
def __init__(self):
super().__init__()
self.train_function = None
self.test_function = None
self.predict_function = None

def train_step(self, data):
data = data[0]
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)

# Compute prediction error
if self._call_has_training_arg():
y_pred = self(x, training=True)
else:
y_pred = self(x)
self.zero_grad()
loss = self.compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
)
self._loss_tracker.update_state(loss)

# Compute gradients
if self.trainable_weights:
# Backpropagation
trainable_weights = [v for v in self.trainable_weights]
loss.backward()
gradients = [v.value.grad for v in trainable_weights]

# Update weights
with torch.no_grad():
self.optimizer.apply_gradients(
zip(gradients, trainable_weights)
)
else:
warnings.warn("The model does not have any trainable weights.")

return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)

def fit(
self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
):
if not self.compiled:
raise ValueError(
"You must call `compile()` before calling `fit()`."
)

# TODO: respect compiled trainable state
if validation_split and validation_data is None:
# Create the validation data using the training data. Only supported
# for TF/numpy/jax arrays.
# TODO: Support torch tensors for validation data.
(
x,
y,
sample_weight,
), validation_data = data_adapter_utils.train_validation_split(
(x, y, sample_weight), validation_split=validation_split
)

if validation_data:
(
val_x,
val_y,
val_sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(validation_data)

# Create an iterator that yields batches for one epoch.
epoch_iterator = EpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps_per_epoch,
shuffle=shuffle,
class_weight=class_weight,
steps_per_execution=self.steps_per_execution,
)

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
verbose=verbose,
epochs=epochs,
steps=epoch_iterator.num_batches,
model=self,
)

self.stop_training = False
callbacks.on_train_begin()

for epoch in range(initial_epoch, epochs):
self.reset_metrics()
callbacks.on_epoch_begin(epoch)

# Switch the torch Module to training mode. Inform torch layers to
# do training behavior in case the user did not use `self.training`
# when implementing a custom layer with torch layers.
self.train()
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Callbacks
callbacks.on_train_batch_begin(step)

logs = self.train_step(data)

# Callbacks
callbacks.on_train_batch_end(step, logs)
if self.stop_training:
break

# Override with model metrics instead of last step logs
epoch_logs = self._pythonify_logs(self.get_metrics_result())

# Switch the torch Module back to testing mode.
self.eval()

# Run validation.
if validation_data and self._should_eval(epoch, validation_freq):
# Create EpochIterator for evaluation and cache it.
if getattr(self, "_eval_epoch_iterator", None) is None:
self._eval_epoch_iterator = EpochIterator(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
steps_per_execution=self.steps_per_execution,
)
val_logs = self.evaluate(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
steps=validation_steps,
callbacks=callbacks,
return_dict=True,
_use_cached_eval_dataset=True,
)
val_logs = {
"val_" + name: val for name, val in val_logs.items()
}
epoch_logs.update(self._pythonify_logs(val_logs))

callbacks.on_epoch_end(epoch, epoch_logs)
training_logs = epoch_logs
if self.stop_training:
break

if (
isinstance(self.optimizer, optimizers_module.Optimizer)
and epochs > 0
):
self.optimizer.finalize_variable_values(self.trainable_weights)

# If _eval_epoch_iterator exists, delete it after all epochs are done.
if getattr(self, "_eval_epoch_iterator", None) is not None:
del self._eval_epoch_iterator
callbacks.on_train_end(logs=training_logs)
return self.history

def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
pass

def evaluate(
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
**kwargs,
):
pass
7 changes: 3 additions & 4 deletions keras_core/utils/progbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import sys
import time

import numpy as np

from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.utils import io_utils

Expand Down Expand Up @@ -159,7 +158,7 @@ def update(self, current, values=None, finalize=None):
for k in self._values_order:
info += f" - {k}:"
if isinstance(self._values[k], list):
avg = np.mean(
avg = ops.mean(
self._values[k][0] / max(1, self._values[k][1])
)
if abs(avg) > 1e-3:
Expand Down Expand Up @@ -188,7 +187,7 @@ def update(self, current, values=None, finalize=None):
info += " -" + self._format_time(time_per_unit, self.unit_name)
for k in self._values_order:
info += f" - {k}:"
avg = np.mean(
avg = ops.mean(
self._values[k][0] / max(1, self._values[k][1])
)
if avg > 1e-3:
Expand Down
9 changes: 6 additions & 3 deletions keras_core/utils/traceback_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def error_handler(*args, **kwargs):
else:
value = arg.default
arguments_context.append(f" • {arg.name}={value}")

if arguments_context:
arguments_context = "\n".join(arguments_context)
# Get original error message and append information to it.
Expand Down Expand Up @@ -220,9 +219,13 @@ def format_argument_value(value):
tensor_cls = "tf.Tensor"
elif backend.backend() == "jax":
tensor_cls = "jnp.ndarray"
elif backend.backend() == "pytorch":
elif backend.backend() == "torch":
tensor_cls = "torch.Tensor"
else:
tensor_cls = "array"
return f"{tensor_cls}(shape={value.shape}, dtype={value.dtype.name})"

return (
f"{tensor_cls}(shape={value.shape}, "
f"dtype={backend.standardize_dtype(value.dtype)})"
)
return repr(value)

0 comments on commit 0724e27

Please sign in to comment.