Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model load failed #1663

Closed
freemanliu opened this issue May 18, 2022 · 7 comments
Closed

Model load failed #1663

freemanliu opened this issue May 18, 2022 · 7 comments
Labels
bug Something isn't working

Comments

@freemanliu
Copy link
Contributor

Description

This is reproduceable with the following test case:
@test
fun testModelLoad() {
var model = Model.newInstance("model")
model.block = Mlp(2, 1, intArrayOf(10))
model.newTrainer(DefaultTrainingConfig(Loss.l2Loss())).use { trainer ->
trainer.initialize(Shape(2))
val manager = model.ndManager;
val input = manager.ones(Shape(1, 2), DataType.FLOAT32)
val label = manager.create(floatArrayOf(0.5f))
val trainingDs = ArrayDataset.Builder().setData(input)
.optLabels(label).setSampling(1, false).build()
EasyTrain.fit(trainer, 100, trainingDs, trainingDs)
model.save(Path.of("/tmp"), "predictorAndTrainer")
}
val model2 = Model.newInstance("model")
model2.load(Path.of("/tmp"), "model")
val p2 = model2.newPredictor(NoopTranslator())
NDManager.newBaseManager().use { manager ->
println(p2.predict(NDList(manager.ones(Shape(1, 2)))))
}
}

Here is the gradle dependency to use pytorch engine.
implementation 'ai.djl:basicdataset:0.17.0'
implementation 'ai.djl:model-zoo:0.17.0'
implementation 'ai.djl.pytorch:pytorch-model-zoo:0.17.0'

Further investigation shows that the save() is done in BaseModel while the load is done in PtModel. I was expecting the save() is also done in PtModel.

Expected Behavior

model.load succeeds.

Error Message

model.pt file not found in: /tmp
java.io.FileNotFoundException: model.pt file not found in: /tmp
at ai.djl.pytorch.engine.PtModel.load(PtModel.java:74)
at ai.djl.Model.load(Model.java:121)
at helloworld.jdl.AppTest.testModelLoad(AppTest.kt:79)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.junit.runners.model.FrameworkMethod$1.runReflectiveCall(FrameworkMethod.java:50)
at org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12)
at org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:47)
at org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17)
at org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:325)
at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:78)
at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:57)
at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290)
at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71)
at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288)
at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58)
at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268)
at org.junit.runners.ParentRunner.run(ParentRunner.java:363)
at org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.runTestClass(JUnitTestClassExecutor.java:110)
at org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.execute(JUnitTestClassExecutor.java:58)
at org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.execute(JUnitTestClassExecutor.java:38)
at org.gradle.api.internal.tasks.testing.junit.AbstractJUnitTestClassProcessor.processTestClass(AbstractJUnitTestClassProcessor.java:62)
at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.processTestClass(SuiteTestClassProcessor.java:51)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
at com.sun.proxy.$Proxy2.processTestClass(Unknown Source)
at org.gradle.api.internal.tasks.testing.worker.TestWorker$2.run(TestWorker.java:176)
at org.gradle.api.internal.tasks.testing.worker.TestWorker.executeAndMaintainThreadName(TestWorker.java:129)
at org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:100)
at org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:60)
at org.gradle.process.internal.worker.child.ActionExecutionWorker.execute(ActionExecutionWorker.java:56)
at org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:133)
at org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:71)
at worker.org.gradle.process.internal.worker.GradleWorkerMain.run(GradleWorkerMain.java:69)
at worker.org.gradle.process.internal.worker.GradleWorkerMain.main(GradleWorkerMain.java:74)

How to Reproduce?

(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)

Steps to reproduce

(Paste the commands you ran that produced the error.)

What have you tried to solve it?

Environment Info

Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:

