Skip to content

Commit

Permalink
share an idea that should be tried if it has not been
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 15, 2023
1 parent 0ad09c4 commit d446a41
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.6.3',
version = '1.6.4',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
162 changes: 162 additions & 0 deletions vit_pytorch/simple_vit_with_fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import torch
from torch.fft import fft
from torch import nn

from einops import rearrange, reduce, pack, unpack
from einops.layers.torch import Rearrange

# helpers

def pair(t):
return t if isinstance(t, tuple) else (t, t)

def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)

y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)

# classes

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)

self.attend = nn.Softmax(dim = -1)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

def forward(self, x):
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)

class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, freq_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
freq_patch_height, freq_patch_width = pair(freq_patch_size)

assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert image_height % freq_patch_height == 0 and image_width % freq_patch_width == 0, 'Image dimensions must be divisible by the freq patch size.'

patch_dim = channels * patch_height * patch_width
freq_patch_dim = channels * 2 * freq_patch_height * freq_patch_width

self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)

self.to_freq_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) ri -> b (h w) (p1 p2 ri c)", p1 = freq_patch_height, p2 = freq_patch_width),
nn.LayerNorm(freq_patch_dim),
nn.Linear(freq_patch_dim, dim),
nn.LayerNorm(dim)
)

self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)

self.freq_pos_embedding = posemb_sincos_2d(
h = image_height // freq_patch_height,
w = image_width // freq_patch_width,
dim = dim
)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

self.pool = "mean"
self.to_latent = nn.Identity()

self.linear_head = nn.Linear(dim, num_classes)

def forward(self, img):
device, dtype = img.device, img.dtype

x = self.to_patch_embedding(img)
freqs = torch.view_as_real(fft(img))

f = self.to_freq_embedding(freqs)

x += self.pos_embedding.to(device, dtype = dtype)
f += self.freq_pos_embedding.to(device, dtype = dtype)

x, ps = pack((f, x), 'b * d')

x = self.transformer(x)

_, x = unpack(x, ps, 'b * d')
x = reduce(x, 'b n d -> b d', 'mean')

x = self.to_latent(x)
return self.linear_head(x)

if __name__ == '__main__':
vit = SimpleViT(
num_classes = 1000,
image_size = 256,
patch_size = 8,
freq_patch_size = 8,
dim = 1024,
depth = 1,
heads = 8,
mlp_dim = 2048,
)

images = torch.randn(8, 3, 256, 256)

logits = vit(images)

6 comments on commit d446a41

@lucidrains
Copy link
Owner Author

@lucidrains lucidrains commented on d446a41 Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if any vision researchers has seen an idea like this for vision transformers, please let me know and i'll cite. there was a big success in the music separation space applying attention to fourier domain

@skumar-ml
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two papers that come to mind:

https://arxiv.org/pdf/2107.00645.pdf - replace attention with a global frequency filter. May be more desired than attention since frequency information is not spatially localized, and the global frequency filter does global convolution in spatial domain.

https://arxiv.org/pdf/2304.06446.pdf - builds off of previous, but mixes frequency + MHSA.

They don't concat frequency with the image data as you are proposing.

Anecdotally, I've tried converting each patch to the frequency domain and running only the freq info through a ViT without much success.

@lucidrains
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skumar-ml thank you Shubham! i'll give those a read!

@lucidrains
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you say, "without much success", only for classification, or what tasks have you tried this on?

@skumar-ml
Copy link

@skumar-ml skumar-ml commented on d446a41 Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucidrains - I tried it on CIFAR-100 for object classification (using the model layout specified in DeIT-Tiny and DeIT-Small). Each 16x16x3 image patch was transformed into a 16x9 via FFT (complex and divided by 2 in the last dimension because the input data is real). I unrolled the frequency response into phase and magnitude and used the 16x9x2 as input to the linear embedding. I also only took the FFT of the grayscale image, which may be why my performance was substantially lower.

It's something I'm still working through for a class project, so I'll let you know if I'm able to figure anything else out.

@lucidrains
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah! thanks for the context!

Please sign in to comment.