Skip to content

Commit

Permalink
Fixes #68 - changes to tensorflow 2.2 api
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Fischer committed May 23, 2020
1 parent 7f34bd5 commit 88732d8
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions dataworkspaces/kits/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __iter__(self):
return _wrap_generator(self.wrapped, self.hash_state)

def on_epoch_end(self):
return self.on_epoch_end()
return self.wrapped.on_epoch_end()


class DwsModelCheckpoint(ModelCheckpoint):
Expand Down Expand Up @@ -492,17 +492,28 @@ def compile(
self._dws_state.lineage.add_param("loss_function", loss)
elif isinstance(loss, losses.Loss):
self._dws_state.lineage.add_param("loss_function", loss.__class__.__name__)
return super().compile(
optimizer,
loss,
metrics,
loss_weights,
sample_weight_mode,
weighted_metrics,
target_tensors,
distribute,
**kwargs,
)
if tensorflow.__version__<"2.2.": # type: ignore
return super().compile(
optimizer,
loss,
metrics,
loss_weights,
sample_weight_mode,
weighted_metrics,
target_tensors,
distribute,
**kwargs,
)
else: # starting in 2.2, tensorflow removed the tartet_tensors and distribute args
return super().compile(
optimizer,
loss,
metrics,
loss_weights,
sample_weight_mode,
weighted_metrics,
**kwargs,
)

def fit(self, x, y=None, **kwargs):
"""x, y can be arrays or x can be a generator.
Expand Down

0 comments on commit 88732d8

Please sign in to comment.