In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import datasets

Note: memorizing transformer is trained with long documents, where the document is longer than the seq_len (input size), so that we have to break the document into chunks, so that for each chunk we look backwards and see if we can find info in XL recurrent mem or in kNN stored memory to help us predict -> arxiv-summarization is a good HF dataset as it contains long doucments  

In [2]:
dataset = datasets.load_dataset("ccdv/arxiv-summarization", split="train", streaming=True)
raw_dataset = list(dataset.take(10)) # take a subset

Note: normally we would shuffle the dataset to avoid bias, but here we consciously don't do that, because we want our transformer to use memory about the past sequences -> we are going to feed in the same article, chunk by chunk / sequence by sequence. This is mentioned in the paper. We're calling these chunks/sequences "segments" from now on.

Normally (vanilla) a batch contains a certain number of sequences of seq_len/input_len tokens, and the whole sequence is used to train in parallel and there is not relation to other sequences (ideally it's all shuffled and mentioned above).

Here however we need that relationship between a sequence and the past sequence, so what we do is in batch 1 we put all the first sequences/segments of b_s docs, then batch 2 will have all the 2nd segments of these same docs, and so on. So each "row" that spans the batches are segments that come from the same document. This implies that each "row" has its own training "context" where the memory will be used/trained.


In [3]:
segments = 10    # This means we take 10 segments of each doc. This is not a lot but this way we don't have to pad.
seg_len = 512   # seq_len usually but here we'll go with seg_len - same thing
chunk_size = segments * seg_len # Not sure why chunk is needed at this point TBD
chunk_size

5120

In [6]:
raw_articles = [x['article'] for x in raw_dataset] # The other key is 'abstract' -> see with raw_dataset[0].keys
# 5120 is min size so any articles shorter can be discarded
raw_articles = [x for x in raw_articles if len(x) > 5120] # Not sure why > and not >=
# Note that we're doing character-based tokenization here to keep the math simple: normally a token will have multiple chars but not now

unique_tokens = set(''.join([raw_article for raw_article in raw_articles]))

print("token set length: ", len(unique_tokens))
print("toket set: ", ''.join(sorted(unique_tokens)))

# Again: taking a massive shortcut here by using characters as tokens -> replace by BPE or subword
# We now need to convert each token in our set to a token_id (number) - quick way to do this is to just use the unicode int for it
# uint8 is the first 128 and that will cover the 70 that we have - there will be "dead" slots (numbers) for which we will create a
# "dead" embedding vector, but ok for now, we'll switch to BPE or Subword anyway.

# quick play with this:
#print(raw_articles[0][:20])
#token_ids = np.fromstring(raw_articles[0], dtype=np.uint8)[:20]
#token_ids
#tokens = ''.join([chr(token_id) for token_id in token_ids])
#tokens

# Token_id all the articles - note that string is converted to array here:

raw_articles_as_token_ids = [np.fromstring(raw_article, dtype=np.uint8) for raw_article in raw_articles]

# Now make sure that each article is clipped at a certain number of chunk sizes

def clip_article(article, chunk_size):
    remainder = len(article) % chunk_size
    return article[:-remainder]

clipped_raw_articles_as_token_ids = [clip_article(raw_article, chunk_size) for raw_article in raw_articles_as_token_ids]

# What we have at this point is a list of articles, where each article is in token_ids and a multiple of chunk_size and then cut off


#for article in clipped_raw_articles_as_token_ids:
#    print(article.reshape(-1,chunk_size).shape)


clipped_raw_arts_as_t_ids_as_chunks = np.array([article.reshape(-1, chunk_size) for article in clipped_raw_articles_as_token_ids], dtype=object)

# The reshape transforms each article from an array of token ids to a 2-D array (#chunks, chunk_size) 
# #chunks can vary dep. on article length

#clipped_raw_arts_as_t_ids_as_chunks.shape
# (10,) second dim is not given as it's a ragged array and the dim size varies dep. on length of article

processed_data = torch.tensor(np.concatenate(clipped_raw_arts_as_t_ids_as_chunks), dtype=torch.long)
processed_data.shape

# numbers for 10 articles:
#(5, 5120)
#(3, 5120)
#(5, 5120)
#(4, 5120)
#(10, 5120)
#(7, 5120)
#(2, 5120)
#(3, 5120)
#(2, 5120)
#(5, 5120)
#
# --> torch.Size([46, 5120]) so concat happens on first dimension
# We now have chunks of chunk_size to feed in for training where chunks relate to each other (from same article) until article changes









token set length:  70
toket set:  
 !"#$%&'()*+,-./0123456789:;<=>?@[\]^_`abcdefghijklmnopqrstuvwxyz{|}~


  raw_articles_as_token_ids = [np.fromstring(raw_article, dtype=np.uint8) for raw_article in raw_articles]


torch.Size([46, 5120])

In [7]:
loader = iter(DataLoader(processed_data, batch_size=8, shuffle=True))

# We're going to have a batch of 8 random chunks. The chunks will not be related. However we're then going to push these chunks through
# the model in parallel, and we'll split each chunk up in seg_len parts that WILL be sequentially related.

THIS IS JUST FOR UNDERSTANDING

example = next(loader) # Get a batch
print(example.shape)
# torch.Size([8, 5120])
seqs, labels = example[:,:-1], example[:,1:]
print(seqs[0][:10])
print(labels[0][:10])
#tensor([116, 104, 101,  32, 100, 105, 114,  97,  99,  32]) -> label is the token to the right
#tensor([104, 101,  32, 100, 105, 114,  97,  99,  32, 115]) -> the above but shifted left

# Now we want to feed in b_s parts of seg_len, with each segment coming from a different chunk, 
# but then the next segment coming from the same chuck again 

seqs.chunk(10, dim=-1)[0].shape
# We split each chunk verticall in 10 pieces of 512, and then get a tensor of size b_s with the first 512 token_ids from each chunk
# torch.Size([8, 512])  so we went from 5120 
# Note that for segs.chunk(10, dim=-1)[0].shape we get
# torch.Size([8, 511]) 
# because there are not 5120 but only 5119 token_ids in segs and labels (because we either lose the last or first token)

def decode_text(token_ids):
    return ''.join([chr(token_id) for token_id in token_ids])

for seqs_segment, labels_segment in zip(seqs.chunk(10, dim=-1), labels.chunk(10, dim=-1)):
    print(decode_text(seqs_segment[0]), "\n ***************** \n")

# In the above we chunk each chunk in the batch up vertically in 10 parts and we loop 10 times, where at each iteration
# we have a 8 (b_s) slices. The slices don't relate to each other (they are from different chunks) but each slice relates to the
# slice in the same "position" in the previous batch.

# So in print(decode_text(seqs_segment[0]) you can go up to 7 for a b_s of 8, but 8 throws an error.



In [53]:
# Dummy model

model = nn.Sequential(
    nn.Embedding(128, 16), # (vocab_size, embed_dim)
    nn.Linear(16, 150),
    nn.ReLU(),
    nn.Linear(150,150),
    nn.ReLU(),
    nn.Linear(150, 128) # back to vocab_size
)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
model.train()

# We have 46 chunks, so with a batch size of 8 we can call next 6 times -> 5 * 8 == 40, call 6 will only have 6 and not 8 in the batch

segments = 10 # Just because 5120 / 10 == 512 

loader = iter(DataLoader(processed_data, batch_size=8, shuffle=True))

for idx in range(6):
    
    batch = next(loader)  # (b_s, seq_len) to (8, 5120) here
    seqs, labels = batch[:,:-1], batch[:,1:]
    train_loss = 0

    # This is special, normally we would just loop through the batch, but here we need to cut up each chunk

    # The use of chunk() is confusing here, since we're using it to cut up what we otherwise call a chunk, which is the
    # article that is clipped to a size of 5120 == segments * seqment length or 10 * 512

    for seqs_segment, labels_segment in zip(seqs.chunk(segments, dim=-1), labels.chunk(segments, dim=-1)): 
        # "segments" passes of (b_s, chunk_size/segments) or (8, 512 here)
        optimizer.zero_grad()
        y_pred = model(seqs_segment) # The model is trained on the batch-segment in parallel on the GPU
        y_pred = y_pred.transpose(2,1) # Just for the loss function

        loss = loss_fn(y_pred, labels_segment) # loss for all batch-segments
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        # We just keep adding to train_loss, segment per segment, and in the end we have it for the whole batch

    if idx % 2 == 0: # Every 5 batches (chunks) print the training loss for that batch
        print(train_loss / segments) # We average over the segments
        


4.795375871658325
4.3197229385375975
3.697280502319336
