Skip to content

Commit

Permalink
add vit with patch dropout, fully embrace structured dropout as multi…
Browse files Browse the repository at this point in the history
…ple papers are now corroborating each other
  • Loading branch information
lucidrains committed Dec 2, 2022
1 parent 2f87c0c commit e051522
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 1 deletion.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1884,4 +1884,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```

```bibtex
@article{Liu2022PatchDropoutEV,
title = {PatchDropout: Economizing Vision Transformers Using Patch Dropout},
author = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.07220}
}
```

*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.39.1',
version = '0.40.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
152 changes: 152 additions & 0 deletions vit_pytorch/vit_with_patch_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
return t if isinstance(t, tuple) else (t, t)

# classes

class PatchDropout(nn.Module):
def __init__(self, prob):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob

def forward(self, x):
if not self.training or self.prob == 0.:
return x

b, n, _, device = *x.shape, x.device

batch_indices = torch.arange(b, device = device)
batch_indices = rearrange(batch_indices, '... -> ... 1')
num_patches_keep = max(1, int(n * (1 - self.prob)))
patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices

return x[batch_indices, patch_indices_keep]

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)

self.heads = heads
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()

def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x

class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., patch_dropout = 0.25):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)

self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

self.patch_dropout = PatchDropout(patch_dropout)
self.dropout = nn.Dropout(emb_dropout)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

self.pool = pool
self.to_latent = nn.Identity()

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)

def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape

x += self.pos_embedding

x = self.patch_dropout(x)

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)

x = torch.cat((cls_tokens, x), dim=1)
x = self.dropout(x)

x = self.transformer(x)

x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

x = self.to_latent(x)
return self.mlp_head(x)

0 comments on commit e051522

Please sign in to comment.