Skip to content

Commit

Permalink
add a helper function wrapper around einops rearrange that allows for…
Browse files Browse the repository at this point in the history
… passing in anonymous dimensions as a list, for addressing a common tensor manipulation pattern of saving adjacent dimensions for flatten and then reconstitution
  • Loading branch information
lucidrains committed Apr 2, 2022
1 parent 731206d commit b190c5e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
9 changes: 8 additions & 1 deletion memorizing_transformers_pytorch/knn_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import torch
import faiss
import numpy as np

from einops import rearrange
from memorizing_transformers_pytorch.utils import rearrange_with_dim_list

# constants

Expand Down Expand Up @@ -175,7 +177,9 @@ def search(
nprobe = 8,
increment_hits = False
):
check_shape(queries, 'b n d', d = self.dim, b = self.num_indices)
_, *prec_dims, _ = queries.shape
check_shape(queries, 'b ... d', d = self.dim, b = self.num_indices)
queries = rearrange(queries, 'b ... d -> b (...) d')

device = queries.device
queries = queries.detach().cpu().numpy()
Expand All @@ -197,6 +201,9 @@ def search(
all_key_values = torch.stack(all_key_values)
all_key_values = all_key_values.masked_fill(~rearrange(all_masks, '... -> ... 1 1'), 0.)

all_key_values = rearrange_with_dim_list(all_key_values, 'b (...p) ... -> b ...p ...', p = prec_dims)
all_masks = rearrange_with_dim_list(all_masks, 'b (...p) ... -> b ...p ...', p = prec_dims)

return all_key_values.to(device), all_masks.to(device)

def __del__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,8 @@ def forward(

# calculate knn attention over memory, if index is passed in

knn_queries = rearrange(q, 'b h n d -> b (h n) d')

mem_kv, mem_mask = knn_memory.search(knn_queries, self.num_retrieved_memories)

mem_mask = rearrange(mem_mask, 'b (h i) j -> b h i j', h = h)
mem_k, mem_v = rearrange(mem_kv, 'b (h i) j kv d -> b h i j kv d', h = h).unbind(dim = -2)
mem_kv, mem_mask = knn_memory.search(q, self.num_retrieved_memories)
mem_k, mem_v = mem_kv.unbind(dim = -2)

# use null key / value to protect against empty memory

Expand Down
28 changes: 28 additions & 0 deletions memorizing_transformers_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import re
from einops import rearrange

def rearrange_with_dim_list(tensor, pattern, **kwargs):
regex = r'(\.\.\.[a-zA-Z]+)'
matches = re.findall(regex, pattern)
dim_prefixes = tuple(map(lambda t: t.lstrip('...'), set(matches)))

update_kwargs_dict = dict()

for prefix in dim_prefixes:
assert prefix in kwargs, f'dimension list "{prefix}" was not passed in'
dim_list = kwargs[prefix]
assert isinstance(dim_list, (list, tuple)), f'dimension list "{prefix}" needs to be a tuple of list of dimensions'
dim_names = list(map(lambda ind: f'{prefix}{ind}', range(len(dim_list))))
update_kwargs_dict[prefix] = dict(zip(dim_names, dim_list))

def sub_with_anonymous_dims(t):
dim_name_prefix = t.groups()[0].strip('...')
return ' '.join(update_kwargs_dict[dim_name_prefix].keys())

pattern_new = re.sub(regex, sub_with_anonymous_dims, pattern)

for prefix, update_dict in update_kwargs_dict.items():
del kwargs[prefix]
kwargs.update(update_dict)

return rearrange(tensor, pattern_new, **kwargs)

0 comments on commit b190c5e

Please sign in to comment.