Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UnliftableError while upgrading from TFF 0.15.0 to 0.16.0 and later #913

Open
dpreuveneers opened this issue Sep 30, 2020 · 3 comments
Open
Assignees
Labels
bug Something isn't working

Comments

@dpreuveneers
Copy link

The sample below is a simplified code snippet for an autoencoder. The example works with TFF 0.15.0 (on top of TF 2.2.1), but does no longer with TFF 0.16.0 and later.

import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

num_clients = 4
x_train = np.load('x_train.npy').astype(np.float16)
input_dim = x_train.shape[1]
x_train_clients = np.array_split(x_train, num_clients)

def map_fn(example):
    return collections.OrderedDict(x=example, y=example)

def client_data(n):
    ds = tf.data.Dataset.from_tensor_slices(x_train_clients[n])
    return ds.batch(128).map(map_fn)

# Pick a subset of client devices to participate in training.
train_data = [client_data(n) for n in range(num_clients)]
input_spec = train_data[0].element_spec

# Wrap a Keras model for use with TFF.
def model_fn():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(input_dim,)),
        tf.keras.layers.Dense(50, activation='relu', activity_regularizer=tf.keras.regularizers.l1(1e-4)),
        tf.keras.layers.Dense(input_dim, activation='sigmoid'),
    ])
    return tff.learning.from_keras_model(
        model,
        input_spec=input_spec,
        loss=tf.keras.losses.MeanSquaredError(),
        metrics=[tf.keras.metrics.Accuracy()])

# Simulate a few rounds of training with the selected client devices.
trainer = tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.Adam(lr=2e-3))

def evaluate(num_epochs):
    state = trainer.initialize()
    for epoch in range(num_epochs):
        state, metrics = trainer.next(state, train_data)
        print('Epoch {}: training loss {:.8f}'.format(epoch, metrics.train.loss))

evaluate(10)

With TFF 0.16.0 and later I get the following error:
tensorflow.python.ops.op_selector.UnliftableError: Unable to lift tensor <tf.Tensor 'StatefulPartitionedCall_1:0' shape=(77, 50) dtype=float32> because it depends transitively on placeholder <tf.Operation 'input_1' type=Placeholder> via at least one path, e.g.: StatefulPartitionedCall_1 (StatefulPartitionedCall) <- dense/ActivityRegularizer/truediv (RealDiv) <- dense/ActivityRegularizer/Cast (Cast) <- dense/ActivityRegularizer/strided_slice (StridedSlice) <- dense/ActivityRegularizer/Shape (Shape) <- dense/Relu (Relu) <- dense/BiasAdd (BiasAdd) <- dense/MatMul (MatMul) <- input_1 (Placeholder)

The same UnliftableError error occurs with TFF/TF nightly.

According to the github release info of TFF 0.16.0, the only breaking change is "Renamed AnonymousTuple to Struct". I perhaps mistakenly assumed the above code would still work without any changes.

Is there a way I can work around the UnliftableError in my code or is this something that needs to be addressed in TFF?

@dpreuveneers dpreuveneers added the bug Something isn't working label Sep 30, 2020
@jkr26
Copy link
Collaborator

jkr26 commented Oct 21, 2020

You've hit an interesting bug. Looking at the stacktrace of a reproduction, I see that this error is coming from TFF execution (IE, the next call here), in fact during a prune call we (well, I) added in June, actually to avoid a similar issue.

Generally, TFF needs quite precise control over the TensorFlow graphs it works with; this results in some patterns that TensorFlow is unused to seeing, and TFF tends to hit weird corners of TF. Throw Keras into the mix and it gets exciting--that seems to be what is happening here. One thing definitely jumps out at me from the code above: I'm not sure I've seen anyone use the activity_regularizer parameter to a Keras layer with TFF before. In fact, removing this parameter makes my reproduction pass.

I'm making a failing test to reproduce inside of TFF, and will investigate further.

@jkr26 jkr26 self-assigned this Oct 21, 2020
tensorflow-copybara pushed a commit that referenced this issue Oct 21, 2020
PiperOrigin-RevId: 338282399
@samuelstevens
Copy link

Any progress on this issue? I am running into this error with TFF 0.19.0 and TF 2.5.1 (most recent at time of writing).

@jkr26
Copy link
Collaborator

jkr26 commented Sep 8, 2021

Sorry--I spent a little time attempting to debug without getting much of anywhere. It's possible that our best bet is to bring this issue to TF or Keras teams

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants