In [14]:
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import math
import os
import random
import tqdm
import gzip

!pip3 install einops
from einops import rearrange, repeat, pack, unpack, einsum
from einops.layers.torch import Rearrange


from functools import partial, wraps
from contextlib import contextmanager, ExitStack
from pathlib import Path
from filelock import FileLock
import pickle

import transformers
from transformers import AutoTokenizer

!pip3 install faiss-gpu
import faiss

!pip3 install datasets
import datasets




In [15]:
sequence_length = 5
sequence_pos = torch.arange(sequence_length, dtype=torch.long)
context_pos = torch.arange(2*sequence_length, dtype=torch.long)
#context_pos = torch.arange(-sequence_length, sequence_length, dtype=torch.long)
sequence_rel_pos = rearrange(sequence_pos, 'i -> i 1')
context_rel_pos = rearrange(context_pos, 'j -> 1 j')
rel_pos = context_rel_pos - sequence_rel_pos
rel_pos

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [-1,  0,  1,  2,  3,  4,  5,  6,  7,  8],
        [-2, -1,  0,  1,  2,  3,  4,  5,  6,  7],
        [-3, -2, -1,  0,  1,  2,  3,  4,  5,  6],
        [-4, -3, -2, -1,  0,  1,  2,  3,  4,  5]])

In [17]:
# # Outline

# # Embedding
# token_ids = tokenizer(raw_text)
# x = embedding(token_ids)

# # BLOCK x n (layers)

# # Attention
# residual = x.copy()
# x = layernorm(x)
# x = attention(x) # XL, KNN_XL
# x = x + residual

# # Feedforward
# residual = x.copy()
# x = layernorm(x)
# x = linear(x)
# x = activation(x)
# x = dropout(x)
# x = linear_2(x)
# x = x + residual


# # Output
# x = layernorm(x)
# token_ids = embedding_reverse(x)
# loss = cross_entropy(token_ids, labels)

In [19]:
class MemorizingTransformer(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        vocab_size,
        heads = 8,
        depth = 10,
        dropout = 0,
        head_dimension = 64,
        max_knn_memories = 32000,
        topk = 5,

    ):
        super().__init__()
        self.heads = heads
        self.embedding_dimension = embedding_dimension
        self.dropout = dropout
        self.depth = depth
        self.head_dimension = head_dimension
        self.max_knn_memories = max_knn_memories
        self.topk = topk

        self.rel_pos = RelativePosition(rp_scale = head_dimension** 0.5,
                                        heads = self.heads)
        self.rel_pos_knn = RelativePosition(rp_scale = head_dimension** 0.5,
                                        heads = self.heads)
        self.embedding_matrix = nn.Embedding(vocab_size, self.embedding_dimension)

        self.knn = KNN(head_dimension * heads, self.max_knn_memories)

        self.layers = nn.ModuleList([])

        for i in range(self.depth):

            if i == self.depth-2:
                layer_knn = self.knn
            else:
                layer_knn = None

            self.layers.append(Block(layer_knn))

        self.to_logits = nn.Sequential(
            nn.LayerNorm(self.embedding_dimension),
            nn.Linear(self.embedding_dimension, vocab_size)
        )

    def forward(
        self,
        x,
        relative_positions = None,
        xl_memories = None,
        labels = None,
    ):

        batch_size, sequence_length = x.shape[0], x.shape[1]

        # Position values
        rel_pos = self.rel_pos(sequence_length)
        rel_pos_knn = self.rel_pos_knn(sequence_length)

        if xl_memories is not None:
            xl_memories = xl_memories
        else:
            xl_memories = (None,) * self.depth # if we're in first chunk of document

        # Iterator
        xl_memories_iter = iter(xl_memories)

        # Store the XL memories for each pass
        new_xl_memories = []

        # Embeddings
        x = self.embedding_matrix(x)

        for ind, block in enumerate(self.layers):

            if ind == self.depth-2:
                layer_rel_pos = rel_pos_knn
            else:
                layer_rel_pos = rel_pos

            x, xl_mem = block(x, next(xl_memories_iter), layer_rel_pos)

            if xl_mem is not None:
                new_xl_memories.append(xl_mem)

        logits = self.to_logits(x)

        loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels)
        if len(new_xl_memories) > 0:
            return loss, new_xl_memories
        return loss





## Final Architecture

In [20]:
class RelativePosition(nn.Module):
  def __init__(
      self,
      rp_scale,
      num_buckets = 32,
      rp_max_distance = 128,
      heads = 8
  ):
      super().__init__()
      self.scale = rp_scale
      self.num_buckets = num_buckets
      self.rp_max_distance = rp_max_distance
      self.relative_attention_embedding = nn.Embedding(num_buckets, heads)

  def relative_position_bucket(self, relative_position_matrix):
      n = -relative_position_matrix
      n = torch.max(n, torch.zeros_like(n))

      max_exact = self.num_buckets // 2

      is_small = n < max_exact
      val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(self.rp_max_distance / max_exact) * (self.num_buckets - max_exact)).long()
      val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, self.num_buckets - 1))

      return torch.where(is_small, n, val_if_large)

  def forward(self, sequence_length):

      sequence_pos = torch.arange(sequence_length, dtype=torch.long)
      context_pos = torch.arange(2 * sequence_length, dtype=torch.long)
      sequence_rel_pos = rearrange(sequence_pos, 'i -> i 1')
      context_rel_pos = rearrange(context_pos, 'j -> 1 j')
      rel_pos = context_rel_pos - sequence_rel_pos

      position_bucket_indices = self.relative_position_bucket(rel_pos)

      rp_values = self.relative_attention_embedding(position_bucket_indices)
      rp_values = rearrange(rp_values, 'i j h -> () h i j')
      return rp_values * self.scale



class KNN():
    def __init__(
        self,
        dim,
        max_memories,
        ):
        self.dim = dim
        self.max_memories = max_memories
        self.shape = (max_memories, 2, dim)
        self.db_offset = 0
        self.db_filepath = "./memory.memmap"
        self.db = np.memmap(self.db_filepath, mode = 'w+', dtype = np.float32, shape = self.shape)
        self.index = faiss.IndexFlatL2(dim)


    def add_to_db(self, new_data):
        new_data_len = new_data.shape[0]
        ids = (np.arange(new_data_len) + self.db_offset)
        self.db[ids] = new_data.detach().numpy()
        self.db_offset += new_data_len
        # Write to file
        self.db.flush()


    def search_and_retrieve(self, query_vecs, topk):
        query_vecs = query_vecs
        distances, indices = self.index.search(query_vecs, topk)
        kvs = self.db[indices]
        return kvs

    def add(self, new_data):
        # Input is b n 2 d, flatten to (b n) 2 d
        new_data = new_data.flatten(0,1)
        # Add to db
        self.add_to_db(new_data)
        # Only keys are used in knn index
        keys, vals = new_data.unbind(dim=-2)
        keys = keys.detach().numpy()
        # Add (b n) d tensors to index
        keys = np.ascontiguousarray(keys)
        # Add to index
        self.index.add(keys)

    def search(self, query_vecs, topk):
        # can override topk
        query_batch_size, query_seq_len = query_vecs.shape[0], query_vecs.shape[1]
        # Input is b n d, flatten to (b n) d
        query_vecs = query_vecs.flatten(0,1)
        kvs = self.search_and_retrieve(np.ascontiguousarray(query_vecs.detach().numpy()), topk)
        # kvs are (b n) k 2 d, unflatten to b n k 2 d
        kvs = torch.tensor(kvs)
        kvs = torch.unflatten(kvs, 0, (query_batch_size, query_seq_len))
        return kvs


    def clear(self):
        self.index.reset()
        self.db[:] = 0
        self.db_offset = 0


