Step by step walkthru of lucidrains [LSHSelfAttention module](https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reformer_pytorch.py)

**Notes:**
* Local attention from lucidrains is not directly relevant for the reformer
* reused values per attention head is not standard afaik
* post_attn_dropout is not standard to my knowledge

In [1]:
from fastai.vision.all import *
import pdb

# Helpers

In [2]:
def default(val, default_val):
    return default_val if val is None else val

In [3]:
def merge_heads(v):
    return v.view(b, kv_len, h, -1).transpose(1, 2)

In [4]:
def split_heads(v):
    return v.view(b, h, t, -1).transpose(1, 2).contiguous()

In [5]:
def merge_dims(ind_from, ind_to, tensor):
    shape = list(tensor.shape)
    arr_slice = slice(ind_from, ind_to + 1)
    shape[arr_slice] = [reduce(mul, shape[arr_slice])]
    return tensor.reshape(*shape)

In [6]:
def split_at_index(dim, index, t):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(None, index))
    r = (*pre_slices, slice(index, None))
    return t[l], t[r]

In [7]:
def process_inputs_chunk(fn, chunks=1, dim=0):
    def inner_fn(*args, **kwargs):
        keys, values, len_args = kwargs.keys(), kwargs.values(), len(args)
        chunked_args = list(zip(*map(lambda x: x.chunk(chunks, dim=dim), list(args) + list(values))))
        all_args = map(lambda x: (x[:len_args], dict(zip(keys, x[len_args:]))), chunked_args)
        outputs = [fn(*c_args, **c_kwargs) for c_args, c_kwargs in all_args]
        return tuple(map(lambda x: torch.cat(x, dim=dim), zip(*outputs)))
    return inner_fn

# init

In [8]:
dim=128                          # note! This is `kqv` dim, not sl or embedding dim
heads = 8
bucket_size = 64
n_hashes = 8
causal = False
dim_head = 64                     # set qk, v dim manually, otherwise calculated as dim/n_heads
attn_chunks = 1                   # Provess attention calculation in chunks if we have memory concerns 
random_rotations_per_head = False # This is assumed to be false, and is removed in our LSHAttention implementation
attend_across_buckets = False     # same as LSHAttention
allow_duplicate_attention = False # same as LSHAttention
num_mem_kv = 0                    # extra random paramteres added to x. Reason for this?
one_value_head = False            # (True) one repeated v for all heads, or (False) separate value for each head (standard)
use_full_attn = False             # Full attention or LSH
full_attn_thres = None            # logic to decide if we can use full attention? 
return_attn = False               
post_attn_dropout = 0.            # Dropout after attention is calculated
dropout = 0.                      # Normal dropout passed to LSH layer
n_local_attn_heads = 0            # disable 

In [9]:
assert dim_head or (dim % heads) == 0, 'dimensions must be divisible by number of heads'
assert n_local_attn_heads < heads, 'local attention heads must be less than number of heads'

We can set `dim_heads` to some custom value. If not set, `dim_heads = dim`

In [10]:
dim_head = default(dim_head, dim // heads)
dim_heads = dim_head * heads
dim_head, dim_heads, dim

(64, 512, 128)

We can choose to splut up attention caluclations in chunks to decrease peak memory (note: This is unlike LSH chunking. This trades comput time agains memory).

In [11]:
attn_chunks = default(attn_chunks, 1)
attn_chunks

1

If `one_value_head` is True `v_head_repeats` is set to `n_heads`. That means that v is reused across attention heads. Not standard to my knowledge. Maybe save computations or memory?

In [12]:
v_head_repeats = (heads if one_value_head else 1)
v_dim = dim_heads // v_head_repeats
attn_chunks, v_head_repeats, v_dim

(1, 1, 512)

Set up projection layers:

In [13]:
toqk = nn.Linear(dim, dim_heads, bias = False)
tov = nn.Linear(dim, v_dim, bias = False)
to_out = nn.Linear(dim_heads, dim)
toqk, tov, to_out

(Linear(in_features=128, out_features=512, bias=False),
 Linear(in_features=128, out_features=512, bias=False),
 Linear(in_features=512, out_features=128, bias=True))

Dropout and attention layer:

In [14]:
from reformer_pytorch import LSHAttention
# self.bucket_size = bucket_size
lsh_attn = LSHAttention(bucket_size=bucket_size, n_hashes=n_hashes, causal=causal, random_rotations_per_head=random_rotations_per_head, attend_across_buckets = attend_across_buckets,  allow_duplicate_attention = allow_duplicate_attention, return_attn = return_attn, dropout = dropout)
# self.full_attn = FullQKAttention(causal=causal, dropout=dropout)
post_attn_dropout = nn.Dropout(post_attn_dropout)
# self.use_full_attn = use_full_attn
# self.full_attn_thres = default(full_attn_thres, bucket_size)

`num_mem_kv` let us pass additional parameters to our input `x`. They are randomly initialised. Note: this is not the keys from the encoder (if we are in a decoder setting) - they are passed as `keys` in the forward.

In [15]:
mem_kv = nn.Parameter(torch.randn(1, num_mem_kv, dim, requires_grad=True)) if num_mem_kv > 0 else None
num_mem_kv, mem_kv

(0, None)

Allows for local attention in heads: https://github.com/lucidrains/local-attention. Not part of the reformer (even though the sorted attention chuncks uses a form of local attention)

In [16]:
# self.n_local_attn_heads = n_local_attn_heads
# self.local_attn = LocalAttention(window_size=bucket_size * 2, causal=causal, dropout=dropout, shared_qk=True, look_forward=(1 if not causal else 0))
# self.callback = None

# Forward

In [17]:
x = torch.randn(10, 1024, 128) # random data for testing
keys = None
input_mask = None       # padding
input_attn_mask = None  # direct attention mask - remove
context_mask = None     # 

In [18]:
device, dtype = x.device, x.dtype
b, t, e, h, dh, m, l_h = *x.shape, heads, dim_head, num_mem_kv, n_local_attn_heads
b, t, e, h, dh, m, l_h

(10, 1024, 128, 8, 64, 0, 0)

* b - bs
* t - sl
* e - embedding dim
* h - n_heads
* dh - head dim
* m - num_mem_k?
* l_h - n_local_attn_heads

What is mem_kv? Random parameters of [bs, mem_kv, sl] added to x.

In [19]:
mem_kv = default(mem_kv, torch.empty(b, 0, e, dtype=dtype, device=device))
mem = mem_kv.expand(b, m, -1)
mem.shape

torch.Size([10, 0, 128])

Create empty keys if not passed in (depende on the role of the layer encoder/decoder)

In [20]:
keys = default(keys, torch.empty(b, 0, e, dtype=dtype, device=device))
c = keys.shape[1]
kv_len = t + m + c
kv_len, keys.shape

(1024, torch.Size([10, 0, 128]))

Logic to decide if can use full attention or not. Unsure of this:

In [21]:
#use_full_attn = use_full_attn or kv_len <= full_attn_thres
use_full_attn

False

Cat `mem` and `keys` to x if they exist

In [22]:
x = torch.cat((x, mem, keys), dim=1) # why do we cat keys to input?
x.shape

torch.Size([10, 1024, 128])

Projection matrices to create shared `qk` and `values`.

In [23]:
qk = toqk(x)
v = tov(x)
x.shape, qk.shape, v.shape

(torch.Size([10, 1024, 128]),
 torch.Size([10, 1024, 512]),
 torch.Size([10, 1024, 512]))

Repeat `v` across heads `one_value_head` is True:

In [24]:
v = v.repeat(1, 1, v_head_repeats)   # repeat v if desired
v_head_repeats, v.shape

(1, torch.Size([10, 1024, 512]))

Split n_heads dim to position 1:

In [25]:
merge_batch_and_heads = partial(merge_dims, 0, 1)
qk, v = map(merge_heads, (qk, v))
qk.shape, v.shape   # [bs, n_heads, sl, head_dim]

(torch.Size([10, 8, 1024, 64]), torch.Size([10, 8, 1024, 64]))

Decide number of lsh-heads, depeding on number of local heads:

In [26]:
has_local = l_h > 0
lsh_h = h - l_h
has_local, lsh_h, h

(False, 8, 8)

In [27]:
qk.shape

torch.Size([10, 8, 1024, 64])

Split local heads (if any) and normal heads:

In [29]:
split_index_fn = partial(split_at_index, 1, l_h)
(lqk, qk), (lv, v) = map(split_index_fn, (qk, v))
lqk.shape, qk.shape, lv.shape, v.shape

(torch.Size([10, 0, 1024, 64]),
 torch.Size([10, 8, 1024, 64]),
 torch.Size([10, 0, 1024, 64]),
 torch.Size([10, 8, 1024, 64]))

Merge batch and nheads dim. Note that `batch` and `n_heads` dimension are combined:

In [30]:
lqk, qk, lv, v = map(merge_batch_and_heads, (lqk, qk, lv, v))
lqk.shape, qk.shape, lv.shape, v.shape

(torch.Size([0, 1024, 64]),
 torch.Size([80, 1024, 64]),
 torch.Size([0, 1024, 64]),
 torch.Size([80, 1024, 64]))

Deal with masks, assumed to be set up properly before passed in. This code concatenates the various masks:

In [30]:
masks = {}
if input_mask is not None or context_mask is not None:
    default_mask = torch.tensor([True], device=device)
    i_mask = default(input_mask, default_mask.expand(b, t))
    m_mask = default_mask.expand(b, m)
    c_mask = default(context_mask, default_mask.expand(b, c))
    mask = torch.cat((i_mask, m_mask, c_mask), dim=1)
    mask = merge_batch_and_heads(expand_dim(1, lsh_h, mask))
    masks['input_mask'] = mask

if input_attn_mask is not None:
    input_attn_mask = merge_batch_and_heads(expand_dim(1, lsh_h, input_attn_mask))
    masks['input_attn_mask'] = input_attn_mask

Select attention function:

In [31]:
attn_fn = lsh_attn if not use_full_attn else full_attn
attn_fn

LSHAttention(
  (dropout): Dropout(p=0.0, inplace=False)
  (dropout_for_hash): Dropout(p=0.0, inplace=False)
)

Maybe use query len if set, in our case, t=sl:

In [32]:
partial_attn_fn = partial(attn_fn, query_len = t)
partial_attn_fn

functools.partial(LSHAttention(
  (dropout): Dropout(p=0.0, inplace=False)
  (dropout_for_hash): Dropout(p=0.0, inplace=False)
), query_len=1024)

We have to process attention per head. The `attn_chunks` argument let's us adjust the chunk size. This will have memory/performance implications. Default is 1 `attn_chunk`.

In [33]:
attn_fn_in_chunks = process_inputs_chunk(partial_attn_fn, chunks = attn_chunks)
attn_fn_in_chunks

<function __main__.process_inputs_chunk.<locals>.inner_fn(*args, **kwargs)>

Run the attention function. **Note! our qk has batch and n_heads dimensions collapsed. We need to run LSH attention on one head at a time. `process_inputs_chunk` takes care of this.**

In [34]:
qk.shape

torch.Size([80, 1024, 64])

In [35]:
out, attn, buckets = attn_fn_in_chunks(qk, v, **masks)
out.shape, attn.shape, buckets.shape # out: [bs*n_heads, sl, head_dim], buckets: [bs*n_heads, sl*n_rounds]

(torch.Size([80, 1024, 64]), torch.Size([0]), torch.Size([80, 8192]))

Maybe use callbacks:

In [36]:
# if self.callback is not None:
#     self.callback(attn.reshape(b, lsh_h, t, -1), buckets.reshape(b, lsh_h, -1))

Deal with local attention, we won't:

In [37]:
if has_local:
    lqk, lv = lqk[:, :t], lv[:, :t]
    local_out = self.local_attn(lqk, lqk, lv, input_mask=input_mask)
    local_out = local_out.reshape(b, l_h, t, -1)
    out = out.reshape(b, lsh_h, t, -1)
    out = torch.cat((local_out, out), dim=1)

Finally reshape output to [bs, sl, n_heads*head_dim]:

In [38]:
out = split_heads(out).view(b, t, -1)
out.shape

torch.Size([10, 1024, 512])

And process thru final linear layer. Note: if our n_heads*head_dim is greater than embedding_dim, this step will reshape the output to appropriate shape.

In [39]:
out = to_out(out)
out.shape

torch.Size([10, 1024, 128])

Return output, possibly with a final dropout - not standard to my knowledge.

In [40]:
post_attn_dropout(out).shape

torch.Size([10, 1024, 128])