-
Notifications
You must be signed in to change notification settings - Fork 382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[models] update vit and transformer layer norm #1059
[models] update vit and transformer layer norm #1059
Conversation
Codecov Report
@@ Coverage Diff @@
## main #1059 +/- ##
=======================================
Coverage 95.13% 95.14%
=======================================
Files 141 141
Lines 5819 5827 +8
=======================================
+ Hits 5536 5544 +8
Misses 283 283
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
@frgfm About nn.Sequential i would keep nn.Module to avoid to build an nn.Module for the head |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @felixdittrich92 !
Also, thanks @frgfm for the review 馃憤
@odulcy-mindee will update some last stuff in a few minutes |
@felixdittrich92 ok, I'll review that after 馃憣 |
Last PR:
|
ok weird behaviour ... Conv2D padding='valid' works well without (which should be the default) it doesnt ... |
@odulcy-mindee should be ok now 馃槄 Unfortunately i saw with padding='valid' and onnx is a known issue and processed internally by microsoft so i think we will get a fix soon |
Will test "manual" patchify without Conv tomorrow morning... maybe a better solution |
Now it works much better and ONNX works also the only disadvantage is, that it is a bit slower as using Conv2d with padding='valid' (introduced in pytorch 1.10). TF side still well running with Conv2d
|
Now i'm really done for review 馃槄 sry for all the changes afterwards :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
馃殌 馃殌
For reference, this is linked to #1050 (always better to be able to trace back the evolution/fixes :)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Late review again, but hopefully it helps :)
class ClassifierHead(nn.Module): | ||
"""Classifier head for Vision Transformer | ||
|
||
Args: | ||
in_channels: number of input channels | ||
num_classes: number of output classes | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_channels: int, | ||
num_classes: int, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.head = nn.Linear(in_channels, num_classes) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
# (batch_size, num_classes) cls token | ||
return self.head(x[:, 0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmmmh, what's the difference with
head = nn.Linear(in_channels, num_classes)
...
out = head(x[:, 0])
(Linear actually supports higher dimensions than 2, we can reshape it afterwards I think)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be cleaner to add a squeeze or flatten layer in the sequential, rather than creating a class that is doing 99% the same as a Linear :)
@@ -109,7 +120,7 @@ def _vit( | |||
return model | |||
|
|||
|
|||
def vit(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: | |||
def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: | |||
"""VisionTransformer architecture as described in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest specifying the version of the archi in the docstring as well "VisionTransformer-B"
@@ -135,7 +135,7 @@ def _vit( | |||
return model | |||
|
|||
|
|||
def vit(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: | |||
def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: | |||
"""VisionTransformer architecture as described in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
B, C, H, W = x.shape | ||
assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height" | ||
assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width" | ||
|
||
patches = self.proj(x) # BCHW | ||
# patchify image without convolution | ||
# adopted from: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo "adapted"
This PR:
Any feedback is welcome 馃