class XLAttention(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        heads = 8,
        head_dimension = 64,
        dropout = 0.,
    ):
        super().__init__()
        self.heads = heads
        self.dropout = nn.Dropout(dropout)
        self.scale = head_dimension ** -0.5

        self.query_matrix = nn.Linear(embedding_dimension, self.heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, self.heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, self.heads * head_dimension)
        self.output_matrix = nn.Linear(self.heads * head_dimension, embedding_dimension)

    def forward(
        self,
        x, # batch_size, sequence_length, embedding_dimension
        relative_positions = None,
        xl_memory = None
    ):

        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        queries = queries * self.scale

        if xl_memory is not None:
            k_xl, v_xl = xl_memory.unbind(dim = -2) # assume stacked
            keys = torch.cat((k_xl, keys), dim = -2) # prepend XL memory
            values = torch.cat((v_xl, values), dim = -2) # prepend XL memory
            xl_sequence_length = k_xl.shape[1]

        queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
        keys    = rearrange(keys, 'b t (h d) -> b h t d', h = self.heads)
        qk      = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

        i, j = qk.shape[-2:]
        if relative_positions is not None:
            qk = relative_positions[..., -i:, -j:] + qk

        qk = qk * self.scale

        mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)
        qk = self.dropout(qk)

        values = rearrange(values, 'b t (h d) -> b h t d', h=self.heads)
        qkv = qk@values
        qkv = rearrange(qkv, 'b h t d -> b t (h d)')

        out = self.output_matrix(qkv)

        # new XL memories

        keys = rearrange(keys, 'b h t d -> b t (h d)', h = self.heads)
        values = rearrange(values, 'b h t d -> b t (h d)', h=self.heads)
        kv_memories = torch.stack((keys, values), dim=-2) # (batch, sequence_len, 2, dimension)


        if xl_memory is not None:
            xl_memories, current_input = kv_memories[:, :-xl_sequence_length], kv_memories[:, -xl_sequence_length:]
            kv_to_add_xl = current_input
        else:
            kv_to_add_xl = kv_memories

        return out, kv_to_add_xl



class KNNAttention(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        knn,
        heads = 8,
        head_dimension = 64,
        topk_retrieved_memories = 3,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = head_dimension ** -0.5
        self.dropout = nn.Dropout(dropout)

        self.query_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.output_matrix = nn.Linear(heads * head_dimension, embedding_dimension)

        self.gate_bias = nn.Parameter(torch.randn(self.heads, 1, 1))
        self.topk_retrieved_memories = topk_retrieved_memories
        self.knn = knn

    def forward(
        self,
        x, # batch_size, sequence_length, embedding_dimension
        relative_positions = None,
        xl_memory = None
    ):
        batch_size, sequence_length = x.shape[:2]
        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        queries = F.normalize(queries, dim=-1)
        keys = F.normalize(keys, dim=-1)

        if xl_memory is not None:
            k_xl, v_xl = xl_memory.unbind(dim = -2) # unstack
            keys = torch.cat((k_xl, keys), dim = -2) # prepend XL memory
            values = torch.cat((v_xl, values), dim = -2) # prepend XL memory
            xl_sequence_length = k_xl.shape[1]

        ### LOCAL ATTENTION

        queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
        keys    = rearrange(keys, 'b t (h d) -> b h t d', h = self.heads)
        qk      = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

        i, j = qk.shape[-2:]
        if relative_positions is not None:
            qk = relative_positions[..., -i:, -j:] + qk

        qk = qk * self.scale

        mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)

        qk = self.dropout(qk)

        values = rearrange(values, 'b t (h d) -> b h t d', h=self.heads)
        qkv = qk@values

        ### KNN ATTENTION

        # If there are knn memories (we're not on the first segment) then perform knn attention
        if self.knn.index.ntotal > 0:
            # Convert queries to search form
            queries = rearrange(queries, 'b h t d -> b t (h d)')
            mem_kv = self.knn.search(queries, topk = self.topk_retrieved_memories) # returns b t k 2 d
            mem_k, mem_v = mem_kv.unbind(dim = -2)
            mem_k = rearrange(mem_k, 'b t k (h d) -> b h t k d', h=self.heads)
            mem_v = rearrange(mem_v, 'b t k (h d) -> b h t k d', h=self.heads)

            # Convert queries to attention form
            queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
            mem_qk = einsum(queries, mem_k, 'b h t d, b h t k d -> b h t k')
            mem_qk = mem_qk * self.scale

            mem_qk = F.softmax(mem_qk, dim=-1)
            mem_qk = self.dropout(mem_qk)
            mem_qkv = einsum(mem_qk, mem_v, 'b h t k, b h t k d -> b h t d')

            # Combined attentions

            combined_qkv = mem_qkv * self.gate_bias + qkv * (1 - self.gate_bias)
            combined_qkv = rearrange(combined_qkv, 'b h t d -> b t (h d)')
            out = self.output_matrix(combined_qkv)

        else:
            qkv = rearrange(qkv, 'b h t d -> b t (h d)')
            out = self.output_matrix(qkv)

        # New XL memories
        keys = rearrange(keys, 'b h t d -> b t (h d)', h = self.heads)
        values = rearrange(values, 'b h t d -> b t (h d)', h=self.heads)
        kv_memories = torch.stack((keys, values), dim=-2) # (batch, sequence_len, 2, dimension)

        if xl_memory is not None:
            # if we're on a middle/end segment of a document (there are previous XL memories)
            xl_memories, current_kv = kv_memories[:, :-xl_sequence_length], kv_memories[:, -xl_sequence_length:]
        else:
            # if we're at the first segment
            current_kv = kv_memories

        self.knn.add(current_kv)

        return out, current_kv


