In [1]:
from fastai.vision.all import *
from einops import rearrange, repeat, reduce
from functools import wraps
import pdb

This is a simplified implementation of LSH-attention based on [lucidrains](https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reformer_pytorch.py).


**Not implemented:**
* Direct attention masking. This means passing in a [sl x sl] mask for the attention matrix (**Not** the same as casual masking)
* Variable query length (query length is always sequence length)
* Attention across buckets (we assume attention only within bucket - as suggested in paper)

**Note:**
* For a detailed walkthru of the code, see this exploration [notebook](https://github.com/morganmcg1/reformer-fastai/blob/main/exploration/LSH-attention.ipynb)
* 8 hash rounds with 64 buckets seems is suggested in paper/trax
* We'll assume rehashing in each hash round as in the trax library. Otherwise we have to reuse our hashes in some way.
* We'll assume the same random rotations for all items in a batch (i.e. copy rotations across batch dimension). Lucidrains' source code refers to this as "random_rotations_per_head" which is confusing imo. This simply refers to if we copy or draw new random rotations across the batch dimension. Note that each head gets it's own LSHAttention module and it's "own" random rotations.
* Normalize **key vectors** (but not q vectors)
* Attention is optionally returned, but unsorted. Need to take care of unsorting before use.
* Duplicate attention defaults to False. That means that k/q pairs that end up in the same attention chunk across hash rounds get's penalized with the log of the count.
* Summation of duplicate attention is done with chunked summations for memory concerns. But same result to normal summation.
* Mask value is `TOKEN_SELF_ATTN_VALUE = -5e4`, lucidrains notes that "carefully set for half precision to work"

**Masking**:
* Other hash buckets (LSH-specific)
* mask self/self pairs (LSH-specific), i.e. k and q at position i can't attend to each other
* pad mask
* casual masking

**Code**:
* Uses fastai Module - no super().__init__()
* uses fastcore store_attr() to store self.attributes

## helpers

From: https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reformer_pytorch.py

In [2]:
# Allow each chunk to attend within itself, and also one chunk back. Chunk
# boundaries might occur in the middle of a sequence of items from the
# same bucket, so this increases the chances of attending to relevant items.
def look_one_back(x):
    x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1)
    return torch.cat([x, x_extra], dim=2)

In [3]:
def sort_key_val(t1, t2, dim=-1):
    values, indices = t1.sort(dim=dim)
    t2 = t2.expand_as(t1)
    return values, t2.gather(dim, indices)

In [4]:
def default(val, default_val):
    return default_val if val is None else val

In [5]:
def batched_index_select(values, indices):
    last_dim = values.shape[-1]
    return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))

In [6]:
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

In [7]:
def chunked_sum(tensor, chunks=1):
    *orig_size, last_dim = tensor.shape
    tensor = tensor.reshape(-1, last_dim)
    summed_tensors = [c.sum(dim=-1) for c in tensor.chunk(chunks, dim=0)]
    return torch.cat(summed_tensors, dim=0).reshape(orig_size)

In [8]:
def cache_method_decorator(cache_attr, cache_namespace, reexecute = False):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, key_namespace=None, fetch=False, set_cache=True, **kwargs):
            namespace_str = str(default(key_namespace, ''))
            _cache = getattr(self, cache_attr)
            _keyname = f'{cache_namespace}:{namespace_str}'

            if fetch:
                val = _cache[_keyname]
                if reexecute:
                    fn(self, *args, **kwargs)
            else:
                val = fn(self, *args, **kwargs)
                if set_cache:
                    setattr(self, cache_attr, {**_cache, **{_keyname: val}})
            return val
        return wrapper
    return inner_fn


In [9]:
TOKEN_SELF_ATTN_VALUE = -5e4 # carefully set for half precision to work

## LSHAttention

