Skip to content

Commit

Permalink
Update google_research to use public TF/keras API if possible.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 370162707
  • Loading branch information
qlzh727 authored and Copybara-Service committed Apr 23, 2021
1 parent 0782952 commit 88a0963
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions low_rank_local_connectivity/layers.py
Expand Up @@ -29,7 +29,6 @@
import numpy as np
import tensorflow.compat.v1 as tf

from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import tf_utils

Expand Down Expand Up @@ -543,9 +542,11 @@ def build(self, input_shape):
self.bias = tf.math.add(self.bias_spatial, self.bias_channels, name='bias')

if self.data_format == 'channels_last':
self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
self.input_spec = tf.keras.layers.InputSpec(
ndim=4, axes={-1: input_filter})
else:
self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
self.input_spec = tf.keras.layers.InputSpec(
ndim=4, axes={1: input_filter})

self.built = True

Expand Down

0 comments on commit 88a0963

Please sign in to comment.