> Task :api:compileJava UP-TO-DATE
> Task :api:processResources UP-TO-DATE
> Task :api:classes UP-TO-DATE
> Task :api:jar UP-TO-DATE
> Task :basicdataset:compileJava UP-TO-DATE
> Task :basicdataset:processResources UP-TO-DATE
> Task :basicdataset:classes UP-TO-DATE
> Task :basicdataset:jar UP-TO-DATE
> Task :model-zoo:compileJava UP-TO-DATE
> Task :testing:compileJava UP-TO-DATE
> Task :integration:compileJava UP-TO-DATE
> Task :integration:processResources UP-TO-DATE
> Task :integration:classes UP-TO-DATE
> Task :model-zoo:processResources UP-TO-DATE
> Task :model-zoo:classes UP-TO-DATE
> Task :model-zoo:jar UP-TO-DATE
> Task :testing:processResources NO-SOURCE
> Task :testing:classes UP-TO-DATE
> Task :testing:jar UP-TO-DATE
> Task :engines:ml:xgboost:compileJava UP-TO-DATE
> Task :engines:ml:xgboost:processResources UP-TO-DATE
> Task :engines:ml:xgboost:classes UP-TO-DATE
> Task :engines:ml:xgboost:jar UP-TO-DATE
> Task :engines:mxnet:jnarator:generateGrammarSource UP-TO-DATE
> Task :engines:mxnet:jnarator:compileJava UP-TO-DATE
> Task :engines:mxnet:jnarator:processResources UP-TO-DATE
> Task :engines:mxnet:jnarator:classes UP-TO-DATE
> Task :engines:mxnet:jnarator:jar UP-TO-DATE
> Task :engines:mxnet:mxnet-engine:jnarator UP-TO-DATE
> Task :engines:mxnet:mxnet-engine:compileJava UP-TO-DATE
> Task :engines:mxnet:mxnet-engine:processResources UP-TO-DATE
> Task :engines:mxnet:mxnet-engine:classes UP-TO-DATE
> Task :engines:mxnet:mxnet-engine:jar UP-TO-DATE
> Task :engines:mxnet:mxnet-model-zoo:compileJava UP-TO-DATE
> Task :engines:mxnet:mxnet-model-zoo:processResources UP-TO-DATE
> Task :engines:mxnet:mxnet-model-zoo:classes UP-TO-DATE
> Task :engines:mxnet:mxnet-model-zoo:jar UP-TO-DATE
> Task :engines:pytorch:pytorch-engine:processResources UP-TO-DATE
> Task :engines:pytorch:pytorch-engine:compileJava UP-TO-DATE
> Task :engines:pytorch:pytorch-engine:classes UP-TO-DATE
> Task :engines:pytorch:pytorch-engine:jar UP-TO-DATE
> Task :engines:pytorch:pytorch-jni:processResources UP-TO-DATE
> Task :engines:pytorch:pytorch-jni:compileJava NO-SOURCE
> Task :engines:pytorch:pytorch-jni:classes UP-TO-DATE
> Task :engines:pytorch:pytorch-jni:jar UP-TO-DATE
> Task :engines:pytorch:pytorch-model-zoo:compileJava UP-TO-DATE
> Task :engines:pytorch:pytorch-model-zoo:processResources UP-TO-DATE
> Task :engines:pytorch:pytorch-model-zoo:classes UP-TO-DATE
> Task :engines:pytorch:pytorch-model-zoo:jar UP-TO-DATE
> Task :engines:tensorflow:tensorflow-api:compileJava NO-SOURCE
> Task :engines:tensorflow:tensorflow-api:processResources UP-TO-DATE
> Task :engines:tensorflow:tensorflow-api:classes UP-TO-DATE
> Task :engines:tensorflow:tensorflow-api:jar UP-TO-DATE
> Task :engines:tensorflow:tensorflow-engine:compileJava UP-TO-DATE
> Task :engines:tensorflow:tensorflow-engine:processResources UP-TO-DATE
> Task :engines:tensorflow:tensorflow-engine:classes UP-TO-DATE
> Task :engines:tensorflow:tensorflow-engine:jar UP-TO-DATE
> Task :engines:tensorflow:tensorflow-model-zoo:compileJava UP-TO-DATE
> Task :engines:tensorflow:tensorflow-model-zoo:processResources UP-TO-DATE
> Task :engines:tensorflow:tensorflow-model-zoo:classes UP-TO-DATE
> Task :engines:tensorflow:tensorflow-model-zoo:jar UP-TO-DATE

