Skip to content

Commit

Permalink
address #300
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 2, 2024
1 parent 96f66d2 commit bca88e9
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.6.7',
version = '1.6.8',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
171 changes: 171 additions & 0 deletions vit_pytorch/simple_flash_attn_vit_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from packaging import version
from collections import namedtuple

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Module, ModuleList

from einops import rearrange
from einops.layers.torch import Rearrange

# constants

Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# helpers

def pair(t):
return t if isinstance(t, tuple) else (t, t)

def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
_, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

z, y, x = torch.meshgrid(
torch.arange(f, device = device),
torch.arange(h, device = device),
torch.arange(w, device = device),
indexing = 'ij')

fourier_dim = dim // 6

omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
omega = 1. / (temperature ** omega)

z = z.flatten()[:, None] * omega[None, :]
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]

pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)

pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6
return pe.type(dtype)

# main class

class Attend(Module):
def __init__(self, use_flash = False, config: Config = Config(True, True, True)):
super().__init__()
self.config = config
self.use_flash = use_flash
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

def flash_attn(self, q, k, v):
# flash attention - https://arxiv.org/abs/2205.14135

with torch.backends.cuda.sdp_kernel(**self.config._asdict()):
out = F.scaled_dot_product_attention(q, k, v)

return out

def forward(self, q, k, v):
n, device, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5

if self.use_flash:
return self.flash_attn(q, k, v)

# similarity

sim = einsum("b h i d, b j d -> b h i j", q, k) * scale

# attention

attn = sim.softmax(dim=-1)

# aggregate values

out = einsum("b h i j, b j d -> b h i d", attn, v)

return out

# classes

class FeedForward(Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)

class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, use_flash = True):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)

self.attend = Attend(use_flash = use_flash)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

def forward(self, x):
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

out = self.attend(q, k, v)

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_flash):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, use_flash = use_flash),
FeedForward(dim, mlp_dim)
]))

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x

return x

class SimpleViT(Module):
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash_attn = True):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)

assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'

num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
patch_dim = channels * patch_height * patch_width * frame_patch_size

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_flash_attn)

self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)

def forward(self, video):
*_, h, w, dtype = *video.shape, video.dtype

x = self.to_patch_embedding(video)
pe = posemb_sincos_3d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe

x = self.transformer(x)
x = x.mean(dim = 1)

x = self.to_latent(x)
return self.linear_head(x)

0 comments on commit bca88e9

Please sign in to comment.