Skip to content

Commit

Permalink
make continuous positional bias calculations more efficient, while ge…
Browse files Browse the repository at this point in the history
…neralizing to any number of dimensions
  • Loading branch information
lucidrains committed Mar 19, 2023
1 parent 50bc159 commit b6e0a17
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 19 deletions.
52 changes: 34 additions & 18 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from operator import mul

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat, pack, unpack
Expand Down Expand Up @@ -81,13 +82,10 @@ def __init__(
dim,
heads,
num_dims = 1,
layers = 2,
log_dist = True,
cache_rel_pos = False
layers = 2
):
super().__init__()
self.num_dims = num_dims
self.log_dist = log_dist

self.net = nn.ModuleList([])
self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU()))
Expand All @@ -97,33 +95,51 @@ def __init__(

self.net.append(nn.Linear(dim, heads))

self.cache_rel_pos = cache_rel_pos
self.register_buffer('rel_pos', None, persistent = False)

@property
def device(self):
return next(self.parameters()).device

def forward(self, *dimensions):
device = self.device

if not exists(self.rel_pos) or not self.cache_rel_pos:
positions = [torch.arange(d, device = device) for d in dimensions]
grid = torch.stack(torch.meshgrid(*positions, indexing = 'ij'))
grid = rearrange(grid, 'c ... -> (...) c')
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
shape = torch.tensor(dimensions, device = device)
rel_pos_shape = 2 * shape - 1

# calculate strides

strides = torch.flip(rel_pos_shape, (0,)).cumprod(dim = -1)
strides = torch.flip(F.pad(strides, (1, -1), value = 1), (0,))

# get all positions and calculate all the relative distances

positions = [torch.arange(d, device = device) for d in dimensions]
grid = torch.stack(torch.meshgrid(*positions, indexing = 'ij'), dim = -1)
grid = rearrange(grid, '... c -> (...) c')
rel_dist = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')

if self.log_dist:
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
# get all relative positions across all dimensions

self.register_buffer('rel_pos', rel_pos, persistent = False)
rel_positions = [torch.arange(-d + 1, d, device = device) for d in dimensions]
rel_pos_grid = torch.stack(torch.meshgrid(*rel_positions, indexing = 'ij'), dim = -1)
rel_pos_grid = rearrange(rel_pos_grid, '... c -> (...) c')

rel_pos = self.rel_pos.float()
# mlp input

bias = rel_pos_grid.float()

for layer in self.net:
rel_pos = layer(rel_pos)
bias = layer(bias)

# convert relative distances to indices of the bias

rel_dist += (shape - 1) # make sure all positive
rel_dist *= strides
rel_dist_indices = rel_dist.sum(dim = -1)

# now select the bias for each unique relative position combination

return rearrange(rel_pos, 'i j h -> h i j')
bias = bias[rel_dist_indices]
return rearrange(bias, 'i j h -> h i j')

# helper classes

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.9',
version = '0.0.10',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit b6e0a17

Please sign in to comment.