class Block(nn.Module):
    def __init__(self, embedding_dimension, attention_type, dropout=0.):
        super().__init__()
        self.attention = attention_type
        self.dim = embedding_dimension

        self.ff_block = nn.Sequential(
            nn.LayerNorm(self.dim),
            nn.Linear(self.dim, self.dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.dim * 4, self.dim))

    def forward(self, x, xl_memories, rel_pos):
        residual = x
        norm = nn.LayerNorm(self.dim)
        attn_out = norm(x)
        attn_out, new_xl_memories = self.attention(attn_out, relative_positions=rel_pos, xl_memory=xl_memories)
        attn_out += residual

        residual = attn_out
        ff_out = self.ff_block(attn_out)
        ff_out += residual
        return ff_out, new_xl_memories


class MemorizingTransformer(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        vocab_size,
        max_knn_memories = 81920,
        heads = 8,
        depth = 10,
        dropout = 0,
        head_dimension = 64,
        topk = 5,

    ):
        super().__init__()
        self.heads = heads
        self.embedding_dimension = embedding_dimension
        self.dropout = dropout
        self.depth = depth
        self.head_dimension = head_dimension
        self.max_knn_memories = max_knn_memories
        self.topk = topk

        ###########
        self.rel_pos = RelativePosition(rp_scale = head_dimension** 0.5,
                                        heads = self.heads)
        self.rel_pos_knn = RelativePosition(rp_scale = head_dimension** 0.5,
                                        heads = self.heads)
        self.embedding_matrix = nn.Embedding(vocab_size, self.embedding_dimension)

        self.knn = KNN(head_dimension * heads, self.max_knn_memories)



        self.layers = nn.ModuleList([])
        for i in range(self.depth):

            if i == self.depth-2:
                attention_type = KNNAttention(self.embedding_dimension,
                            self.knn,
                            heads = self.heads,
                            head_dimension = self.head_dimension,
                            dropout = self.dropout)
            else:
                attention_type = XLAttention(self.embedding_dimension,
                            heads = self.heads,
                            head_dimension = self.head_dimension,
                            dropout = self.dropout)

            self.layers.append(Block(self.embedding_dimension, attention_type))

        self.to_logits = nn.Sequential(
            nn.LayerNorm(self.embedding_dimension),
            nn.Linear(self.embedding_dimension, vocab_size)
        )


    def forward(
        self,
        x,
        relative_positions = None,
        xl_memories = None,
        labels = None,
    ):

        batch_size, sequence_length = x.shape[0], x.shape[1]

        # Position values
        rel_pos = self.rel_pos(sequence_length)
        rel_pos_knn = self.rel_pos_knn(sequence_length)

        # If no XL memories (start of a sequence) then None type for each layer.
        # There is one set of XL memories for each layer
        # xl_memories = default(xl_memories, (None,) * self.num_xl_memory_layers)
        if xl_memories is not None:
            xl_memories = xl_memories
        else:
            xl_memories = (None,) * self.depth

        # Iterator
        xl_memories_iter = iter(xl_memories)

        # Embeddings
        x = self.embedding_matrix(x)

        # Store the XL memories for each pass
        new_xl_memories = []

        for ind, block in enumerate(self.layers):

            if i == self.depth-2:
                layer_rel_pos = rel_pos_knn
            else:
                layer_rel_pos = rel_pos

            x, xl_mem = block(x, next(xl_memories_iter), layer_rel_pos)

            if xl_mem is not None:
                ############
                new_xl_memories.append(xl_mem.detach())



        logits = self.to_logits(x)

        # Training
        loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels)
        if len(new_xl_memories) > 0:
            return loss, new_xl_memories
        return loss

