Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Conversation

@natuan
Copy link
Contributor

@natuan natuan commented Feb 15, 2021

Allow to use native Keras. Only Keras >= 2.4.3 is supported due to the following reasons:

  • for Keras < 2.4.0, the clone_model function does not have the param clone_layer required by the modifiers;
  • the model's fit functions in 2.4.0/2.4.1/2.4.2 do not work with the current custom callbacks such as LRModifierCallback w/ error AttributeError: 'LRModifierCallback' object has no attribute '_implements_train_batch_hooks'

For TF2.0, model save hit an assertion inside TF, and therefore we only enable >=TF2.1.

@natuan natuan requested a review from a team February 15, 2021 05:34
version = [int(v) for v in tensorflow.__version__.split(".")]
if version[0] != 2 or version[1] < 2:
raise Exception
raise RuntimeError("TensorFlow >= 2.2 is required, found {}".format(version))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we still restricted on 2.2 right now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to support from TF 2.1: for TF2.0 there's a bug in TF model save.

@natuan natuan force-pushed the tuan/native_keras_tf2x branch 3 times, most recently from 9e83e5c to eece7bc Compare March 16, 2021 04:34
from tensorflow.keras.optimizers.schedules import LearningRateSchedule


try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to change this over to our shared conditional keras import

from typing import List, Union

import tensorflow
import tensorflow as tf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: remove the abbreviation for tensorflow

import tensorflow as tf


try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to change this over to our shared conditional keras import

from typing import Dict, List, Union

import tensorflow
import tensorflow as tf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: remove abbreviation for tensorflow

tensor
)
mask = tf.cast(tf.not_equal(tensor, 0.0), tensor.dtype)
sparsity = tf.math.reduce_sum(1.0 - mask).numpy() / tf.size(tensor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is going to return a numpy float and not a python float. Need to update the types or use .item.

Additionally, the sparsity is set to None and returned and the return type does not reflect that for this function


from tensorflow.keras import Model

try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to change this over to our shared conditional keras import

from tensorflow.keras.models import Model


try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to change this over to our shared conditional keras import

import tensorflow as tf


try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to change this over to our shared conditional keras import

@natuan natuan changed the title Native keras support [WIP] Native keras support Mar 19, 2021
@natuan natuan force-pushed the tuan/native_keras_tf2x branch from eece7bc to 4c8561c Compare March 19, 2021 17:02
@natuan natuan changed the title [WIP] Native keras support Native Keras support Mar 19, 2021
@natuan natuan merged commit f366825 into main Mar 19, 2021
@markurtz markurtz deleted the tuan/native_keras_tf2x branch April 3, 2021 17:58
markurtz added a commit that referenced this pull request Sep 1, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants