Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to convert the pretrained model into onnx format? #45

Open
LancerComet opened this issue Aug 24, 2023 · 7 comments
Open

Comments

@LancerComet
Copy link

LancerComet commented Aug 24, 2023

Hi, after reading the code, it seems that the pretrained weights need to be used in conjunction with a tokenizer and some other libraries:

x = self._preprocess(img)
x = self.model.generate(x[None].to(self.model.device), max_length=300)[0].cpu()
x = self.tokenizer.decode(x, skip_special_tokens=True)
x = post_process(x)

def post_process(text):
    text = ''.join(text.split())
    text = text.replace('…', '...')
    text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text)
    text = jaconv.h2z(text, ascii=True, digit=True)

    return text

This makes it seem like there's no way to convert the model to the ONNX format for use in other languages. Do you have any thoughts on how to achieve this? I don't know too much about it, thank you very much!

@kha-white
Copy link
Owner

Conversion to ONNX is possible and has been done, but inference in other languages is not trivial, mainly because HuggingFace's generate method (basically beam search) would have to be reimplemented. Tokenizer and post-processing are relatively easy to substitute; tokenizer's decode can be replaced with a look-up table and the rest are some rather simple operations on strings.

@Mar2ck
Copy link
Contributor

Mar2ck commented Aug 25, 2023

An automated conversion of the model to ONNX can be found here https://huggingface.co/kha-white/manga-ocr-base/blob/refs%2Fpr%2F3/model.onnx

@LancerComet
Copy link
Author

@kha-white Thanks for your reply, I have converted an ONNX model with input and output shapes as follows:

Input Details:
Name: pixel_values, Shape: ['batch_size', 'num_channels', 'height', 'width'], Type: tensor(float)
Name: decoder_input_ids, Shape: ['batch_size', 'decoder_sequence_length'], Type: tensor(int64)

Output Details:
Name: logits, Shape: ['batch_size', 'decoder_sequence_length', 6144], Type: tensor(float)

I understand that pixel_values is a tensor of the bitmap, but I am unclear about decoder_input_ids. Does it come from data in the manga109 training set?

logits is a tensor of length 6144, I am currently guessing it is to be used in conjuction with vocab.txt.

@Mar2ck Wow I didn't see there is even a bot that turns model into onnx automatically! But I see it is a little different from mine:

Input Details:
Name: pixel_values, Shape: ['batch_size', 'num_channels', 'height', 'width'], Type: tensor(float)
Name: decoder_input_ids, Shape: ['batch_size', 'decoder_sequence_length'], Type: tensor(int64)

Output Details:
Name: logits, Shape: ['batch_size', 'decoder_sequence_length', 6144], Type: tensor(float)
Name: encoder_last_hidden_state, Shape: ['batch_size', 'encoder_sequence_length', 768], Type: tensor(float)

It has encoder_last_hidden_state in output which mine doesn't have. Have no clue why it happens.

@kha-white
Copy link
Owner

Ok so I don't know exactly how to do the inference in onnx (I played around with it a little bit but it seemed rather tricky to do and I abandoned/postponed it), so I'll just tell you what I know.

This model has an encoder-decoder architecture. The encoder gets an image as an input and outputs a feature vector (this is encoder_last_hidden_state). Then, this feature vector is passed to the decoder, which is run iteratively, outputting one token at a time until reaching a special token indicating the end of the sequence. At each step, the decoder is being fed the sequence of all the tokens it had outputted so far (this is decoder_input_ids). The decoder outputs the tokens as logits - each token is represented by a vector of 6144 values corresponding to the vocab, as you correctly noticed. You can take the argmax from the logits vector to get the most probable token, but what actually happens is the logits are converted to probabilities and top N hypotheses are considered at each step (beam search). This is done by HuggingFace's generate method, which is unfortunately quite complex.

The tricky part is replicating the beam search (although it could be replaced with a simpler greedy search at cost of some accuracy drop) and getting all the little details right when passing around the tensors.

BTW I suppose that there is something wrong with both yours and that bot's onnx export. I think that there should be separate onnx file for encoder and decoder, since the encoder is run only once per inference and then the decoder is run iteratively until the end of the sequence is reached.

@LancerComet
Copy link
Author

@kha-white Thank you for your reply, I'm starting to understand the whole workflow. I am currently looking for a solution regarding generate function. I saw some stuff about BeamSearch in the onnxruntime repository, but haven't delved into it yet. As for the onnx model issue, it's actually possible to create two separate models, but I had previously merged them while generating because I didn't have a deep understanding of the model. Again thank you very much for your response.

@mayocream
Copy link

mayocream commented Oct 3, 2023

@LancerComet Hi, would you like to share your method to export the pre-trained model to onnx format? I am getting the below errors when exporting with optimum-cli:

$ optimum-cli export onnx --model kha-white/manga-ocr-base bin/                                                                                                     (base) 
Framework not specified. Using pt to export to ONNX.
Automatic task detection to image-to-text-with-past.
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
/Users/mayo/miniconda3/lib/python3.11/site-packages/transformers/models/vit/feature_extraction_vit.py:28: FutureWarning: The class ViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use ViTImageProcessor instead.
  warnings.warn(
Traceback (most recent call last):
  File "/Users/mayo/miniconda3/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/Users/mayo/miniconda3/lib/python3.11/site-packages/optimum/commands/optimum_cli.py", line 163, in main
    service.run()
  File "/Users/mayo/miniconda3/lib/python3.11/site-packages/optimum/commands/export/onnx.py", line 232, in run
    main_export(
  File "/Users/mayo/miniconda3/lib/python3.11/site-packages/optimum/exporters/onnx/__main__.py", line 399, in main_export
    onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mayo/miniconda3/lib/python3.11/site-packages/optimum/exporters/onnx/__main__.py", line 82, in _get_submodels_and_onnx_configs
    onnx_config = onnx_config_constructor(
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mayo/miniconda3/lib/python3.11/site-packages/optimum/exporters/onnx/base.py", line 623, in with_past
    return cls(
           ^^^^
  File "/Users/mayo/miniconda3/lib/python3.11/site-packages/optimum/exporters/onnx/model_configs.py", line 1231, in __init__
    super().__init__(
  File "/Users/mayo/miniconda3/lib/python3.11/site-packages/optimum/exporters/onnx/config.py", line 322, in __init__
    raise ValueError(
ValueError: The decoder part of the encoder-decoder model is bert which does not need past key values.

Updated:

Got exporting succucced. Run
optimum-cli export onnx --model kha-white/manga-ocr-base --task vision2seq-lm bin/

@mathewthe2
Copy link

mathewthe2 commented May 24, 2024

Ok so I don't know exactly how to do the inference in onnx (I played around with it a little bit but it seemed rather tricky to do and I abandoned/postponed it), so I'll just tell you what I know.

This model has an encoder-decoder architecture. The encoder gets an image as an input and outputs a feature vector (this is encoder_last_hidden_state). Then, this feature vector is passed to the decoder, which is run iteratively, outputting one token at a time until reaching a special token indicating the end of the sequence. At each step, the decoder is being fed the sequence of all the tokens it had outputted so far (this is decoder_input_ids). The decoder outputs the tokens as logits - each token is represented by a vector of 6144 values corresponding to the vocab, as you correctly noticed. You can take the argmax from the logits vector to get the most probable token, but what actually happens is the logits are converted to probabilities and top N hypotheses are considered at each step (beam search). This is done by HuggingFace's generate method, which is unfortunately quite complex.

Seems like there's a rust implementation called Candle by huggingface as well. There's also an example for tocr, not sure how hard it is to convert this model to candle or use candle-onnx, but it seems promising for binary/wasm.

Update: I modified the trocr example to work with mangaocr, and the image processor and encoder seems fine (not exact tensor outputs), but I am having trouble with the final output of the decoder.

let output_projection = candle_nn::linear_no_bias(
  decoder_cfg.d_model,
  decoder_cfg.vocab_size,
  vb.pp("decoder.cls.predictions.decoder")
)?;

I am getting Error: cannot find tensor decoder.cls.predictions.decoder.weight for the output tensor even though that seems to exist when I check the tensors of the model with state_dict in Python.

decoder.cls.predictions.bias torch.Size([6144])
decoder.cls.predictions.transform.dense.weight torch.Size([768, 768])
decoder.cls.predictions.transform.dense.bias torch.Size([768])
decoder.cls.predictions.transform.LayerNorm.weight torch.Size([768])
decoder.cls.predictions.transform.LayerNorm.bias torch.Size([768])
decoder.cls.predictions.decoder.weight torch.Size([6144, 768])
decoder.cls.predictions.decoder.bias torch.Size([6144])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants