In [None]:
# !pip install transformers
# !pip install faiss-gpu
# !pip install einops

In [None]:
from IPython import embed

from memorizing_transformers import MemorizingModel, MemorizingLMHeadModel
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from transformers import GPT2Tokenizer, GPT2Model


In [None]:
# constants
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
NUM_BATCHES = int(1e5)
BATCH_SIZE = 16
SEQ_LEN = 512
SEGMENTS = 5

LEARNING_RATE = 0.001
MAX_GRAD_CLIP_NORM = 0.5

VALIDATE_EVERY  = 10
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))


In [None]:
import pickle

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# with gzip.open('./data/enwik8.gz',"rt",encoding="utf-8") as file:
#     text = file.read()
#     X = tokenizer(text).input_ids
# with open('./data/enwik8_token.pickle', 'wb') as file:
#     pickle.dump(X, file)

with open('./data/enwik8_token.pickle', 'rb') as file:
    X = pickle.load(file)

tr_num = len(X) // 2
trX, vaX = np.split(X, [tr_num])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

In [None]:
print(f"{len(X)/10**6} million tokens")

In [None]:
#model = MemorizingLMHeadModel.from_pretrained('checkpoint/checkpoint_499').to(DEVICE)
model = MemorizingLMHeadModel.from_pretrained('checkpoint/memorizing_checkpoint_499').to(DEVICE)
#model = MemorizingLMHeadModel.from_pretrained('gpt2').to(DEVICE)

# prepare enwik8 data

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len].long()
        
        return full_seq.to(DEVICE)

    def __len__(self):
        return self.data.size(0) // self.seq_len

class RepeatedTextSamplerDataset(Dataset):
    def __init__(self,data,single_seq_len,num_segments):
        super().__init__()
        self.data = data
        self.single_seq_len = single_seq_len
        self.num_segments = num_segments
    
    def __getitem__(self,index):
        rand_start = torch.randint(0,self.data.size(0) - self.single_seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.single_seq_len].long()
        full_seq = full_seq.repeat(self.num_segments)
        return full_seq.to(DEVICE)
    
    def __len__(self):
        return self.data.size(0)
    
# dataset and dataloader

#train_dataset = TextSamplerDataset(data_train, SEQ_LEN * SEGMENTS)
train_dataset = RepeatedTextSamplerDataset(data_train, SEQ_LEN, SEGMENTS)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))

#valid_dataset = TextSamplerDataset(data_val, SEQ_LEN * SEGMENTS)
valid_dataset = RepeatedTextSamplerDataset(data_val, SEQ_LEN, SEGMENTS)
valid_loader = cycle(DataLoader(valid_dataset, batch_size = BATCH_SIZE, drop_last = True))

In [None]:
next(train_loader)
next(valid_loader)
print(f"{model.num_parameters()//1e6}M parameters")

In [None]:
optim = torch.optim.Adam([param for (name,param) in model.named_parameters() if name == "transformer.h.5.attn.knn_attention_ratio"], lr = LEARNING_RATE)
#optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

In [None]:
# sentence_prefix = """
# History of Rome dates back to
# """
# input_ids = tokenizer.encode(
#     sentence_prefix,
#     add_special_tokens=False,
#     return_tensors="pt",
# ).to(DEVICE)

# with open('./data/history_rome.txt',"rt",encoding="utf-8") as file:
#     sample_text = file.read()
#     sample_X = tokenizer(sample_text).input_ids
#     sample_X = torch.from_numpy(np.array([sample_X]))
    
# seq=sample_X[:,0:SEGMENTS*SEQ_LEN].long().to(DEVICE)
    
# with model.knn_memories_context(batch_size = 1, num_heads=12) as knn_memories:
    
#     for index_seg, seq_segment in enumerate(seq.chunk(SEGMENTS, dim = -1)):     
#         with torch.no_grad():
#             result = model(
#                 input_ids = seq_segment,
#                 labels = seq_segment.clone(),
#                 knn_memories = knn_memories,
#                 knn_memory_layer= 5
#             )
    
#     output_ids = model.generate(
#         input_ids=input_ids,
#         do_sample=True,
#         max_length=300,  # desired output sentence length
#         pad_token_id=model.config.eos_token_id,
#         knn_memories = knn_memories,
#         knn_memory_layer= 5,
        
#     )[0].tolist()

#     generated_text = tokenizer.decode(
#         output_ids,
#         clean_up_tokenization_spaces=True)
    
#     print(generated_text)


In [None]:
# training

knn_memory_layer = 5

from torch.utils.tensorboard import SummaryWriter

tb = SummaryWriter()

NUM_BATCHES=500

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
    model.train()

    data = next(train_loader)

    train_loss = 0.
    seq = data
    
    with model.knn_memories_context(batch_size = BATCH_SIZE, num_heads=12) as knn_memories:
        
        for index_seg, seq_segment in enumerate(seq.chunk(SEGMENTS, dim = -1)):
            result = model(
                input_ids = seq_segment,
                labels = seq_segment,
                knn_memories = knn_memories,
                knn_memory_layer=knn_memory_layer
            )

            loss = result.loss
            print(f'training loss for {index_seg}th segment: {loss}')

            train_loss += loss.item() / SEGMENTS
            (loss / SEGMENTS).backward()


        print(f'training loss: {train_loss}')
        tb.add_scalars("Loss", {"training_loss": train_loss}, i*SEGMENTS*BATCH_SIZE)


        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
        optim.step()
        optim.zero_grad()

    # if not (i % VALIDATE_EVERY):
    if True:
        model.eval()

        valid_data = next(valid_loader)
        valid_loss = 0.

        with torch.no_grad(), model.knn_memories_context(batch_size = BATCH_SIZE, num_heads=12) as knn_memories:
            seq = valid_data
            
            for index_seg, seq_segment in enumerate(seq.chunk(SEGMENTS, dim = -1)):
                
                result = model(
                    input_ids = seq_segment,
                    labels = seq_segment,
                    knn_memories = knn_memories,
                    knn_memory_layer=knn_memory_layer
                )

                loss = result.loss
                tb.add_scalars("Segment Loss", {f"segment_loss_{index_seg}": loss}, i*SEGMENTS*BATCH_SIZE)
                tb.add_scalars("Segment Loss Perplexity", {f"segment_loss_perplexity_{index_seg}": torch.exp(loss)}, i*SEGMENTS*BATCH_SIZE)

                valid_loss += loss.item() / SEGMENTS

        print(f'valid loss: {valid_loss}')
        tb.add_scalars("Loss", {"validation_loss": valid_loss}, i*SEGMENTS*BATCH_SIZE)
        tb.add_histogram("memorizing_attention_ratio", torch.sigmoid(model.transformer.h[5].attn.knn_attention_ratio*100), i*SEGMENTS*BATCH_SIZE)

    if not ((i+1) % 100):
        model.save_pretrained(f"./checkpoint/memorizing_checkpoint_{i}")

tb.close()

In [None]:
# !jupyter nbconvert --to python train.ipynb
# !nohup python train.py > output.txt ?