From cce1416862201ab3d1f54b521bd8170d4b4c6703 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 4 Aug 2021 05:20:21 -0700 Subject: [PATCH] fix mask needing repeats for non-autoregressive case --- long_short_transformer/long_short_transformer.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/long_short_transformer/long_short_transformer.py b/long_short_transformer/long_short_transformer.py index 389a121..78b3810 100644 --- a/long_short_transformer/long_short_transformer.py +++ b/long_short_transformer/long_short_transformer.py @@ -167,7 +167,7 @@ def forward(self, x, mask = None): pkv = self.to_dynamic_proj(gkv) if exists(mask): - pmask = rearrange(mask, 'b (n s) -> b n s', s = s) + pmask = repeat(mask, 'b (n s) -> (b h) n s', s = s, h = h) pkv.masked_fill_(~pmask[..., None], mask_value) pkv = pkv.softmax(dim = -2) diff --git a/setup.py b/setup.py index f0a1d41..ce9c830 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'long-short-transformer', packages = find_packages(), - version = '0.0.4', + version = '0.0.5', license='MIT', description = 'Long Short Transformer - Pytorch', author = 'Phil Wang',