Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle non square patch sizes #104

Closed
FilipAndersson245 opened this issue May 1, 2021 · 11 comments
Closed

handle non square patch sizes #104

FilipAndersson245 opened this issue May 1, 2021 · 11 comments

Comments

@FilipAndersson245
Copy link

FilipAndersson245 commented May 1, 2021

Hello, would it be difficult to allow for nonsquare patches to be used, I'm examining the use of Vit on some non-image data where it cannot be represented in a square format.
Preferably by having a tuple as an input to select width and height. Is there anything currently that would block this from being done?

@FilipAndersson245
Copy link
Author

I added initial functionality for this inside the vanilla Vit class,
See #105 (comment)

@lucidrains
Copy link
Owner

@FilipAndersson245 hey! it actually already works with non-square patches, as long as you set the image size to be that of max(height, width), and your height and width is divisible by the patch size

@FilipAndersson245
Copy link
Author

@lucidrains Hey! do you have an example of how you do that?
This is the code I currently use with my pull request,

from vit_pytorch import vit

v = vit.ViT(
    image_size = (25,60),
    patch_size = (25,1),
    num_classes = 30,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 25, 60)

preds = v(img)

@lucidrains
Copy link
Owner

@FilipAndersson245 yup, just try

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 128) # not a square

preds = v(img) # (1, 1000)

@lucidrains
Copy link
Owner

@FilipAndersson245 updated the readme too :) this is a frequently asked question

@FilipAndersson245
Copy link
Author

FilipAndersson245 commented May 1, 2021

@lucidrains Okay, but in your example, the patch size is still square 32x32 correct?
Thanks for the quick response!

@lucidrains
Copy link
Owner

Ohh gotcha, you want nonsquare patches! 🤦‍♂️ Ok I'll get this built :)

@FilipAndersson245
Copy link
Author

Did a pull request implementing it if you want to give it a look :)

@FilipAndersson245
Copy link
Author

my pull request was only for VIT but it would probably be good to do the same to all variants.

@lucidrains
Copy link
Owner

@FilipAndersson245 6549522

@Fodark
Copy link

Fodark commented May 10, 2021

Are there plans to include this modification to the other variants of ViT? e.g. the efficient one?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants