In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext
from torchtext.datasets import TranslationDataset, Multi30k
from torchtext.data import Field, BucketIterator

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import spacy
import numpy as np

import random
import math
import time

from reformer_pytorch import LSHAttention,ReformerLM

%load_ext autoreload
%autoreload 2
%load_ext jupyter_spaces


In [2]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
spacy_de = spacy.load('de_core_news_sm')
spacy_en = spacy.load('en_core_web_sm')

In [4]:
def tokenize_de(text):
    """
    Tokenizes German text from a string into a list of strings
    """
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [5]:
SRC = Field(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            batch_first = True)

TRG = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            batch_first = True)

In [6]:
train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), 
                                                    fields = (SRC, TRG))

In [10]:
model = ReformerLM(
    num_tokens= 20000,
    dim = 1024,
    depth = 12,
    max_seq_len = 8192,
    heads = 8,
    lsh_dropout = 0.1,
    ff_dropout = 0.1,
    post_attn_dropout = 0.1,
    layer_dropout = 0.1,  # layer dropout from 'Reducing Transformer Depth on Demand' paper
    causal = True,        # auto-regressive or not
    bucket_size = 64,     # average size of qk per bucket, 64 was recommended in paper
    n_hashes = 4,         # 4 is permissible per author, 8 is the best but slower
    emb_dim = 128,        # embedding factorization for further memory savings
    dim_head = 64,        # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    ff_chunks = 200,      # number of chunks for feedforward layer, make higher if there are memory issues
    attn_chunks = 8,      # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
    num_mem_kv = 128,       # persistent learned memory key values, from all-attention paper
    full_attn_thres = 1024, # use full attention if context length is less than set value
    reverse_thres = 1024,   # turn off reversibility for 2x speed for sequence lengths shorter or equal to the designated value
    use_scale_norm = False,  # use scale norm from 'Transformers without tears' paper
    use_rezero = False,      # remove normalization and use rezero from 'ReZero is All You Need'
    one_value_head = False,  # use one set of values for all heads from 'One Write-Head Is All You Need'
    weight_tie = False,           # tie parameters of each layer for no memory per additional depth
    weight_tie_embedding = False, # use token embedding for projection of output, some papers report better results
    n_local_attn_heads = 2,       # many papers suggest mixing local attention heads aids specialization and improves on certain tasks
    pkm_layers = (4,7),           # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
    pkm_num_keys = 128,           # defaults to 128, but can be increased to 256 or 512 as memory allows
    use_full_attn = False    # only turn on this flag to override and turn on full attention for all sequence lengths. for comparison with LSH to show that it is working
).cuda()

x = torch.randint(0, 20000, (1, 8192)).long().cuda()
y = model(x) # (1, 8192, 20000)

torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])
torch.Size([1, 4, 8320])


In [14]:
attn = LSHAttention(
    bucket_size = 64,
    n_hashes = 16,
    causal = True,
    return_attn = True
)

qk = torch.randn(10, 1024, 128)
v = torch.randn(10, 1024, 128)

out, attn, buckets = attn(qk, v) # (10, 1024, 128)
# attn contains the unsorted attention weights, provided return_attn is set to True (costly otherwise)
# buckets will contain the bucket number (post-argmax) of each token of each batch

aaa
