Skip to content

Commit

Permalink
be able to accept non-square patches, thanks to @FilipAndersson245
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 2, 2021
1 parent 6a80a4e commit 6549522
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 5 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,29 @@ img = torch.randn(1, 3, 256, 128) # <-- not a square
preds = v(img) # (1, 1000)
```

- How do I pass in non-square patches?

```python
import torch
from vit_pytorch import ViT

v = ViT(
num_classes = 1000,
image_size = (256, 128), # image size is a tuple of (height, width)
patch_size = (32, 16), # patch size is a tuple of (height, width)
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 128)

preds = v(img)
```

## Resources

Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.17.2',
version = '0.17.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
19 changes: 15 additions & 4 deletions vit_pytorch/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
Expand Down Expand Up @@ -74,13 +81,17 @@ def forward(self, x):
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)

Expand Down

0 comments on commit 6549522

Please sign in to comment.