-
Notifications
You must be signed in to change notification settings - Fork 655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model load failed #1663
Comments
Please fix this asap. This is a deal break. I won't be able to load the trained model. I've to turn to other framework. |
Sorry for the delay. When you training a model in DJL, the trainer only save the model's parameters. The block information is not serialized in the model directory. In order to load such model you need to manually set the
|
Hi, Frank,
Thanks for looking into this. Added the block and the issue is the same. As
I put before, the load and save implementation does not match. Are you able
to make this piece of code work?
Cheers,
Freeman
…On Thu, Jun 2, 2022 at 7:05 AM Frank Liu ***@***.***> wrote:
@freemanliu <https://github.com/freemanliu>
Sorry for the delay.
When you training a model in DJL, the trainer only save the model's
parameters. The block information is not serialized in the model directory.
In order to load such model you need to manually set the Block before you
load the model:
model2.block = Mlp(2, 1, intArrayOf(10))
# the model prefix you provide was also wrong in your code, it should be:
model2.load(Path.of("/tmp"), "predictorAndTrainer")
—
Reply to this email directly, view it on GitHub
<#1663 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AJKX36ESDCVC5OPUQAKOFRTVM7GBNANCNFSM5WIOWTGA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
--
Language? Kotlin, Typescript or Rust?
All of them!
|
@freemanliu
The output is:
|
Hi, Frank,
Thanks a lot for that!
I added the block and it still does not work. Following the code and I
found the bug! in ai.djl.util.Utils.getCurrentEpoch at line 246:
Files.walk(modelDir, 1)
It does not look into a modelDir if it is a symlink. Adding
FileVisitOption.FOLLOW_LINK should fix it.
Cheers,
Freeman
…On Thu, Jun 2, 2022 at 1:40 PM Frank Liu ***@***.***> wrote:
@freemanliu <https://github.com/freemanliu>
I tested your code in java, and it's working:
public static void main(String[] args) throws IOException, TranslateException, MalformedModelException {
System.setProperty("ai.djl.default_engine", "PyTorch");
Block mlp = new Mlp(2, 1, new int[] {10});
Model model = Model.newInstance("model");
model.setBlock(mlp);
Trainer trainer = model.newTrainer(new DefaultTrainingConfig(Loss.l2Loss()));
trainer.initialize(new Shape(2));
NDManager manager = model.getNDManager();
NDArray input = manager.ones(new Shape(1, 2), DataType.FLOAT32);
NDArray label = manager.create(new float[] {0.5f});
ArrayDataset trainingDs = new ArrayDataset.Builder().setData(input)
.optLabels(label).setSampling(1, false).build();
EasyTrain.fit(trainer, 100, trainingDs, trainingDs);
Path dir = Paths.get("build/mlp");
Files.createDirectories(dir);
model.save(dir, "predictorAndTrainer");
Model model2 = Model.newInstance("model");
model2.setBlock(mlp);
model2.load(dir, "predictorAndTrainer");
Predictor<NDList, NDList> p2 = model2.newPredictor(new NoopTranslator());
NDManager manager2 = NDManager.newBaseManager();
NDList output = p2.predict(new NDList(manager2.ones(new Shape(1, 2))));
System.out.println(output.get(0));
}
The output is:
ND: (1, 1) cpu() float32
[[0.4958],
]
—
Reply to this email directly, view it on GitHub
<#1663 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AJKX36DJHQDYOQADB5U56BDVNAUJ5ANCNFSM5WIOWTGA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
--
Language? Kotlin, Typescript or Rust?
All of them!
|
Since we only look 1 level of the directory, |
Fixed by #1692 |
Description
This is reproduceable with the following test case:
@test
fun testModelLoad() {
var model = Model.newInstance("model")
model.block = Mlp(2, 1, intArrayOf(10))
model.newTrainer(DefaultTrainingConfig(Loss.l2Loss())).use { trainer ->
trainer.initialize(Shape(2))
val manager = model.ndManager;
val input = manager.ones(Shape(1, 2), DataType.FLOAT32)
val label = manager.create(floatArrayOf(0.5f))
val trainingDs = ArrayDataset.Builder().setData(input)
.optLabels(label).setSampling(1, false).build()
EasyTrain.fit(trainer, 100, trainingDs, trainingDs)
model.save(Path.of("/tmp"), "predictorAndTrainer")
}
val model2 = Model.newInstance("model")
model2.load(Path.of("/tmp"), "model")
val p2 = model2.newPredictor(NoopTranslator())
NDManager.newBaseManager().use { manager ->
println(p2.predict(NDList(manager.ones(Shape(1, 2)))))
}
}
Here is the gradle dependency to use pytorch engine.
implementation 'ai.djl:basicdataset:0.17.0'
implementation 'ai.djl:model-zoo:0.17.0'
implementation 'ai.djl.pytorch:pytorch-model-zoo:0.17.0'
Further investigation shows that the save() is done in BaseModel while the load is done in PtModel. I was expecting the save() is also done in PtModel.
Expected Behavior
model.load succeeds.
Error Message
model.pt file not found in: /tmp
java.io.FileNotFoundException: model.pt file not found in: /tmp
at ai.djl.pytorch.engine.PtModel.load(PtModel.java:74)
at ai.djl.Model.load(Model.java:121)
at helloworld.jdl.AppTest.testModelLoad(AppTest.kt:79)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.junit.runners.model.FrameworkMethod$1.runReflectiveCall(FrameworkMethod.java:50)
at org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12)
at org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:47)
at org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17)
at org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:325)
at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:78)
at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:57)
at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290)
at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71)
at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288)
at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58)
at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268)
at org.junit.runners.ParentRunner.run(ParentRunner.java:363)
at org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.runTestClass(JUnitTestClassExecutor.java:110)
at org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.execute(JUnitTestClassExecutor.java:58)
at org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.execute(JUnitTestClassExecutor.java:38)
at org.gradle.api.internal.tasks.testing.junit.AbstractJUnitTestClassProcessor.processTestClass(AbstractJUnitTestClassProcessor.java:62)
at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.processTestClass(SuiteTestClassProcessor.java:51)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
at com.sun.proxy.$Proxy2.processTestClass(Unknown Source)
at org.gradle.api.internal.tasks.testing.worker.TestWorker$2.run(TestWorker.java:176)
at org.gradle.api.internal.tasks.testing.worker.TestWorker.executeAndMaintainThreadName(TestWorker.java:129)
at org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:100)
at org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:60)
at org.gradle.process.internal.worker.child.ActionExecutionWorker.execute(ActionExecutionWorker.java:56)
at org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:133)
at org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:71)
at worker.org.gradle.process.internal.worker.GradleWorkerMain.run(GradleWorkerMain.java:69)
at worker.org.gradle.process.internal.worker.GradleWorkerMain.main(GradleWorkerMain.java:74)
How to Reproduce?
(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)
Steps to reproduce
(Paste the commands you ran that produced the error.)
What have you tried to solve it?
Environment Info
Please run the command
./gradlew debugEnv
from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:The text was updated successfully, but these errors were encountered: