Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DL4J: Add way to load/convert models to specific datatype #7520

Closed
dollarHome opened this issue Apr 10, 2019 · 2 comments · Fixed by #7531

Comments

@dollarHome
Copy link

commented Apr 10, 2019

Issue Description

I was looking for a way to load a pretrained FP16 Resnet50 model and run on CPUs. There is no direct way to do that but a round about way as suggested by @raver119 on gitter chat (4/10/2019)

@raver119 suggested:
"to get resnet on fp16 right now you'll probably have to cast params to fp16, and assign these params to your model
but i think we should improve that
cc @AlexDBlack ^^^
there's also support for bfloat16 planned
it's already available in c++, just wasn't introduced to java yet
but params of your nn is just INDArray
but you'd better file an issue
and we'll provide convenient method to do that
we'll need to do that for quantized types anyway"

  • Deeplearning4j version - snapshots
  • platform information (OS, etc): Linux, Ubuntu18.04, 28c SKX Xeon system

@AlexDBlack AlexDBlack changed the title DL4J: Need an efficient way to load and run pretrained FP16 models on CPU DL4J: Add way to load/convert models to specific datatype Apr 11, 2019

@AlexDBlack

This comment has been minimized.

Copy link
Contributor

commented Apr 11, 2019

Yes, we will want this functionality.

I man thinking we could have an API like:

  • MultiLayerNetwork.load(File, DataType) - "load as FP16 regardless of what the model is saved as"
  • MultiLayerNetwork.convertTo(DataType) - "recreate the network with the specified data type"

Of course we'll want to add the equivalent methods for ComputationGraph.
We should think about this for SameDiff also, though conversion might be on a per variable basis there...

@AlexDBlack AlexDBlack added this to the beta4 release milestone Apr 11, 2019

@AlexDBlack AlexDBlack self-assigned this Apr 11, 2019

AlexDBlack added a commit that referenced this issue Apr 11, 2019
AlexDBlack added a commit that referenced this issue Apr 17, 2019
[WIP] QA, fixes, DL4J net convertDataType methods (#7531)
* Fix BaseNDArray.equalsWithEps issue for scalars of different ranks

* #7447 Fix slice on row vector

* #7483 Remove old deserialization warnings

* #6861 SameDiff datatype validation, round 1

* #6861 SameDiff datatype validation, round 2

* #6861 SameDiff datatype validation, round 3

* More rank 2 minimum shape fixes

* Multiple test fixes after changing rank2 minimum shapes

* Test fixes

* #7520 add MultiLayerNetwork.convertDataType(DataType) + test

* Datatype cleanup and fixes

* DL4J: Fixes for global (default) vs. network datatypes

* Fix incorrect datatype when arrays (different to default dtype) are detached

* Multiple fixes, improve tests

* Test

* #7532 New network datatype configuration

* Pass network dtype to layer/vertex initialization

* Yolo datatype fixes

* More fixes, more tests

* More fixes, more tests

* Fix bug in PoolHelperVertex backprop

* Vertex dtype tests; misc fixes

* Fix for BaseReduce3Op dtype

* More fix; finally all layers/vertices/preprocessors tested for dtypes

* Fix slices()

* Fixes - gradient check dtype issues

* Pass network dtype when constructing layers

* Pass network dtype when constructing vertices

* Layer dtype/casting fixes

* Various fixes

* Fix Shape.elementWiseStride for 1d view case

* #7092 INDArray.get(point,x)/get(x,point) returns 1d array

* More 1d getRow/getCol fixes

* Indexing/sub-array fixes

* More test and indexing fixes

* More test fixes, add getRow(i,keepDim) and getColumn(i,keepDim)

* More indexing/test fixes

* More fixes

* More fixes

* More fixes

* #7550 Evaluation dtype tests + fixes

* Nd4j.gemm result dtype fix

* Next round of fixes

* Even more dtype fixes...

* Datavec and more DL4J fixes

* Next round of fixes

* DL4J cuDNN helpers - dtype improvements/fixes

* Another round of fixes

* Datavec fixes

* DL4J Fixes

* Keras/Spark/elementwisevertex fixes

* Final (hopefully) fixes

* Last set of fixes
@lock

This comment has been minimized.

Copy link

commented May 17, 2019

This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.

@lock lock bot locked and limited conversation to collaborators May 17, 2019

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
2 participants
You can’t perform that action at this time.