Skip to content

Commit

Permalink
follow spatiotemporal attention with a feedforward, and add the highl…
Browse files Browse the repository at this point in the history
…y effective token shift along the time axis in the hidden layer
  • Loading branch information
lucidrains committed Mar 19, 2023
1 parent b6e0a17 commit 7e0838e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 17 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,13 @@ video_as_images_out = unet(video, enable_time = False)
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
```

```bibtex
@article{Dong2021AttentionIN,
title = {Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth},
author = {Yihe Dong and Jean-Baptiste Cordonnier and Andreas Loukas},
journal = {ArXiv},
year = {2021},
volume = {abs/2103.03404}
}
```
55 changes: 39 additions & 16 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,35 @@ def forward(self, x):

# feedforward

def shift_token(t):
t, t_shift = t.chunk(2, dim = 1)
t_shift = F.pad(t_shift, (0, 0, 0, 0, 1, -1), value = 0.)
return torch.cat((t, t_shift), dim = 1)

class GEGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
x, gate = x.chunk(2, dim = 1)
return x * F.gelu(gate)

def FeedForward(dim, mult = 4):
inner_dim = int(dim * mult * 2 / 3)
return nn.Sequential(
nn.Linear(dim, inner_dim, bias = False),
GEGLU(),
nn.Linear(inner_dim, bias = False)
)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()

inner_dim = int(dim * mult * 2 / 3)
self.proj_in = nn.Sequential(
nn.Conv3d(dim, inner_dim * 2, 1, bias = False),
GEGLU()
)

self.proj_out = nn.Conv3d(inner_dim, dim, 1, bias = False)

def forward(self, x, enable_time = True):
x = self.proj_in(x)

if enable_time:
x = shift_token(x)

return self.proj_out(x)

# best relative positional encoding

Expand Down Expand Up @@ -242,15 +259,16 @@ def forward(
return x

# factorized spatial temporal attention from Ho et al.
# todo - take care of relative positional biases + rotary embeddings

class SpatioTemporalAttention(nn.Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8
heads = 8,
add_feed_forward = True,
ff_mult = 4
):
super().__init__()
self.spatial_attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
Expand All @@ -259,6 +277,11 @@ def __init__(
self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1)

if not add_feed_forward:
return

self.ff = FeedForward(dim = dim, mult = ff_mult)

def forward(
self,
x,
Expand All @@ -282,17 +305,17 @@ def forward(
else:
x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)

if not enable_time:
return x
if enable_time:

x = rearrange(x, 'b c f h w -> (b h w) f c')
x = rearrange(x, 'b c f h w -> (b h w) f c')

time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1])
time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1])

x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias) + x
x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias) + x

x = rearrange(x, '(b h w) f c -> b c f h w', w = w, h = h)
x = rearrange(x, '(b h w) f c -> b c f h w', w = w, h = h)

x = self.ff(x, enable_time = enable_time) + x
return x

# resnet block
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.10',
version = '0.1.0',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 7e0838e

Please sign in to comment.