Skip to content

Commit

Permalink
fix null key / values
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 10, 2023
1 parent 7d3d501 commit a806696
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions memory_compressed_attention/memory_compressed_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def forward(self, x, input_mask = None):
k, v = map(self.compress_fn, (k, v))

# attach a null key and value, in the case that the first query has no keys to pay attention to
k = torch.cat((self.null_k, k), dim=1)
v = torch.cat((self.null_v, v), dim=1)
nk, nv = map(lambda t: t.expand(b, -1, -1), (self.null_k, self.null_v))
k = torch.cat((nk, k), dim=1)
v = torch.cat((nv, v), dim=1)

# merge heads
q, k, v = map(lambda t: t.reshape(*t.shape[:2], h, -1).transpose(1, 2), (q, k, v))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'memory_compressed_attention',
packages = find_packages(),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'Memory-Compressed Self Attention',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit a806696

Please sign in to comment.