-
Notifications
You must be signed in to change notification settings - Fork 114
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add pytorch backend trainer * addressing the comments --------- Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
- Loading branch information
1 parent
49991ef
commit 0724e27
Showing
7 changed files
with
238 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import numpy as np | ||
|
||
import keras_core | ||
from keras_core import layers | ||
|
||
model = keras_core.Sequential( | ||
[ | ||
layers.Dense(1), | ||
] | ||
) | ||
model.compile(loss="mse", optimizer="adam", metrics=["mae"]) | ||
history = model.fit( | ||
x=np.random.rand(100, 10), | ||
y=np.random.rand(100, 1), | ||
epochs=10, | ||
shuffle=False, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,207 @@ | ||
import warnings | ||
|
||
import torch | ||
|
||
from keras_core import callbacks as callbacks_module | ||
from keras_core import optimizers as optimizers_module | ||
from keras_core.trainers import trainer as base_trainer | ||
from keras_core.trainers.data_adapters import data_adapter_utils | ||
from keras_core.trainers.epoch_iterator import EpochIterator | ||
|
||
|
||
class TorchTrainer(base_trainer.Trainer): | ||
def __init__(self): | ||
super().__init__() | ||
self.train_function = None | ||
self.test_function = None | ||
self.predict_function = None | ||
|
||
def train_step(self, data): | ||
data = data[0] | ||
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) | ||
|
||
# Compute prediction error | ||
if self._call_has_training_arg(): | ||
y_pred = self(x, training=True) | ||
else: | ||
y_pred = self(x) | ||
self.zero_grad() | ||
loss = self.compute_loss( | ||
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight | ||
) | ||
self._loss_tracker.update_state(loss) | ||
|
||
# Compute gradients | ||
if self.trainable_weights: | ||
# Backpropagation | ||
trainable_weights = [v for v in self.trainable_weights] | ||
loss.backward() | ||
gradients = [v.value.grad for v in trainable_weights] | ||
|
||
# Update weights | ||
with torch.no_grad(): | ||
self.optimizer.apply_gradients( | ||
zip(gradients, trainable_weights) | ||
) | ||
else: | ||
warnings.warn("The model does not have any trainable weights.") | ||
|
||
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) | ||
|
||
def fit( | ||
self, | ||
x=None, | ||
y=None, | ||
batch_size=None, | ||
epochs=1, | ||
verbose="auto", | ||
callbacks=None, | ||
validation_split=0.0, | ||
validation_data=None, | ||
shuffle=True, | ||
class_weight=None, | ||
sample_weight=None, | ||
initial_epoch=0, | ||
steps_per_epoch=None, | ||
validation_steps=None, | ||
validation_batch_size=None, | ||
validation_freq=1, | ||
): | ||
if not self.compiled: | ||
raise ValueError( | ||
"You must call `compile()` before calling `fit()`." | ||
) | ||
|
||
# TODO: respect compiled trainable state | ||
if validation_split and validation_data is None: | ||
# Create the validation data using the training data. Only supported | ||
# for TF/numpy/jax arrays. | ||
# TODO: Support torch tensors for validation data. | ||
( | ||
x, | ||
y, | ||
sample_weight, | ||
), validation_data = data_adapter_utils.train_validation_split( | ||
(x, y, sample_weight), validation_split=validation_split | ||
) | ||
|
||
if validation_data: | ||
( | ||
val_x, | ||
val_y, | ||
val_sample_weight, | ||
) = data_adapter_utils.unpack_x_y_sample_weight(validation_data) | ||
|
||
# Create an iterator that yields batches for one epoch. | ||
epoch_iterator = EpochIterator( | ||
x=x, | ||
y=y, | ||
sample_weight=sample_weight, | ||
batch_size=batch_size, | ||
steps_per_epoch=steps_per_epoch, | ||
shuffle=shuffle, | ||
class_weight=class_weight, | ||
steps_per_execution=self.steps_per_execution, | ||
) | ||
|
||
# Container that configures and calls callbacks. | ||
if not isinstance(callbacks, callbacks_module.CallbackList): | ||
callbacks = callbacks_module.CallbackList( | ||
callbacks, | ||
add_history=True, | ||
add_progbar=verbose != 0, | ||
verbose=verbose, | ||
epochs=epochs, | ||
steps=epoch_iterator.num_batches, | ||
model=self, | ||
) | ||
|
||
self.stop_training = False | ||
callbacks.on_train_begin() | ||
|
||
for epoch in range(initial_epoch, epochs): | ||
self.reset_metrics() | ||
callbacks.on_epoch_begin(epoch) | ||
|
||
# Switch the torch Module to training mode. Inform torch layers to | ||
# do training behavior in case the user did not use `self.training` | ||
# when implementing a custom layer with torch layers. | ||
self.train() | ||
for step, data in epoch_iterator.enumerate_epoch(return_type="np"): | ||
# Callbacks | ||
callbacks.on_train_batch_begin(step) | ||
|
||
logs = self.train_step(data) | ||
|
||
# Callbacks | ||
callbacks.on_train_batch_end(step, logs) | ||
if self.stop_training: | ||
break | ||
|
||
# Override with model metrics instead of last step logs | ||
epoch_logs = self._pythonify_logs(self.get_metrics_result()) | ||
|
||
# Switch the torch Module back to testing mode. | ||
self.eval() | ||
|
||
# Run validation. | ||
if validation_data and self._should_eval(epoch, validation_freq): | ||
# Create EpochIterator for evaluation and cache it. | ||
if getattr(self, "_eval_epoch_iterator", None) is None: | ||
self._eval_epoch_iterator = EpochIterator( | ||
x=val_x, | ||
y=val_y, | ||
sample_weight=val_sample_weight, | ||
batch_size=validation_batch_size or batch_size, | ||
steps_per_execution=self.steps_per_execution, | ||
) | ||
val_logs = self.evaluate( | ||
x=val_x, | ||
y=val_y, | ||
sample_weight=val_sample_weight, | ||
batch_size=validation_batch_size or batch_size, | ||
steps=validation_steps, | ||
callbacks=callbacks, | ||
return_dict=True, | ||
_use_cached_eval_dataset=True, | ||
) | ||
val_logs = { | ||
"val_" + name: val for name, val in val_logs.items() | ||
} | ||
epoch_logs.update(self._pythonify_logs(val_logs)) | ||
|
||
callbacks.on_epoch_end(epoch, epoch_logs) | ||
training_logs = epoch_logs | ||
if self.stop_training: | ||
break | ||
|
||
if ( | ||
isinstance(self.optimizer, optimizers_module.Optimizer) | ||
and epochs > 0 | ||
): | ||
self.optimizer.finalize_variable_values(self.trainable_weights) | ||
|
||
# If _eval_epoch_iterator exists, delete it after all epochs are done. | ||
if getattr(self, "_eval_epoch_iterator", None) is not None: | ||
del self._eval_epoch_iterator | ||
callbacks.on_train_end(logs=training_logs) | ||
return self.history | ||
|
||
def predict( | ||
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None | ||
): | ||
pass | ||
|
||
def evaluate( | ||
self, | ||
x=None, | ||
y=None, | ||
batch_size=None, | ||
verbose="auto", | ||
sample_weight=None, | ||
steps=None, | ||
callbacks=None, | ||
return_dict=False, | ||
**kwargs, | ||
): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters