Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch EngineException: Inference tensors cannot be saved for backward #2159

Closed
aminnasiri opened this issue Nov 17, 2022 · 3 comments
Closed
Labels
bug Something isn't working

Comments

@aminnasiri
Copy link

aminnasiri commented Nov 17, 2022

Description

I developed an example of Rank Classification using BERT on Amazon Review dataset. This guide and It worked fine with Apache MXNET engine, but it is throwing EngineException on Pytorch engine.

Expected Behavior

I am expecting to see the Pytorch engine is working fine too.

Error Message

Error: ai.djl.engine.EngineException: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd. at ai.djl.pytorch.jni.PyTorchLibrary.torchNNLinear(Native Method) at ai.djl.pytorch.jni.JniUtils.linear(JniUtils.java:1189) at ai.djl.pytorch.engine.PtNDArrayEx.linear(PtNDArrayEx.java:390) at ai.djl.nn.core.Linear.linear(Linear.java:183) at ai.djl.nn.core.Linear.forwardInternal(Linear.java:88) at ai.djl.nn.AbstractBaseBlock.forwardInternal(AbstractBaseBlock.java:126) at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:91) at ai.djl.nn.SequentialBlock.forwardInternal(SequentialBlock.java:209) at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:91) at ai.djl.training.Trainer.forward(Trainer.java:175) at ai.djl.training.EasyTrain.trainSplit(EasyTrain.java:122) at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:110) at ai.djl.training.EasyTrain.fit(EasyTrain.java:58) at com.thinksky.classification.TrainModel.trainModel(TrainModel.java:96) at com.thinksky.classification.TrainModel_ClientProxy.trainModel(Unknown Source) at com.thinksky.PredictResource.executeQuery(PredictResource.java:104) at com.thinksky.PredictResource_VertxInvoker_executeQuery_315faff7c0f6c7728fd1c92cfb1b39aa7f024059.invokeBean(Unknown Source) at io.quarkus.vertx.runtime.EventConsumerInvoker.invoke(EventConsumerInvoker.java:41) at io.quarkus.vertx.runtime.VertxRecorder$3$1.handle(VertxRecorder.java:135) at io.quarkus.vertx.runtime.VertxRecorder$3$1.handle(VertxRecorder.java:105) at io.vertx.core.impl.ContextInternal.dispatch(ContextInternal.java:264) at io.vertx.core.eventbus.impl.MessageConsumerImpl.dispatch(MessageConsumerImpl.java:177) at io.vertx.core.eventbus.impl.HandlerRegistration$InboundDeliveryContext.execute(HandlerRegistration.java:137) at io.vertx.core.eventbus.impl.DeliveryContextBase.next(DeliveryContextBase.java:72) at io.vertx.core.eventbus.impl.DeliveryContextBase.dispatch(DeliveryContextBase.java:43) at io.vertx.core.eventbus.impl.HandlerRegistration.dispatch(HandlerRegistration.java:98) at io.vertx.core.eventbus.impl.MessageConsumerImpl.deliver(MessageConsumerImpl.java:183) at io.vertx.core.eventbus.impl.MessageConsumerImpl.doReceive(MessageConsumerImpl.java:168) at io.vertx.core.eventbus.impl.HandlerRegistration.lambda$receive$0(HandlerRegistration.java:49) at io.netty.util.concurrent.AbstractEventExecutor.runTask(AbstractEventExecutor.java:174) at io.netty.util.concurrent.AbstractEventExecutor.safeExecute(AbstractEventExecutor.java:167) at io.netty.util.concurrent.SingleThreadEventExecutor.runAllTasks(SingleThreadEventExecutor.java:470) at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:569) at io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:997) at io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74) at io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30) at java.base/java.lang.Thread.run(Thread.java:1589)

How to Reproduce?

This is kind of the same project

Steps to reproduce

  1. Run the project: mvn quarkus:dev
  2. Call training endpoint: curl -X GET http://localhost:8080/predict/model

What have you tried to solve it?

Set these properties

  • System.setProperty("requires_grad", "True");
  • System.setProperty("retain_graph", "False");
    but doesn't have any effect.

Environment Info

OS: Macos
JDK: Java 19

Dependecies

<dependency>
      <groupId>ai.djl.pytorch</groupId>
      <artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
      <groupId>ai.djl.pytorch</groupId>
      <artifactId>pytorch-native-cpu</artifactId>
      <classifier>osx-x86_64</classifier>
      <version>1.12.1</version>
      <scope>runtime</scope>
</dependency>
<dependency>
      <groupId>ai.djl.pytorch</groupId>
      <artifactId>pytorch-jni</artifactId>
      <version>1.12.1-0.19.0</version>
      <scope>runtime</scope>
</dependency>
@aminnasiri aminnasiri added the bug Something isn't working label Nov 17, 2022
@frankfliu
Copy link
Contributor

I think you hit the same issue as: #2144

@KexinFeng
Copy link
Contributor

KexinFeng commented Nov 18, 2022

This error indicates that the backward propogation is affecting the parameters, which is protected in the inference mode and cannot be updated. Look at this issue #2144. Consider the solution there.

Or, with PyTorch engine, you can also look at this transfer learning example https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java
for training an embedding block.

@aminnasiri
Copy link
Author

aminnasiri commented Nov 18, 2022

Thanks @frankfliu & @KexinFeng.
This approach is working fine for this issue, so I'm gonna close it.
#2144 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants