Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Predictor precision PT backend #1204

Merged
merged 9 commits into from
Jun 20, 2023

Conversation

felixdittrich92
Copy link
Contributor

@felixdittrich92 felixdittrich92 commented May 29, 2023

This PR:

  • adds precision dtype to predictor / det predictor and reco predictor (effects only PT)

Any feedback is welcome 馃

Revert for TF which would need to convert the weights while loading (takes a lot of time)

tests depends on: #1201
issue: #1112

@felixdittrich92 felixdittrich92 added this to the 0.6.1 milestone May 29, 2023
@felixdittrich92 felixdittrich92 self-assigned this May 29, 2023
@felixdittrich92 felixdittrich92 added module: models Related to doctr.models framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend topic: text detection Related to the task of text detection topic: text recognition Related to the task of text recognition type: new feature New feature labels May 29, 2023
@felixdittrich92 felixdittrich92 linked an issue May 29, 2023 that may be closed by this pull request
@felixdittrich92 felixdittrich92 removed the framework: tensorflow Related to TensorFlow backend label May 30, 2023
@felixdittrich92 felixdittrich92 changed the title [Feat] Predictor precision for TF and PT backend [Feat] Predictor precision PT backend May 30, 2023
@felixdittrich92
Copy link
Contributor Author

@odulcy-mindee ready for review if tests done :)

@codecov
Copy link

codecov bot commented Jun 1, 2023

Codecov Report

Merging #1204 (ad0aaa7) into main (fdd00a3) will decrease coverage by 0.01%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main    #1204      +/-   ##
==========================================
- Coverage   93.68%   93.67%   -0.01%     
==========================================
  Files         154      154              
  Lines        6903     6911       +8     
==========================================
+ Hits         6467     6474       +7     
- Misses        436      437       +1     
Flag Coverage 螖
unittests 93.67% <100.00%> (-0.01%) 猬囷笍

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage 螖
doctr/models/detection/zoo.py 96.96% <酶> (酶)
doctr/models/zoo.py 100.00% <酶> (酶)
doctr/models/classification/predictor/pytorch.py 95.45% <100.00%> (+0.45%) 猬嗭笍
doctr/models/detection/predictor/pytorch.py 95.23% <100.00%> (+0.50%) 猬嗭笍
doctr/models/recognition/predictor/pytorch.py 91.66% <100.00%> (+0.49%) 猬嗭笍
doctr/models/utils/pytorch.py 100.00% <100.00%> (酶)

... and 2 files with indirect coverage changes

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Jun 9, 2023

I am not an expert of this model and I am new to doctr but is the problem here:

147        if conf.expanded_channels != conf.input_channels:
148            _layers.extend(conv_sequence(conf.expanded_channels, act_fn, kernel_size=1, bn=True, **_kwargs))

https://github.com/felixdittrich92/doctr/blob/1bf12a3ec73ddb463420d4243133b2d423a602d3/doctr/models/utils/tensorflow.py#L87

Hey @nikokks this is not related to the PR so please keep it clean ^^ Feel free to contact me on Linkedin about universial issues

About the problem: It seems to be a issue with your Tensorflow Installation -> DNN lib cannot be found :)

doctr/models/recognition/zoo.py Outdated Show resolved Hide resolved
@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Jun 20, 2023

Now without the extra kwarg which would only effect PyTorch:
simply do:

predictor = (
    ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="db_mobilenet_v3_large", pretrained=True).cuda().half()
)

or

predictor = (
    ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="db_mobilenet_v3_large", pretrained=True).to(device="cuda", dtype=torch.bfloat16)
)

In tensorflow you can do:

policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

We cannot automate this because all these settings are global in TF

Copy link
Collaborator

@odulcy-mindee odulcy-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely better, thanks Felix !

@felixdittrich92 felixdittrich92 merged commit 31f05c8 into mindee:main Jun 20, 2023
56 of 58 checks passed
@felixdittrich92 felixdittrich92 deleted the predictor-precision branch June 20, 2023 10:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
framework: pytorch Related to PyTorch backend module: models Related to doctr.models topic: text detection Related to the task of text detection topic: text recognition Related to the task of text recognition type: new feature New feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

16 bit precision support in predictors
2 participants