> Task :integration:debugEnv
[DEBUG] - Registering EngineProvider: XGBoost
[DEBUG] - Registering EngineProvider: MXNet
[DEBUG] - Registering EngineProvider: PyTorch
[DEBUG] - Registering EngineProvider: TensorFlow
[DEBUG] - Found default engine: MXNet
----------- System Properties -----------
gopherProxySet: false
awt.toolkit: sun.lwawt.macosx.LWCToolkit
java.specification.version: 11
sun.cpu.isalist: 
sun.jnu.encoding: UTF-8
java.class.path: /Users/freeman.liu/codes/djl/integration/build/classes/java/main:/Users/freeman.liu/codes/djl/integration/build/resources/main:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/commons-cli/commons-cli/1.5.0/dc98be5d5390230684a092589d70ea76a147925c/commons-cli-1.5.0.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-slf4j-impl/2.17.2/183f7c95fc981f3e97d008b363341343508848e/log4j-slf4j-impl-2.17.2.jar:/Users/freeman.liu/codes/djl/basicdataset/build/libs/basicdataset-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/model-zoo/build/libs/model-zoo-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/testing/build/libs/testing-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.testng/testng/7.5/1416a607fae667c14e390b484e8d02b5824c0674/testng-7.5.jar:/Users/freeman.liu/codes/djl/engines/mxnet/mxnet-model-zoo/build/libs/mxnet-model-zoo-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/pytorch/pytorch-model-zoo/build/libs/pytorch-model-zoo-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/pytorch/pytorch-jni/build/libs/pytorch-jni-1.11.0-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/tensorflow/tensorflow-model-zoo/build/libs/tensorflow-model-zoo-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/ml/xgboost/build/libs/xgboost-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/mxnet/mxnet-engine/build/libs/mxnet-engine-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/pytorch/pytorch-engine/build/libs/pytorch-engine-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/tensorflow/tensorflow-engine/build/libs/tensorflow-engine-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/api/build/libs/api-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.slf4j/slf4j-api/1.7.36/6c62681a2f655b49963a5983b8b0950a6120ae14/slf4j-api-1.7.36.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-core/2.17.2/fa43ba4467f5300b16d1e0742934149bfc5ac564/log4j-core-2.17.2.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-api/2.17.2/f42d6afa111b4dec5d2aea0fe2197240749a4ea6/log4j-api-2.17.2.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.apache.commons/commons-csv/1.9.0/b59d8f64cd0b83ee1c04ff1748de2504457018c1/commons-csv-1.9.0.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/com.google.code.findbugs/jsr305/3.0.1/f7be08ec23c21485b9b5a1cf1654c2ec8c58168d/jsr305-3.0.1.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/com.beust/jcommander/1.78/a3927de9bd6f351429bcf763712c9890629d8f51/jcommander-1.78.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.webjars/jquery/3.5.1/2392938e374f561c27c53872bdc9b6b351b6ba34/jquery-3.5.1.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/ml.dmlc/xgboost4j_2.12/1.6.0/4623e78f614c998b4600c1cc58441ce06d80ba49/xgboost4j_2.12-1.6.0.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/commons-logging/commons-logging/1.2/4bfc12adfe4842bf07b657f0369c4cb522955686/commons-logging-1.2.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/com.google.code.gson/gson/2.9.0/8a1167e089096758b49f9b34066ef98b2f4b37aa/gson-2.9.0.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/net.java.dev.jna/jna/5.10.0/7cf4c87dd802db50721db66947aa237d7ad09418/jna-5.10.0.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.apache.commons/commons-compress/1.21/4ec95b60d4e86b5c95a0e919cb172a0af98011ef/commons-compress-1.21.jar:/Users/freeman.liu/codes/djl/engines/tensorflow/tensorflow-api/build/libs/tensorflow-api-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.tensorflow/tensorflow-core-api/0.4.0/2ac35ca087607cce0e5419953cc1ef0c3a5edaea/tensorflow-core-api-0.4.0.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.bytedeco/javacpp/1.5.6/1f18a820aadd943577b0b372554f9e35e1232e25/javacpp-1.5.6.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/com.google.protobuf/protobuf-java/3.19.2/e958ce38f96b612d3819ff1c753d4d70609aea74/protobuf-java-3.19.2.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.tensorflow/ndarray/0.3.3/1b6d8cc3e3762f6e465b884580d9fc17ab7aeb4/ndarray-0.3.3.jar
java.vm.vendor: AdoptOpenJDK
sun.arch.data.model: 64
user.variant: 
java.vendor.url: https://adoptopenjdk.net/
user.timezone: Australia/Sydney
os.name: Mac OS X
java.vm.specification.version: 11
sun.java.launcher: SUN_STANDARD
user.country: AU
sun.boot.library.path: /Library/Java/JavaVirtualMachines/adoptopenjdk-11.jdk/Contents/Home/lib:/Library/Java/JavaVirtualMachines/adoptopenjdk-11.jdk/Contents/Home/lib
sun.java.command: ai.djl.integration.util.DebugEnvironment
jdk.debug: release
sun.cpu.endian: little
user.home: /Users/freeman.liu
org.gradle.appname: gradlew
user.language: en
java.specification.vendor: Oracle Corporation
java.version.date: 2021-04-20
java.home: /Library/Java/JavaVirtualMachines/adoptopenjdk-11.jdk/Contents/Home
ai.djl.logging.level: debug
org.gradle.internal.http.connectionTimeout: 60000
file.separator: /
java.vm.compressedOopsMode: Zero based
line.separator: 

