From 6bebc8f605feb203356a2be058fec3585cfad0d7 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 20 Nov 2023 08:48:34 -0800 Subject: [PATCH] start adopting associative scan based circuits for time --- magvit2_pytorch/magvit2_pytorch.py | 30 +++++++++++++++++++++++++++++- magvit2_pytorch/version.py | 2 +- setup.py | 1 + 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/magvit2_pytorch/magvit2_pytorch.py b/magvit2_pytorch/magvit2_pytorch.py index 397897a..485a01a 100644 --- a/magvit2_pytorch/magvit2_pytorch.py +++ b/magvit2_pytorch/magvit2_pytorch.py @@ -25,6 +25,8 @@ from magvit2_pytorch.attend import Attend from magvit2_pytorch.version import __version__ +from gateloop_transformer import SimpleGateLoopLayer + from kornia.filters import filter3d import pickle @@ -164,6 +166,23 @@ def __init__(self, fn: Module): def forward(self, x, **kwargs): return self.fn(x, **kwargs) + x +# for a bunch of tensor operations to change tensor to (batch, time, feature dimension) and back + +class ToTimeSequence(Module): + @beartype + def __init__(self, fn: Module): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + x = rearrange(x, 'b c f ... -> b ... f c') + x, ps = pack_one(x, '* n c') + + o = self.fn(x, **kwargs) + + o = unpack_one(o, ps, '* n c') + return rearrange(o, 'b ... f c -> b c f ...') + # token shifting class TokenShift(Module): @@ -1030,7 +1049,8 @@ def __init__( grad_penalty_loss_weight = 10., multiscale_adversarial_loss_weight = 1., flash_attn = True, - separate_first_frame_encoding = False + separate_first_frame_encoding = False, + gateloop_use_jax = False ): super().__init__() @@ -1157,6 +1177,14 @@ def __init__( Residual(FeedForward(dim)) ) + elif layer_type == 'gateloop_time': + gateloop_kwargs = dict( + use_jax_associative_scan = gateloop_use_jax + ) + + encoder_layer = ToTimeSequence(Residual(SimpleGateLoopLayer(dim = dim))) + decoder_layer = ToTimeSequence(Residual(SimpleGateLoopLayer(dim = dim))) + elif layer_type == 'attend_time': attn_kwargs = dict( dim = dim, diff --git a/magvit2_pytorch/version.py b/magvit2_pytorch/version.py index 641e1a8..692b2b8 100644 --- a/magvit2_pytorch/version.py +++ b/magvit2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.1.26' +__version__ = '0.1.27' diff --git a/setup.py b/setup.py index f3a542c..82fa19c 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ 'beartype', 'einops>=0.7.0', 'ema-pytorch>=0.2.4', + 'gateloop-transformer>=0.0.25', 'kornia', 'opencv-python', 'pillow',