## Training Loop

In [21]:
SEGMENTS = 10
SEQUENCE_LENGTH = 512
CHUNK_SIZE = (SEGMENTS * SEQUENCE_LENGTH) + 1 #### we need +1 because we shift by 1 for each sequence
BATCH_SIZE = 8
LEARNING_RATE = 2e-4
MAX_GRAD_CLIP_NORM = 0.5
VALIDATE_EVERY = 100
MAX_KNN_MEMORIES = BATCH_SIZE * 1 * SEQUENCE_LENGTH * SEGMENTS


dataset = datasets.load_dataset("ccdv/arxiv-summarization", split='train', streaming=True)
raw_dataset = list(dataset.take(3500))

raw_articles = [x['article'] for x in raw_dataset]
raw_articles = [x for x in raw_articles if len(x) > CHUNK_SIZE]
converted = [np.fromstring(doc, dtype=np.uint8) for doc in raw_articles]

def clip_article(doc, chunk_size):
    remainder = len(doc) % chunk_size
    return doc[:-remainder]

clipped = [clip_article(doc, CHUNK_SIZE) for doc in converted]


chunked = np.array([doc.reshape(-1, CHUNK_SIZE) for doc in clipped])

processed_data = torch.tensor(np.concatenate(chunked), dtype=torch.long)
processed_data.shape
eighty_split = int(processed_data.shape[0] * .8)
ninety_split = int(processed_data.shape[0] * .9)
train_loader = iter(DataLoader(processed_data[:eighty_split], batch_size = BATCH_SIZE, shuffle = True))
val_loader = iter(DataLoader(processed_data[eighty_split:ninety_split], batch_size = BATCH_SIZE, shuffle = True))
test_loader = iter(DataLoader(processed_data[ninety_split:], batch_size = BATCH_SIZE, shuffle = True))

  converted = [np.fromstring(doc, dtype=np.uint8) for doc in raw_articles]


ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (3401,) + inhomogeneous part.

In [None]:


model = MemorizingTransformer(embedding_dimension = 128,
                              vocab_size = 128,
                              max_knn_memories = MAX_KNN_MEMORIES)

optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
model.train()


for i in tqdm.tqdm(range(200), mininterval = 10., desc = 'training'):

    model.train()
    train_loss = 0.

    # Clear XL memories
    xl_memories = None

    # Clear KNN memory
    model.knn.clear()

    data = next(train_loader)
    seq, labels = data[:, :-1], data[:, 1:]


    # Each pass will be (BATCH_SIZE * SEGMENTS) iterations
    for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):

        loss, xl_memories = model(
            seq_segment,
            labels = labels_segment,
            xl_memories = xl_memories
        )

        train_loss += loss.item() / SEGMENTS
        (loss / SEGMENTS).backward()
        print ("segment complete")


    print(f'training loss: {train_loss}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
    optim.step()
    optim.zero_grad()


    if not (i % VALIDATE_EVERY):
        model.eval()

        valid_data = next(val_loader)
        valid_loss = 0.

        with torch.no_grad():
            xl_memories = None
            model.knn.clear()
            seq, labels = data[:, :-1], data[:, 1:]

            for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):

                loss, xl_memories = model(
                    seq_segment,
                    labels = labels_segment,
                    xl_memories = xl_memories
                )

                valid_loss += loss.item() / SEGMENTS

        print(f'valid loss: {valid_loss}')