Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to run FLOAT16 OnnxRuntime models #3190

Open
zaobao opened this issue May 14, 2024 · 3 comments
Open

How to run FLOAT16 OnnxRuntime models #3190

zaobao opened this issue May 14, 2024 · 3 comments
Labels
enhancement New feature or request

Comments

@zaobao
Copy link

zaobao commented May 14, 2024

It looks like FLOAT16 has not been supported by now

Caused by: java.lang.UnsupportedOperationException: type is not supported: FLOAT16
at ai.djl.onnxruntime.engine.OrtUtils.toDataType(OrtUtils.java:101)
at ai.djl.onnxruntime.engine.OrtNDArray.getDataType(OrtNDArray.java:65)
at ai.djl.onnxruntime.engine.OrtNDArray.toByteBuffer(OrtNDArray.java:121)
at ai.djl.pytorch.engine.PtNDManager.from(PtNDManager.java:55)
at ai.djl.pytorch.engine.PtNDManager.from(PtNDManager.java:31)
at ai.djl.ndarray.NDArrayAdapter.getAlternativeArray(NDArrayAdapter.java:1315)
at ai.djl.ndarray.NDArrayAdapter.split(NDArrayAdapter.java:876)
at ai.djl.ndarray.NDArray.split(NDArray.java:3173)
at ai.djl.translate.StackBatchifier.unbatchify(StackBatchifier.java:118)
at ai.djl.huggingface.translator.CrossEncoderBatchTranslator.processOutput(CrossEncoderBatchTranslator.java:60)
at ai.djl.huggingface.translator.CrossEncoderBatchTranslator.processOutput(CrossEncoderBatchTranslator.java:30)
at ai.djl.inference.Predictor.batchPredict(Predictor.java:173)
... 5 more
@zaobao zaobao added the enhancement New feature or request label May 14, 2024
@frankfliu
Copy link
Contributor

You can convert model to fp16, but you need to CUDA device. You can use the following command:

djl-convert -o model -f OnnxRuntime -m <MODEL_ID> --optimize O4 --device cuda

see: https://github.com/deepjavalibrary/djl/tree/master/extensions/tokenizers#use-command-line

@zaobao
Copy link
Author

zaobao commented May 14, 2024

You can convert model to fp16, but you need to CUDA device. You can use the following command:

djl-convert -o model -f OnnxRuntime -m <MODEL_ID> --optimize O4 --device cuda

see: https://github.com/deepjavalibrary/djl/tree/master/extensions/tokenizers#use-command-line

I converted the model to fp16 and encountered an exception while loading fp16 model with CrossEncoderBatchTranslator

Caused by: java.lang.UnsupportedOperationException: type is not supported: FLOAT16
at ai.djl.onnxruntime.engine.OrtUtils.toDataType(OrtUtils.java:101)
at ai.djl.onnxruntime.engine.OrtNDArray.getDataType(OrtNDArray.java:65)
at ai.djl.onnxruntime.engine.OrtNDArray.toByteBuffer(OrtNDArray.java:121)
at ai.djl.pytorch.engine.PtNDManager.from(PtNDManager.java:55)
at ai.djl.pytorch.engine.PtNDManager.from(PtNDManager.java:31)
at ai.djl.ndarray.NDArrayAdapter.getAlternativeArray(NDArrayAdapter.java:1315)
at ai.djl.ndarray.NDArrayAdapter.split(NDArrayAdapter.java:876)
at ai.djl.ndarray.NDArray.split(NDArray.java:3173)
at ai.djl.translate.StackBatchifier.unbatchify(StackBatchifier.java:118)
at ai.djl.huggingface.translator.CrossEncoderBatchTranslator.processOutput(CrossEncoderBatchTranslator.java:60)
at ai.djl.huggingface.translator.CrossEncoderBatchTranslator.processOutput(CrossEncoderBatchTranslator.java:30)
at ai.djl.inference.Predictor.batchPredict(Predictor.java:173)
... 5 more

    public static DataType toDataType(OnnxJavaType javaType) {
        switch (javaType) {
            case FLOAT:
                return DataType.FLOAT32;
            case DOUBLE:
                return DataType.FLOAT64;
            case INT8:
                return DataType.INT8;
            case UINT8:
                return DataType.UINT8;
            case INT32:
                return DataType.INT32;
            case INT64:
                return DataType.INT64;
            case BOOL:
                return DataType.BOOLEAN;
            case UNKNOWN:
                return DataType.UNKNOWN;
            case STRING:
                return DataType.STRING;
            default:
                throw new UnsupportedOperationException("type is not supported: " + javaType);
        }
    }

@zaobao
Copy link
Author

zaobao commented May 15, 2024

I added FLOAT16 in OrtUtils.toDataType and disabled CrossEncoderBatchTranslator.sigmoid(PyTorchLibrary doesn't support fp16 sigmoid op with cpu), and the problem was solved

public static DataType toDataType(OnnxJavaType javaType) {
        switch (javaType) {
            case FLOAT:
                return DataType.FLOAT32;
            case FLOAT16:
                return DataType.FLOAT16;
            case DOUBLE:
                return DataType.FLOAT64;
            case INT8:
                return DataType.INT8;
            case UINT8:
                return DataType.UINT8;
            case INT32:
                return DataType.INT32;
            case INT64:
                return DataType.INT64;
            case BOOL:
                return DataType.BOOLEAN;
            case UNKNOWN:
                return DataType.UNKNOWN;
            case STRING:
                return DataType.STRING;
            default:
                throw new UnsupportedOperationException("type is not supported: " + javaType);
        }
    }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants