In [None]:
from PIL import Image
from pathlib import Path
from fastai.basics import *
from fastai.text.all import *
from transformers import ReformerModelWithLMHead, ReformerTokenizerFast
from fastai.callback.wandb import *
#import wandb
import torch
from reformer_pytorch import Reformer
from torch.utils.checkpoint import get_device_states
TOKEN_SELF_ATTN_VALUE = -5e4 
from local_attention import LocalAttention
from functools import partial, reduce, wraps

In [None]:
from reformer_pytorch import LSHSelfAttention
from reformer_pytorch.reversible import ReversibleSequence

In [None]:
torch.cuda.set_device('cuda:1')

In [None]:
with no_random():
    model = Reformer(
        dim = 512,
        depth = 12,
        heads = 4,
        lsh_dropout = 0.1
    ).cuda()

    x = torch.randn(1, 8192, 512).cuda()
    print(model(x)) # (1, 8192, 512)

tensor([[[ 0.2265, -0.4902,  0.7183,  ..., -0.1998,  1.0268,  1.6641],
         [-0.1284,  0.0688,  0.2674,  ...,  0.8180, -1.5623,  1.0299],
         [ 0.9158,  0.6591, -1.8887,  ...,  0.4773, -0.2201,  1.6994],
         ...,
         [-1.3225,  1.1408, -0.4832,  ..., -0.0524,  0.6005,  1.3357],
         [-1.1027,  0.1178, -0.5195,  ...,  0.4283,  0.2834, -0.2728],
         [-0.1449, -3.3321, -1.4543,  ..., -0.5748,  1.1886, -0.0275]]],
       device='cuda:1')


In [None]:
seq_len=8192
lsh_attn=model.layers.blocks[0].f.net.fn.lsh_attn
lsh_attn.hash_vectors(seq_len // lsh_attn.bucket_size,torch.randn([4, 8192, 128]))

tensor([[  30,   28,   83,  ...,  934,  933,  916],
        [   3,   13,   18,  ...,  942,  980,  922],
        [  65,  101,  105,  ...,  913,  977,  930],
        [  56,   10,   15,  ..., 1019,  901,  962]])

#### Understanding `lsh_attn.hash_vectors`

We start off by defining the inputs

In [None]:
#def hash_vectors(self, n_buckets, vecs)
vecs=torch.randn([4, 8192, 128]) #[bs,seq_len,features]
n_buckets=seq_len // lsh_attn.bucket_size

We then set a few variables for easy use and validate `n_buckets`. `n_buckets` must be divisible by 2 as we are going to negate the vector "scores" latter, for the other half of the buckets.  

In [None]:
batch_size = vecs.shape[0]
device = vecs.device

# See https://arxiv.org/pdf/1509.02897.pdf
# We sample a different random rotation for each round of hashing to
# decrease the probability of hash misses.
assert n_buckets % 2 == 0 #we need positive + negative buckets, so must be multiple of 2

rot_size = n_buckets

We define the shape of the rotations.

In [None]:
rotations_shape = (
     1, #bs placeholder
    vecs.shape[-1], #features
    lsh_attn.n_hashes, 
    rot_size // 2) #same as n_buckets//2
print(rotations_shape)

(1, 128, 8, 64)


We now generate the random rotations. `expand` copies so that they are the same for each item. 

These are applied in a similar way to a matrix multiply :) 

Note: In order to guarantee reproducibility we use no_random here due to random number generation. 

In [None]:
with no_random():
    random_rotations = torch.randn(rotations_shape, dtype=vecs.dtype, device=device).expand(batch_size, -1, -1, -1)

Einsum is a way to explain matrix multiples in a very concise way, the main thing to notice is that f,features, is what disappears. So we get a "score" for all the for each compbination of 'bhti'

In [None]:
# shapes
# b = bs
# t = seq_len
# i = n_buckets//2
# h = n_hashes
# f = features
# features f is what is removed here, we end up with a "score" for each n_hash,seq_len,bucket
rotated_vecs = torch.einsum('btf,bfhi->bhti', vecs, random_rotations)

##### Begin Einsum explanation: If you are already familiar with einsum, please skip this explanation!!!! 

This is the shape we want in the end: 

In [None]:
rotated_vecs.shape

torch.Size([4, 8, 8192, 64])

This is the shapes we currently have, notice `random_rotations` has too many deminsions! 

In [None]:
vecs.shape,random_rotations.shape

(torch.Size([4, 8192, 128]), torch.Size([4, 128, 8, 64]))

So we flatten it! They are now compatible with a matrix multiply. 

In [None]:
random_rotations.flatten(start_dim=-2).shape

torch.Size([4, 128, 512])

In [None]:
(vecs@random_rotations.flatten(start_dim=-2)).shape

torch.Size([4, 8192, 512])

Ugh! Now we need to `unflatten` 512 :(

In [None]:
((vecs@random_rotations.flatten(start_dim=-2)).unflatten(-1,(8, 64))).shape

torch.Size([4, 8192, 8, 64])

Nooooo!!!! We still don't match the shape of `rotated_vecs` so we have to transpose T.T

In [None]:
(vecs@random_rotations.flatten(start_dim=-2)).unflatten(-1,(8, 64)).transpose(1,2).shape

torch.Size([4, 8, 8192, 64])

Lets check to see that it all worked: 

In [None]:
((vecs@random_rotations.flatten(start_dim=-2)).unflatten(-1,(8, 64)).transpose(1,2)==rotated_vecs).all()

tensor(True)

But Einsum is so much neater. 

In [None]:
(torch.einsum('btf,bfhi->bhti', vecs, random_rotations)==rotated_vecs).all()

tensor(True)

##### End Einsum explanation

Lets take a look at `rotated_vecs` to remind us where we are at: [n_hash,seq_len,bucket]

In [None]:
rotated_vecs.shape,rotated_vecs[0,:2,:2,:2]

(torch.Size([4, 8, 8192, 64]),
 tensor([[[-22.9653, -15.4491],
          [  0.5603,  -0.2839]],
 
         [[  5.3804,  -2.2699],
          [ -4.3336,   0.3952]]]))

We now negate the score for each bucket, large negative scores will become large positive scores

In [None]:
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)

We now take the argmax, to get the bucket for each bs,n_hash,seq_len

In [None]:
buckets = torch.argmax(rotated_vecs, dim=-1)

In [None]:
buckets.shape,buckets[0,:2,:2]

(torch.Size([4, 8, 8192]),
 tensor([[47, 86],
         [36, 61]]))

Problem!!!! For each n_hash we do not want overlap in bucket numbers, so lets add appropriate offsets to each `n_hashes` dim. 

In [None]:
offsets = torch.arange(lsh_attn.n_hashes, device=device)
offsets

tensor([0, 1, 2, 3, 4, 5, 6, 7])

Each hash gets its own range of buckets.

In [None]:
offsets = torch.reshape(offsets * n_buckets, (1, -1, 1)) 
offsets[...,0] #just for display reasons, the last dim has 1 value

tensor([[  0, 128, 256, 384, 512, 640, 768, 896]])

We now add the offsets to their associated n_hashes values, and flatten. 

In [None]:
buckets.shape #dim 1(8) is n_hashes

torch.Size([4, 8, 8192])

In [None]:
buckets = torch.reshape(buckets + offsets, (batch_size, -1,))
buckets.shape

torch.Size([4, 65536])

Check that we are correct. 

In [None]:
with no_random():
    check = lsh_attn.hash_vectors(n_buckets,vecs)
(buckets == check).all()

tensor(True)

#### Chunking

After hashing we can order by the hash and then chunk, and then attend to vectors in this or the previous chunk, so we attend to a maximum of 2*chunksize. Not going to go indepth on this one.

In [None]:
 with torch.no_grad():
    info={}
    chunked=torch.chunk(x,64,dim=-1)
    info['type']=type(chunked)
    info['0shape']=chunked[0].shape
    info['len']=len(chunked)
    print('chunked:',info)
    print('x.shape:',x.shape)

chunked: {'type': <class 'tuple'>, '0shape': torch.Size([1, 8192, 8]), 'len': 64}
x.shape: torch.Size([1, 8192, 512])


#### Reversible Networks

Reversible Networks allows us to calculate the input to a layer on the backward pass without storing the activations, and without using gradient checkpointing. This does however require us to split x into x1 and x2, and y into y1 and y2. 

This is the `forward` pass. Please notice that:
<ul>
<li>x2 = y2 - self.g(y1, **g_args)
<li>x1 = y1 - self.f(x2, **f_args)
</ul>
So we can get back x1, and x2 given f,g,y1,and y2

In [None]:
targ=torch.ones(x.shape[:-1],dtype=int).cuda()
targ

tensor([[1, 1, 1,  ..., 1, 1, 1]], device='cuda:1')

In [None]:
with no_random():
    x = torch.randn(1, 8192, 512).cuda()
    f=nn.Linear(256,256).cuda() #doesn't matter too much here
    g=nn.Linear(256,256).cuda()

This is what the forward looks like, lets just do two :) (with same weights)

In [None]:
def forward_pass(x):
    x1, x2 = torch.chunk(x, 2, dim=2)
    y1, y2 = None, None

    with torch.no_grad():
        y1 = x1 + f(x2)
        y2 = x2 + g(y1)

    return torch.cat([y1, y2], dim=2)
y=forward_pass(x)
y=forward_pass(y)

In [None]:
y.shape

torch.Size([1, 8192, 512])

To keep this simple I am going to run this through the loss function now, this would be where you would add your regular classification or LM head. We transpose as crossentropy expects the second deminsion to be the label deminsion. 

In [None]:
loss = F.cross_entropy(y.transpose(-1,-2),targ)
loss

tensor(7.7459, device='cuda:1')

Since we are doing this super by hand, we calculate the gradient now. 

In [None]:
def cross_entropy_grad(y,targ):
    with torch.no_grad():
        return (F.softmax(y,dim=2)-F.one_hot(targ,num_classes=y.shape[-1]))/y.shape[0]/y.shape[1]

In [None]:
dy=cross_entropy_grad(y,targ)

Next is the `backward_pass`. The full function is below.

We start by setting up our variables

In [None]:
#method head
#def backward_pass(self, y, dy, f_args = {}, g_args = {}):

In [None]:
def split_y(y,dy):
    y1, y2 = torch.chunk(y, 2, dim=2)
    dy1, dy2 = torch.chunk(dy, 2, dim=2)
    return y1,y2,dy1,dy2
y1,y2,dy1,dy2=split_y(y,dy)

We have `g`,`y1` and `dy2`, so we cab calculate the gradients for `g(y1)` because `y2= g(y1)+x2`. 

In [None]:
def update_g_grad(g,y1,dy2):
    with torch.enable_grad():
        y1.requires_grad = True
        gy1 = g(y1)
        torch.autograd.backward(gy1, dy2)
    return gy1
gy1=update_g_grad(g,y1,dy2)

Next we can calculate `x2`. `dx1` We can calculate `dx1` because we ran `torch.autograd.backward` above.

In [None]:
def getx2dx1(y1,y2,dy1, gy1):
    with torch.no_grad():
        x2 = y2 - gy1
        del y2, gy1

        dx1 = dy1 + y1.grad
        del dy1
        y1.grad = None
    return x2,dx1
x2,dx1=getx2dx1(y1,y2,dy1, gy1)

We can run the backward pass for `f(x2)` now that we have `x2`

In [None]:
def update_f_grad(f,x2):
    with torch.enable_grad():
        x2.requires_grad = True
        fx2 = f(x2)
        torch.autograd.backward(fx2, dx1, retain_graph=True)
    return fx2
fx2=update_f_grad(f,x2)

We can calculate the `x1` since `y1 = x1 + f(x2)`. `dx2` can be determined because of the above backward pass. 

In [None]:
def get_xdx(fx2,x2,y1,dx1,dy2):
    with torch.no_grad():
        x1 = y1 - fx2
        del y1, fx2

        dx2 = dy2 + x2.grad
        del dy2
        x2.grad = None

        x = torch.cat([x1, x2.detach()], dim=2)
        dx = torch.cat([dx1, dx2], dim=2)
        return x,dx
x,dx=get_xdx(fx2,x2,y1,dx1,dy2)

We have now completed one backward pass for a ReversibleBlock! Now lets do it for the second block. 

In [None]:
y,dy=x,dx
y1,y2,dy1,dy2=split_y(y,dy)
gy1=update_g_grad(g,y1,dy2)
x2,dx1=getx2dx1(y1,y2,dy1, gy1)
fx2=update_f_grad(f,x2)
x,dx=get_xdx(fx2,x2,y1,dx1,dy2)

Save the gradients so we can verify our results. 

In [None]:
f_grads=list(map(lambda p:p.grad.detach(),f.parameters()))
g_grads=list(map(lambda p:p.grad.detach(),g.parameters()))

In [None]:
from reformer_pytorch.reversible import ReversibleBlock

Re-initialize variables for repeatability~

In [None]:
with no_random():
    x = torch.randn(1, 8192, 512,requires_grad=True).cuda()
    f=nn.Linear(256,256).cuda() #doesn't matter too much here
    g=nn.Linear(256,256).cuda()

In [None]:
#blocks=nn.Sequential(ReversibleBlock(f,g),ReversibleBlock(f,g))
blocks=ReversibleSequence([(f,g),(f,g)])

In [None]:
y=blocks(x)
y.retain_grad()

In [None]:
loss = F.cross_entropy(y.transpose(-1,-2),targ)
print(loss)

tensor(7.7459, device='cuda:1', grad_fn=<NllLoss2DBackward>)


In [None]:
loss.backward()

Test that reversible blocks are implemented correctly. 

In [None]:
#works with lists of uneven tensors
def is_close(a,b,eps=1e-5):
    return all([((a_-b_)<1e-5).all() for a_,b_ in zip(a,b)])

In [None]:
is_close(f_grads,list(map(lambda p:p.grad,f.parameters())))

True

In [None]:
is_close(g_grads,list(map(lambda p:p.grad,g.parameters())))

True