java.specification.name: Java Platform API Specification
java.vm.specification.vendor: Oracle Corporation
java.awt.graphicsenv: sun.awt.CGraphicsEnvironment
sun.management.compiler: HotSpot 64-Bit Tiered Compilers
java.runtime.version: 11.0.11+9
user.name: freeman.liu
path.separator: :
os.version: 11.5
java.runtime.name: OpenJDK Runtime Environment
file.encoding: UTF-8
java.vm.name: OpenJDK 64-Bit Server VM
java.vendor.version: AdoptOpenJDK-11.0.11+9
java.vendor.url.bug: https://github.com/AdoptOpenJDK/openjdk-support/issues
java.io.tmpdir: /var/folders/4r/gxktgtgn2277w32wxlm56t080000gp/T/
org.gradle.internal.http.socketTimeout: 120000
java.version: 11.0.11
user.dir: /Users/freeman.liu/codes/djl/integration
os.arch: x86_64
java.vm.specification.name: Java Virtual Machine Specification
java.awt.printerjob: sun.lwawt.macosx.CPrinterJob
sun.os.patch.level: unknown
java.library.path: /Users/freeman.liu/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.
java.vm.info: mixed mode
java.vendor: AdoptOpenJDK
java.vm.version: 11.0.11+9
sun.io.unicode.encoding: UnicodeBig
library.jansi.path: /Users/freeman.liu/.gradle/native/jansi/1.18/osx
java.class.version: 55.0
org.gradle.internal.publish.checksums.insecure: true