In [10]:
class LSHAttention(Module):
    def __init__( self,
                  dropout = 0.,                       # attention matrix dropout
                  bucket_size = 64,                   # at least 64 suggested in trax
                  n_hashes = 8,                       # papers sugests 8
                  causal = False,
                  allow_duplicate_attention = False,  # as in the paper
                  attend_across_buckets = False,      # as in the paper
                  drop_for_hash_rate = 0.0,           # unsure of default, not mentioned in paper
                  return_attn = False):
        
        if dropout >= 1.0 or drop_for_hash_rate >=1.0:
            raise ValueError('Dropout rates must be lower than 1.')
        
        store_attr(but=['dropout', 'drop_for_hash_rate'])  # fastcore - store attibutes
        self.dropout = nn.Dropout(dropout)
        self.dropout_for_hash = nn.Dropout(drop_for_hash_rate)
        self._cache = {} # cache buckets for reversible network, required to make Reformer work at depth

    @cache_method_decorator('_cache', 'buckets', reexecute=True)
    def hash_vectors(self, n_buckets, vecs):
        # 0. We need an even number of buckets: 
        assert n_buckets % 2 == 0

        # 1. account for the input shapes. vecs = [bs, sl, dim]
        batch_size, seqlen, dim = vecs.shape
        device = vecs.device
        rotations_shape = (dim, self.n_hashes, n_buckets // 2)

        # 2. Calculate hash bucket id via random rotations, concatenation and argmax 
        # note: we copy rotations accross batch dimension (see exploration notebook for details). 
        random_rotations = repeat(torch.randn(rotations_shape), 
                                  'd nh nb -> bs d nh nb', bs=batch_size)           
        dropped_vecs = self.dropout_for_hash(vecs)
                       
        rotated_vecs = torch.einsum('bsd,bdhn->bhsn', 
                                    dropped_vecs,       # [bs, sl, dim]
                                    random_rotations)   # [bs, dim, n_hashes, n_buckets//2]
                                                        # rotated vecs: [bs, n_hashes, sl, n_buckets//2]

        rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) # [bs, n_hashes, sl, n_buckets]
        buckets = torch.argmax(rotated_vecs, dim=-1)                    # [bs, n_hashes, sl] 

        # 3. Next we add offsets so that bucket numbers from different hashing rounds don't overlap.
        # We also reshape the buckets so that each hash round is concatenated along the -1 dim
        offsets = torch.arange(self.n_hashes)                              # list of [0,1,2,..n_hashes-1]
        offsets = rearrange(offsets * n_buckets, 'nh -> 1 nh 1')        # [1, n_hashes, 1]
        buckets = rearrange(buckets+offsets, 'bs nh sl -> bs (nh sl)')  # [bs, (n_hashes*sl)]
        return buckets

    def forward(self, qk, v, input_mask = None, **kwargs):
        batch_size, seqlen, dim, device = *qk.shape, qk.device

        # caching
        is_reverse = kwargs.pop('_reverse', False)
        depth = kwargs.pop('_depth', None)
        
        # We will have an even number of buckets, and our attention chunks needs to fit completely within a seqlen
        assert seqlen % (self.bucket_size * 2) == 0, f'Sequence length ({seqlen}) needs to be divisible by target bucket size  x 2 - {self.bucket_size * 2}'
        
        # get the hash buckets for our qk input vectors
        n_buckets = seqlen // self.bucket_size
        buckets = self.hash_vectors(n_buckets, qk, key_namespace=depth, fetch=is_reverse, set_cache=self.training)

        # We use the same vector as both a query and a key.
        assert int(buckets.shape[1]) == self.n_hashes * seqlen
        
        # Create an index that reflexts both bucket id and sequence id. This let's us sort qk according 
        # to both simultaneously. Repeated across the batch dimension.
        ticker = repeat(torch.arange(self.n_hashes * seqlen), 'l -> bs l', bs=batch_size)
        buckets_and_t = seqlen * buckets + (ticker % seqlen) 
        buckets_and_t = buckets_and_t.detach()                # [bs, seqlen*n_hashes]

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1)  # [bs, seqlen*n_hashes]
        _, undo_sort = sticker.sort(dim=-1)                                    # indexes to undo sortings
        del ticker

        sbuckets_and_t = sbuckets_and_t.detach()   # no need to store gradiens for indexes
        sticker = sticker.detach()
        undo_sort = undo_sort.detach()

        st = (sticker % seqlen)             # index of [0..seqlen-1] for each hash round
        sqk = batched_index_select(qk, st)  # get the sorted qk, [bs, seqlen*n_hashes, dim]
        sv = batched_index_select(v, st)    # get the sorted v, [bs, seqlen*n_hashes, dim] 

        # Reshape to include a n_chunks axis.
        n_chunks = self.n_hashes * n_buckets
        bq_t = bkv_t = rearrange(st, 'bs (n s) -> bs n s', n=n_chunks) # [bs, n_chunks, chunk_size]
        bqk = rearrange(sqk, 'bs (n s) d -> bs n s d', n=n_chunks)     # [bs, n_chunks, chunk_size, dim]
        bv = rearrange(sv, 'bs (n s) d -> bs n s d', n=n_chunks)       # [bs, n_chunks, chunk_size, dim]

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = F.normalize(bqk, p=2, dim=-1).type_as(bq)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bucket, so this increases the chances of attending to relevant items.
        # Note: no look_back for queries

        bk = look_one_back(bk)        # [bs, n_chunks, chunk_size*2, dim]
        bv = look_one_back(bv)        # [bs, n_chunks, chunk_size*2, dim]
        bkv_t = look_one_back(bkv_t)

        # Dot-product attention.
        dots = torch.einsum('bnsd,bnzd->bnsz', 
                    bq,                  # [bs, n_chunks, chunk_size, dim]
                    bk                   # [bs, n_chunks, chunk_size*2, dim]
                   ) * (dim ** -0.5)     # dots: [bs, n_chunks, chunk_size, chunk_size*2]
        masked_value = max_neg_value(dots)

        # Input mask for padding in variable lengthed sequences
        if input_mask is not None:
            input_mask = F.pad(input_mask, (0, seqlen - input_mask.shape[1]), value=True)
            mq = input_mask.gather(1, st).reshape((batch_size, n_chunks, -1))
            mkv = look_one_back(mq)
            mask = mq[:, :, :, None] * mkv[:, :, None, :]
            dots.masked_fill_(~mask, masked_value)
            del mask

        # Causal masking
        if self.causal:
            mask = bq_t[:, :, :, None] < bkv_t[:, :, None, :]
            dots.masked_fill_(mask, masked_value)
            del mask

        # Mask out attention to self except when no other targets are available.
        self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :]
        dots.masked_fill_(self_mask, TOKEN_SELF_ATTN_VALUE)
        del self_mask

        # Mask out attention to other hash buckets.
        if not self.attend_across_buckets:
            bq_buckets = bkv_buckets = torch.reshape(sbuckets_and_t // seqlen, (batch_size, n_chunks, -1))
            bkv_buckets = look_one_back(bkv_buckets)
            bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :, None, :]
            dots.masked_fill_(bucket_mask, masked_value)
            del bucket_mask

        # Don't double-count query-key pairs across multiple rounds of hashing.
        # There are two possible strategies here. (1) The default is to count how
        # many times a query-key pair is repeated, and to lower its log-prob
        # correspondingly at each repetition.
        
        if not self.allow_duplicate_attention:
            locs1 = undo_sort // bq_t.shape[-1]
            locs2 = (locs1 + 1) % n_chunks
            if not self.attend_across_buckets:
                locs1 = buckets * n_chunks + locs1
                locs2 = buckets * n_chunks + locs2
            locs = torch.cat([
                torch.reshape(locs1, (batch_size, self.n_hashes, seqlen)),
                torch.reshape(locs2, (batch_size, self.n_hashes, seqlen)),
            ], 1).permute((0, 2, 1))

            slocs = batched_index_select(locs, st)
            b_locs = torch.reshape(slocs, (batch_size, n_chunks, -1, 2 * self.n_hashes))

            b_locs1 = b_locs[:, :, :, None, :self.n_hashes]

            bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, self.n_hashes))
            bq_locs = torch.reshape(bq_locs, b_locs.shape)
            bkv_locs = look_one_back(b_locs)

            dup_counts = (bq_locs[:, :, :, None, :] == bkv_locs[:, :, None, :, :])
            # for memory considerations, chunk summation of last dimension for counting duplicates
            dup_counts = chunked_sum(dup_counts, chunks=(self.n_hashes * batch_size))
            dup_counts = dup_counts.detach()
            assert dup_counts.shape == dots.shape
            dots = dots - torch.log(dup_counts + 1e-9)
            del dup_counts

        # Softmax.
        dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True)
        dots = torch.exp(dots - dots_logsumexp).type_as(dots)
        dropped_dots = self.dropout(dots)
        
        # calculate self-attention (attn * values)
        bo = torch.einsum('bnsz,bnzd->bnsd', 
                          dropped_dots,      # [bs, n_chunks, chunk_size, chunk_size*2]
                          bv)                # [bs, n_chunks, chunk_size*2, dim]    
                                             # bo: [bs, n_chunks, chunk_size, dim]
        
        # unchunk, unsort and reshape self-attention
        so = rearrange(bo, 'b n s d -> b (n s) d')                     # [bs, seqlen*n_hashes, dim]
        o = batched_index_select(so, undo_sort)                        # [bs, seqlen*n_hashes, dim]
        o = rearrange(o, 'b (nh sl) d -> b nh sl d', nh=self.n_hashes) # [bs, n_hashes, seqlen, dim]
        
        # unchunk, unsort and reshape logits
        slogits = rearrange(dots_logsumexp, 'bs n s 1 -> bs (n s)')              # [bs, seqlen*n_hashes]
        logits = slogits.gather(1, undo_sort)                                    # [bs, seqlen*n_hashes]
        logits = rearrange(logits, 'bs (nr sl) -> bs nr sl 1', nr=self.n_hashes) # [bs, n_hashes, seqlen, 1]
        
        # average probabilites across hash rounds (dim 1) and get weighted attention
        probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdim=True)) # [bs, n_rounds, seqlen, 1]
        out = torch.sum(o * probs, dim=1)                                        # [bs, seqlen, dim]

        # return unsorted attention weights - empty otherwise
        attn = torch.empty(0, device=device)
        if self.return_attn:
            attn_unsort = ((bq_t * seqlen)[:, :, :, None] + bkv_t[:, :, None, :])
            attn_unsort = attn_unsort.view(batch_size * self.n_hashes, -1).long()
            unsorted_dots = torch.zeros(batch_size * self.n_hashes, seqlen * seqlen, device=device)
            unsorted_dots.scatter_add_(1, attn_unsort, dots.view_as(attn_unsort))
            del attn_unsort
            unsorted_dots = unsorted_dots.reshape(batch_size, self.n_hashes, seqlen, seqlen)
            attn = torch.sum(unsorted_dots * probs, dim=1)

        # return output, attention matrix, and bucket distribution
        return out, attn, buckets

Testing with random data:

In [11]:
qk = torch.randn(64, 512, 128)
v = torch.rand(64, 512, 128)
lsh_att = LSHAttention()
out, attn, buckets = lsh_att(qk, v)
out.shape, buckets.shape

(torch.Size([64, 512, 128]), torch.Size([64, 4096]))