In [1]:
# !pip install transformers
# !pip install faiss-gpu
# !pip install einops

In [2]:
from IPython import embed

from memorizing_transformers import MemorizingModel, MemorizingLMHeadModel
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from transformers import GPT2Tokenizer, GPT2Model


In [3]:
# constants
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
NUM_BATCHES = int(1e5)
BATCH_SIZE = 8
SEQ_LEN = 512
SEGMENTS = 5

LEARNING_RATE = 0.001
MAX_GRAD_CLIP_NORM = 0.5

VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))


In [4]:
import pickle

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# with gzip.open('./data/enwik8.gz',"rt",encoding="utf-8") as file:
#     text = file.read()
#     X = tokenizer(text).input_ids
# with open('./data/enwik8_token.pickle', 'wb') as file:
#     pickle.dump(X, file)

with open('./data/enwik8_token.pickle', 'rb') as file:
    X = pickle.load(file)

tr_num = len(X) // 2
trX, vaX = np.split(X, [tr_num])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

In [5]:
# instantiate GPT-like decoder model
# model = MemorizingTransformer(
#     num_tokens = 256, # int8 tokens, 32k (x128 = 2^7d) in the paper
#     dim = 512, # dim of token in embedding space, 1024 (x2) in the paper
#     depth = 8, # 12 
#     memorizing_layers = 4, # 9 in the paper
#     max_knn_memories = 512 * 15, 
#     num_retrieved_memories = 32, 
#     xl_memory_layers = (7, 8),
#     xl_max_memories = 512,
# )

model = MemorizingLMHeadModel.from_pretrained('gpt2').to(DEVICE)

# prepare enwik8 data

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len].long()
        return full_seq.to(DEVICE)

    def __len__(self):
        return self.data.size(0) // self.seq_len

# dataset and dataloader

train_dataset = TextSamplerDataset(data_train, SEQ_LEN * SEGMENTS)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
valid_dataset = TextSamplerDataset(data_val, SEQ_LEN * SEGMENTS)
valid_loader = cycle(DataLoader(valid_dataset, batch_size = BATCH_SIZE, drop_last = True))

# optimizer

optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)



In [6]:
next(train_loader)
next(valid_loader)
print(f"{model.num_parameters()//1e6}M parameters")

In [7]:
model.config

In [8]:
# sentence_prefix = "History of Rome dates back to"
 
# input_ids = tokenizer.encode(
#     sentence_prefix,
#     add_special_tokens=False,
#     return_tensors="pt",
# ).to(DEVICE)
 
# output_ids = model.generate(
#     input_ids=input_ids,
#     do_sample=True,
#     max_length=50,  # desired output sentence length
#     pad_token_id=model.config.eos_token_id,
# )[0].tolist()
 
# generated_text = tokenizer.decode(
#     output_ids,
#     clean_up_tokenization_spaces=True)
 
# print(generated_text)


In [None]:
# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
    model.train()

    data = next(train_loader)

    train_loss = 0.
    seq = data

    with model.knn_memories_context(batch_size = BATCH_SIZE, num_heads=12) as knn_memories:

        for seq_segment in seq.chunk(SEGMENTS, dim = -1):
            result = model(
                input_ids = seq_segment,
                labels = seq_segment,
                knn_memories = knn_memories
            )

            loss = result.loss

            train_loss += loss.item() / SEGMENTS
            (loss / SEGMENTS).backward()

    #         output_ids = model.generate(
    #             input_ids=seq_segment,
    #             do_sample=True,
    #             max_length=20,  # desired output sentence length
    #             pad_token_id=model.config.eos_token_id,
    #         )[0].tolist()

    #         generated_text = tokenizer.decode(
    #             output_ids,
    #             clean_up_tokenization_spaces=True)

    #         print(generated_text)


        print(f'training loss: {train_loss}')
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
        optim.step()
        optim.zero_grad()

    if not (i % VALIDATE_EVERY):
        model.eval()

        valid_data = next(valid_loader)
        valid_loss = 0.

        with torch.no_grad():
            seq = data
            
            for seq_segment in seq.chunk(SEGMENTS, dim = -1):
                
                result = model(
                    input_ids = seq_segment,
                    labels = seq_segment,
                )

                loss = result.loss

                valid_loss += loss.item() / SEGMENTS

        print(f'valid loss: {valid_loss}')


In [None]:
model