Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[models] update vit and transformer layer norm #1059

Merged
merged 14 commits into from
Sep 15, 2022

Conversation

felixdittrich92
Copy link
Contributor

This PR:

  • apply suggestions for vit @frgfm

Any feedback is welcome 馃

@felixdittrich92 felixdittrich92 added this to the 0.6.0 milestone Sep 12, 2022
@felixdittrich92 felixdittrich92 added module: models Related to doctr.models framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend type: misc Miscellaneous labels Sep 12, 2022
@codecov
Copy link

codecov bot commented Sep 12, 2022

Codecov Report

Merging #1059 (564a789) into main (28a6cce) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##             main    #1059   +/-   ##
=======================================
  Coverage   95.13%   95.14%           
=======================================
  Files         141      141           
  Lines        5819     5827    +8     
=======================================
+ Hits         5536     5544    +8     
  Misses        283      283           
Flag Coverage 螖
unittests 95.14% <100.00%> (+<0.01%) 猬嗭笍

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage 螖
doctr/models/classification/zoo.py 100.00% <酶> (酶)
doctr/models/classification/vit/pytorch.py 100.00% <100.00%> (+2.43%) 猬嗭笍
doctr/models/classification/vit/tensorflow.py 97.82% <100.00%> (酶)
doctr/models/modules/transformer/pytorch.py 100.00% <100.00%> (酶)
doctr/models/modules/transformer/tensorflow.py 99.03% <100.00%> (+0.04%) 猬嗭笍
doctr/models/modules/vision_transformer/pytorch.py 100.00% <100.00%> (酶)
doctr/transforms/functional/base.py 95.65% <0.00%> (-1.45%) 猬囷笍
doctr/transforms/modules/base.py 94.59% <0.00%> (酶)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Sep 13, 2022

@frgfm About nn.Sequential i would keep nn.Module to avoid to build an nn.Module for the head

odulcy-mindee
odulcy-mindee previously approved these changes Sep 14, 2022
Copy link
Collaborator

@odulcy-mindee odulcy-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @felixdittrich92 !

Also, thanks @frgfm for the review 馃憤

@felixdittrich92
Copy link
Contributor Author

@odulcy-mindee will update some last stuff in a few minutes

@odulcy-mindee
Copy link
Collaborator

@felixdittrich92 ok, I'll review that after 馃憣

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Sep 14, 2022

Last PR:

  • fix minor mistake in PatchEmbedding now the model runs well (without bigger drops)
  • ViT PT now also as Sequential and Classifier as standalone module (suggestion from @frgfm )

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Sep 14, 2022

ok weird behaviour ... Conv2D padding='valid' works well without (which should be the default) it doesnt ...
disabled onnx export ftm: pytorch/pytorch#68880

@felixdittrich92
Copy link
Contributor Author

@odulcy-mindee should be ok now 馃槄 Unfortunately i saw with padding='valid' and onnx is a known issue and processed internally by microsoft so i think we will get a fix soon

@felixdittrich92
Copy link
Contributor Author

Will test "manual" patchify without Conv tomorrow morning... maybe a better solution

@felixdittrich92
Copy link
Contributor Author

Now it works much better and ONNX works also the only disadvantage is, that it is a bit slower as using Conv2d with padding='valid' (introduced in pytorch 1.10). TF side still well running with Conv2d

(doctr-dev) felix@felix-GS66-Stealth-11UH:~/Desktop/doctr$ python3 /home/felix/Desktop/doctr/references/classification/train_pytorch.py vit_b --epochs=50
2022-09-15 07:56:02.911032: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Namespace(amp=False, arch='vit_b', batch_size=64, device=None, epochs=50, export_onnx=False, find_lr=False, font='FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf', input_size=32, lr=0.001, name=None, pretrained=False, push_to_hub=False, resume=None, sched='cosine', show_samples=False, test_only=False, train_samples=1000, val_samples=20, vocab='french', wb=False, weight_decay=0, workers=None)
Validation set loaded in 1.966s (2520 samples in 40 batches)
Train set loaded in 0.2431s (126000 samples in 1968 batches)
Validation loss decreased inf --> 2.42893: saving state...                                                                                                      
Epoch 1/50 - Validation loss: 2.42893 (Acc: 30.28%)
Validation loss decreased 2.42893 --> 2.14431: saving state...                                                                                                  
Epoch 2/50 - Validation loss: 2.14431 (Acc: 39.84%)
Validation loss decreased 2.14431 --> 1.9619: saving state...                                                                                                   
Epoch 3/50 - Validation loss: 1.9619 (Acc: 42.74%)

@felixdittrich92
Copy link
Contributor Author

Now i'm really done for review 馃槄 sry for all the changes afterwards :)

Copy link
Collaborator

@odulcy-mindee odulcy-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

馃殌 馃殌

@felixdittrich92 felixdittrich92 merged commit a95baaa into mindee:main Sep 15, 2022
@felixdittrich92 felixdittrich92 deleted the transformer-updates branch September 15, 2022 09:03
@frgfm
Copy link
Collaborator

frgfm commented Sep 16, 2022

For reference, this is linked to #1050 (always better to be able to trace back the evolution/fixes :))

Copy link
Collaborator

@frgfm frgfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Late review again, but hopefully it helps :)

Comment on lines +32 to +51
class ClassifierHead(nn.Module):
"""Classifier head for Vision Transformer

Args:
in_channels: number of input channels
num_classes: number of output classes
"""

def __init__(
self,
in_channels: int,
num_classes: int,
) -> None:
super().__init__()

self.head = nn.Linear(in_channels, num_classes)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# (batch_size, num_classes) cls token
return self.head(x[:, 0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmmh, what's the difference with

head = nn.Linear(in_channels, num_classes)
...
out = head(x[:, 0])

(Linear actually supports higher dimensions than 2, we can reshape it afterwards I think)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be cleaner to add a squeeze or flatten layer in the sequential, rather than creating a class that is doing 99% the same as a Linear :)

@@ -109,7 +120,7 @@ def _vit(
return model


def vit(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer architecture as described in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest specifying the version of the archi in the docstring as well "VisionTransformer-B"

@@ -135,7 +135,7 @@ def _vit(
return model


def vit(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer architecture as described in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here


def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height"
assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width"

patches = self.proj(x) # BCHW
# patchify image without convolution
# adopted from:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo "adapted"

@felixdittrich92 felixdittrich92 mentioned this pull request Sep 26, 2022
85 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend module: models Related to doctr.models type: misc Miscellaneous
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants