Skip to content

Commit

Permalink
Simplified handling of 1-dimensional data #1
Browse files Browse the repository at this point in the history
Added expansion of 1-dimensional Tensors into 2-dimensional Tensors.
  • Loading branch information
fabiodimarco committed Dec 29, 2020
1 parent 4740db3 commit ec9b661
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions levenberg_marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# ==============================================================================

import tensorflow as tf
from tensorflow.python.keras.engine import data_adapter

# ==============================================================================

Expand Down Expand Up @@ -651,7 +652,11 @@ def fit(self, dataset, epochs=1, metrics=None):

pl.on_train_batch_begin(step)

inputs, targets = next(iterator)
data = next(iterator)

data = data_adapter.expand_1d(data)
inputs, targets, sample_weight = \
data_adapter.unpack_x_y_sample_weight(data)

loss, outputs, attempts, stop_training = \
self.train_step(inputs, targets)
Expand Down Expand Up @@ -728,7 +733,9 @@ def _assign_stop_training(self, value):
self.stop_training = value.numpy().item()

def train_step(self, data):
inputs, targets = data
data = data_adapter.expand_1d(data)
inputs, targets, sample_weight = \
data_adapter.unpack_x_y_sample_weight(data)

loss, outputs, attempts, stop_training = \
self.trainer.train_step(inputs, targets)
Expand Down

0 comments on commit ec9b661

Please sign in to comment.