Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3-2 中阶API train_model提示 Internal: No unary variant device copy function found for direction... #18

Closed
kaixih opened this issue Apr 11, 2020 · 1 comment

Comments

@kaixih
Copy link

kaixih commented Apr 11, 2020

这章的实例code好像在最新的tensorflow下不能用会遇到

Traceback (most recent call last):
  File "demo.py", line 40, in <module>
    train_model(model,epochs = 200)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 608, in __call__
    result = self._call(*args, **kwds)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 678, in _call
    return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1665, in _filtered_call
    self.captured_inputs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1746, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 598, in call
    ctx=ctx)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InternalError: 2 root error(s) found.
  (0) Internal:  No unary variant device copy function found for direction: 1 and Variant type_index: tensorflow::data::(anonymous namespace)::DatasetVariantWrapper
         [[{{node while_input_5/_12}}]]
         [[Func/while/body/_1/while/cond/then/_78/input/_91/_52]]
  (1) Internal:  No unary variant device copy function found for direction: 1 and Variant type_index: tensorflow::data::(anonymous namespace)::DatasetVariantWrapper
         [[{{node while_input_5/_12}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_model_342]

Function call stack:
train_model -> train_model

我把示例code中的visualization的部分都去掉以便于重现这个问题:

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers,losses,metrics,optimizers

n = 400

X = tf.random.uniform([n,2],minval=-10,maxval=10)
w0 = tf.constant([[2.0],[-3.0]])
b0 = tf.constant([[3.0]])
Y = X@w0 + b0 + tf.random.normal([n,1],mean = 0.0,stddev= 2.0)

ds = tf.data.Dataset.from_tensor_slices((X,Y)) \
     .shuffle(buffer_size = 100).batch(10) \
     .prefetch(tf.data.experimental.AUTOTUNE)

model = layers.Dense(units = 1)
model.build(input_shape = (2,))
model.loss_func = losses.mean_squared_error
model.optimizer = optimizers.SGD(learning_rate=0.001)

@tf.function
def train_step(model, features, labels):
    with tf.GradientTape() as tape:
        predictions = model(features)
        loss = model.loss_func(tf.reshape(labels,[-1]), tf.reshape(predictions,[-1]))
    grads = tape.gradient(loss,model.variables)
    model.optimizer.apply_gradients(zip(grads,model.variables))
    return loss

@tf.function
def train_model(model,epochs):
    for epoch in tf.range(1,epochs+1):
        loss = tf.constant(0.0)
        for features, labels in ds:
            loss = train_step(model,features,labels)
        if epoch%50==0:
            tf.print("epoch =",epoch,"loss = ",loss)
            tf.print("w =",model.variables[0])
            tf.print("b =",model.variables[1])
train_model(model,epochs = 200)

问题应该是出现再train_model这个function里。如果把train_model上的@tf.function去掉,则没有问题。难道原因是不能在tf function里操作tf.dataset?

我使用的是tensorflow的nightly build。谢谢

@kaixih
Copy link
Author

kaixih commented Apr 12, 2020

问题可能是关于dataset + GPU的。如果disable GPU,也没有错误。Closing.

@kaixih kaixih closed this as completed Apr 12, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant