Skip to content

Commit

Permalink
start adopting associative scan based circuits for time
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 20, 2023
1 parent 15818d3 commit 6bebc8f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
30 changes: 29 additions & 1 deletion magvit2_pytorch/magvit2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from magvit2_pytorch.attend import Attend
from magvit2_pytorch.version import __version__

from gateloop_transformer import SimpleGateLoopLayer

from kornia.filters import filter3d

import pickle
Expand Down Expand Up @@ -164,6 +166,23 @@ def __init__(self, fn: Module):
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x

# for a bunch of tensor operations to change tensor to (batch, time, feature dimension) and back

class ToTimeSequence(Module):
@beartype
def __init__(self, fn: Module):
super().__init__()
self.fn = fn

def forward(self, x, **kwargs):
x = rearrange(x, 'b c f ... -> b ... f c')
x, ps = pack_one(x, '* n c')

o = self.fn(x, **kwargs)

o = unpack_one(o, ps, '* n c')
return rearrange(o, 'b ... f c -> b c f ...')

# token shifting

class TokenShift(Module):
Expand Down Expand Up @@ -1030,7 +1049,8 @@ def __init__(
grad_penalty_loss_weight = 10.,
multiscale_adversarial_loss_weight = 1.,
flash_attn = True,
separate_first_frame_encoding = False
separate_first_frame_encoding = False,
gateloop_use_jax = False
):
super().__init__()

Expand Down Expand Up @@ -1157,6 +1177,14 @@ def __init__(
Residual(FeedForward(dim))
)

elif layer_type == 'gateloop_time':
gateloop_kwargs = dict(
use_jax_associative_scan = gateloop_use_jax
)

encoder_layer = ToTimeSequence(Residual(SimpleGateLoopLayer(dim = dim)))
decoder_layer = ToTimeSequence(Residual(SimpleGateLoopLayer(dim = dim)))

elif layer_type == 'attend_time':
attn_kwargs = dict(
dim = dim,
Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.26'
__version__ = '0.1.27'
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
'beartype',
'einops>=0.7.0',
'ema-pytorch>=0.2.4',
'gateloop-transformer>=0.0.25',
'kornia',
'opencv-python',
'pillow',
Expand Down

0 comments on commit 6bebc8f

Please sign in to comment.