--------- Environment Variables ---------
PATH: /usr/local/opt/node@16/bin:/Users/freeman.liu/.amplify/bin:/usr/local/opt/node@12/bin:/Users/freeman.liu/bin:/usr/local/opt/node@16/bin:/Users/freeman.liu/.amplify/bin:/Users/freeman.liu/bin:/usr/local/opt/node@16/bin:/Users/freeman.liu/.cargo/bin:/Users/freeman.liu/.amplify/bin:/usr/local/opt/node@12/bin:/Users/freeman.liu/bin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin:/usr/local/go/bin:/opt/puppetlabs/pdk/bin:/Library/Apple/usr/bin:/usr/local/bin:/Users/freeman.liu/Library/Python/3.9/bin:/usr/local/bin:/Users/freeman.liu/Library/Python/3.9/bin
APP_ICON_39014: /Users/freeman.liu/codes/djl/media/gradle.icns
APP_NAME_39014: Gradle
WORKON_HOME: /Users/freeman.liu/.virtualenvs
TERM: screen-bce
LANG: en_AU.UTF-8
VIRTUALENVWRAPPER_SCRIPT: /usr/local/bin/virtualenvwrapper.sh
VIRTUALENVWRAPPER_WORKON_CD: 1
STY: 36507.ttys000.MREM277DB4AD
LOGNAME: freeman.liu
XPC_SERVICE_NAME: 0
PWD: /Users/freeman.liu/codes/djl
TERM_PROGRAM_VERSION: 440
JAVA_MAIN_CLASS_39026: ai.djl.integration.util.DebugEnvironment
__CFBundleIdentifier: com.apple.Terminal
SHELL: /usr/local/bin/bash
TERM_PROGRAM: Apple_Terminal
SECURITYSESSIONID: 186aa
OLDPWD: /Users/freeman.liu/codes/djl
VIRTUALENVWRAPPER_HOOK_DIR: /Users/freeman.liu/.virtualenvs
USER: freeman.liu
WINDOW: 7
LaunchInstanceID: 57D413AC-7DB9-4C1A-BE2D-EE29A69C8716
TMPDIR: /var/folders/4r/gxktgtgn2277w32wxlm56t080000gp/T/
SSH_AUTH_SOCK: /private/tmp/com.apple.launchd.dnFo9uppka/Listeners
XPC_FLAGS: 0x0
LIBTORCH: /Users/freeman.liu/libtorch
TERM_SESSION_ID: 73D5F5B6-416C-42D9-B8D1-A124365689AE
VIRTUALENVWRAPPER_PROJECT_FILENAME: .project
TERMCAP: SC|screen-bce|VT 100/ANSI X3.64 virtual terminal:DO=\E[%dB:LE=\E[%dD:RI=\E[%dC:UP=\E[%dA:bs:bt=\E[Z:cd=\E[J:ce=\E[K:cl=\E[H\E[J:cm=\E[%i%d;%dH:ct=\E[3g:do=^J:nd=\E[C:pt:rc=\E8:rs=\Ec:sc=\E7:st=\EH:up=\EM:le=^H:bl=^G:cr=^M:it#8:ho=\E[H:nw=\EE:ta=^I:is=\E)0:li#45:co#178:am:xn:xv:LP:sr=\EM:al=\E[L:AL=\E[%dL:cs=\E[%i%d;%dr:dl=\E[M:DL=\E[%dM:dc=\E[P:DC=\E[%dP:im=\E[4h:ei=\E[4l:mi:IC=\E[%d@:ks=\E[?1h\E=:ke=\E[?1l\E>:vi=\E[?25l:ve=\E[34h\E[?25h:vs=\E[34l:ti=\E[?1049h:te=\E[?1049l:us=\E[4m:ue=\E[24m:so=\E[3m:se=\E[23m:mb=\E[5m:md=\E[1m:mr=\E[7m:me=\E[m:ms:Co#8:pa#64:AF=\E[3%dm:AB=\E[4%dm:op=\E[39;49m:AX:vb=\Eg:G0:as=\E(0:ae=\E(B:ac=\140\140aaffggjjkkllmmnnooppqqrrssttuuvvwwxxyyzz{{||}}~~..--++,,hhII00:po=\E[5i:pf=\E[4i:Km=\E[M:k0=\E[10~:k1=\EOP:k2=\EOQ:k3=\EOR:k4=\EOS:k5=\E[15~:k6=\E[17~:k7=\E[18~:k8=\E[19~:k9=\E[20~:k;=\E[21~:F1=\E[23~:F2=\E[24~:kB=\E[Z:kh=\E[1~:@1=\E[1~:kH=\E[4~:@7=\E[4~:kN=\E[6~:kP=\E[5~:kI=\E[2~:kD=\E[3~:ku=\EOA:kd=\EOB:kr=\EOC:kl=\EOD:km:
__CF_USER_TEXT_ENCODING: 0x1F6:0x0:0xF
PROJECT_HOME: /Users/freeman.liu/dev
JAVA_MAIN_CLASS_39014: org.gradle.wrapper.GradleWrapperMain
HOME: /Users/freeman.liu
SHLVL: 2

-------------- Directories --------------
temp directory: /var/folders/4r/gxktgtgn2277w32wxlm56t080000gp/T
DJL cache directory: /Users/freeman.liu/.djl.ai
Engine cache directory: /Users/freeman.liu/.djl.ai

------------------ CUDA -----------------
[DEBUG] - cudart library not found.
GPU Count: 0

----------------- Engines ---------------
DJL version: 0.18.0
Default Engine: MXNet
[DEBUG] - Using cache dir: /Users/freeman.liu/.djl.ai/mxnet/1.9.0-mkl-osx-x86_64
[DEBUG] - Loading mxnet library from: /Users/freeman.liu/.djl.ai/mxnet/1.9.0-mkl-osx-x86_64/libmxnet.dylib
Default Device: cpu()
PyTorch: 2
MXNet: 0
XGBoost: 10
TensorFlow: 3

--------------- Hardware --------------
Available processors (cores): 12
Byte Order: LITTLE_ENDIAN
Free memory (bytes): 240247000
Maximum memory (bytes): 4294967296
Total memory available to JVM (bytes): 268435456
Heap committed: 268435456
Heap nonCommitted: 30474240
GCC: 
Apple clang version 13.0.0 (clang-1300.0.29.30)
Target: x86_64-apple-darwin20.6.0
Thread model: posix
InstalledDir: /Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin

BUILD SUCCESSFUL in 2s
44 actionable tasks: 1 executed, 43 up-to-date

@freemanliu freemanliu added the bug Something isn't working label May 18, 2022
@freemanliu
Copy link
Contributor Author

freemanliu commented May 19, 2022

Please fix this asap. This is a deal break. I won't be able to load the trained model. I've to turn to other framework.

@frankfliu
Copy link
Contributor

@freemanliu

Sorry for the delay.

When you training a model in DJL, the trainer only save the model's parameters. The block information is not serialized in the model directory. In order to load such model you need to manually set the Block before you load the model:

        model2.block = Mlp(2, 1, intArrayOf(10))
        # the model prefix you provide was also wrong in your code, it should be:
        model2.load(Path.of("/tmp"), "predictorAndTrainer")

@freemanliu
Copy link
Contributor Author

freemanliu commented Jun 2, 2022 via email

@frankfliu
Copy link
Contributor

@freemanliu
I tested your code in java, and it's working:

    public static void main(String[] args) throws IOException, TranslateException, MalformedModelException {
        System.setProperty("ai.djl.default_engine", "PyTorch");
        Block mlp = new Mlp(2, 1, new int[] {10});
        Model model = Model.newInstance("model");
        model.setBlock(mlp);
        Trainer trainer = model.newTrainer(new DefaultTrainingConfig(Loss.l2Loss()));
        trainer.initialize(new Shape(2));
        NDManager manager = model.getNDManager();
        NDArray input = manager.ones(new Shape(1, 2), DataType.FLOAT32);
        NDArray label = manager.create(new float[] {0.5f});
        ArrayDataset trainingDs = new ArrayDataset.Builder().setData(input)
                .optLabels(label).setSampling(1, false).build();
        EasyTrain.fit(trainer, 100, trainingDs, trainingDs);

        Path dir = Paths.get("build/mlp");
        Files.createDirectories(dir);
        model.save(dir, "predictorAndTrainer");

        Model model2 = Model.newInstance("model");
        model2.setBlock(mlp);
        model2.load(dir, "predictorAndTrainer");

        Predictor<NDList, NDList> p2 = model2.newPredictor(new NoopTranslator());
        NDManager manager2 = NDManager.newBaseManager();
        NDList output = p2.predict(new NDList(manager2.ones(new Shape(1, 2))));
        System.out.println(output.get(0));
    }

The output is:

ND: (1, 1) cpu() float32
[[0.4958],
]

@freemanliu
Copy link
Contributor Author

freemanliu commented Jun 2, 2022 via email

@frankfliu
Copy link
Contributor

@freemanliu

Since we only look 1 level of the directory, FileVisitOption.FOLLOW_LINK should work here. Would you mind raise an PR to improve this?

@frankfliu
Copy link
Contributor

Fixed by #1692

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants