Skip to content

Commit

Permalink
fix a bug with weight tying across layers and number of self attentio…
Browse files Browse the repository at this point in the history
…n transformer blocks > 1, thanks to @yuanmao
  • Loading branch information
lucidrains committed Dec 12, 2021
1 parent 2d59df4 commit e5a81bd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
19 changes: 10 additions & 9 deletions perceiver_pytorch/perceiver_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@ def default(val, d):
return val if exists(val) else d

def cache_fn(f):
cache = None
cache = dict()
@wraps(f)
def cached_fn(*args, _cache = True, **kwargs):
def cached_fn(*args, _cache = True, key = None, **kwargs):
if not _cache:
return f(*args, **kwargs)
nonlocal cache
if cache is not None:
return cache
cache = f(*args, **kwargs)
return cache
if key in cache:
return cache[key]
result = f(*args, **kwargs)
cache[key] = result
return result
return cached_fn

def fourier_encode(x, max_freq, num_bands = 4):
Expand Down Expand Up @@ -196,10 +197,10 @@ def __init__(

self_attns = nn.ModuleList([])

for _ in range(self_per_cross_attn):
for block_ind in range(self_per_cross_attn):
self_attns.append(nn.ModuleList([
get_latent_attn(**cache_args),
get_latent_ff(**cache_args)
get_latent_attn(**cache_args, key = block_ind),
get_latent_ff(**cache_args, key = block_ind)
]))

self.layers.append(nn.ModuleList([
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'perceiver-pytorch',
packages = find_packages(),
version = '0.8.0',
version = '0.8.1',
license='MIT',
description = 'Perceiver - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e5a81bd

Please sign in to comment.