-
Notifications
You must be signed in to change notification settings - Fork 157
Native Keras support #74
Conversation
src/sparseml/keras/__init__.py
Outdated
| version = [int(v) for v in tensorflow.__version__.split(".")] | ||
| if version[0] != 2 or version[1] < 2: | ||
| raise Exception | ||
| raise RuntimeError("TensorFlow >= 2.2 is required, found {}".format(version)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we still restricted on 2.2 right now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to support from TF 2.1: for TF2.0 there's a bug in TF model save.
9e83e5c to
eece7bc
Compare
| from tensorflow.keras.optimizers.schedules import LearningRateSchedule | ||
|
|
||
|
|
||
| try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to change this over to our shared conditional keras import
| from typing import List, Union | ||
|
|
||
| import tensorflow | ||
| import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: remove the abbreviation for tensorflow
| import tensorflow as tf | ||
|
|
||
|
|
||
| try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to change this over to our shared conditional keras import
| from typing import Dict, List, Union | ||
|
|
||
| import tensorflow | ||
| import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: remove abbreviation for tensorflow
| tensor | ||
| ) | ||
| mask = tf.cast(tf.not_equal(tensor, 0.0), tensor.dtype) | ||
| sparsity = tf.math.reduce_sum(1.0 - mask).numpy() / tf.size(tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is going to return a numpy float and not a python float. Need to update the types or use .item.
Additionally, the sparsity is set to None and returned and the return type does not reflect that for this function
src/sparseml/keras/utils/exporter.py
Outdated
|
|
||
| from tensorflow.keras import Model | ||
|
|
||
| try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to change this over to our shared conditional keras import
tests/sparseml/keras/optim/mock.py
Outdated
| from tensorflow.keras.models import Model | ||
|
|
||
|
|
||
| try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to change this over to our shared conditional keras import
| import tensorflow as tf | ||
|
|
||
|
|
||
| try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to change this over to our shared conditional keras import
eece7bc to
4c8561c
Compare
Allow to use native Keras. Only Keras >= 2.4.3 is supported due to the following reasons:
clone_modelfunction does not have the paramclone_layerrequired by the modifiers;fitfunctions in 2.4.0/2.4.1/2.4.2 do not work with the current custom callbacks such asLRModifierCallbackw/ errorAttributeError: 'LRModifierCallback' object has no attribute '_implements_train_batch_hooks'For TF2.0, model save hit an assertion inside TF, and therefore we only enable >=TF2.1.