In [41]:
import json
from dataset import TextDataset
from model import LanguageModel
from train import run_train
from torch.utils.data import DataLoader
import torch 
import wandb
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
from tqdm import tqdm
import torch.nn as nn
import numpy as np
from transformers import pipeline, set_seed


In [39]:
valid_set = TextDataset(data_file='val.json', 
                            tokenizer_path='bpe.model',
                            train=False, 
                            max_length=256)

In [34]:
all_texts = []
for item in range(len(valid_set)):
    tokens, _ = valid_set[item]
    all_texts.append(valid_set.ids2text(tokens))

In [35]:
device = "cpu"
model_id = "gpt2-xl"
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

config.json:   0%|          | 0.00/689 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/6.43G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [36]:
encodings = tokenizer("\n\n".join(all_texts)[:500000], return_tensors="pt")
all_texts = []
max_length = model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)

nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
    end_loc = min(begin_loc + max_length, seq_len)
    trg_len = end_loc - prev_end_loc
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)
        neg_log_likelihood = outputs.loss

    nlls.append(neg_log_likelihood)

    prev_end_loc = end_loc
    if end_loc == seq_len:
        break

ppl = torch.exp(torch.stack(nlls).mean())

Token indices sequence length is longer than the specified maximum sequence length for this model (123110 > 1024). Running this sequence through the model will result in indexing errors
 99%|█████████▉| 239/241 [1:10:44<00:35, 17.76s/it]


In [52]:
print(ppl.item())

5.2897


In [50]:
val_loader = DataLoader(valid_set, batch_size=10, shuffle=False)
criterion = nn.CrossEntropyLoss(ignore_index=val_loader.dataset.pad_id)

our_model = LanguageModel(valid_set, 4, 256, 4, valid_set.vocab_size, 512, 256, 0.1)
our_model.load_state_dict(torch.load('checkpoint_epoch_100', map_location='cpu'))
our_model.eval()

val_loss = 0.0
cnt = 0
total = 0
with torch.no_grad():
    for indices, lengths in tqdm(val_loader):
        cnt += 1
        total += len(lengths)
        indices = indices[:, :lengths.max()]
        logits = our_model(indices[:, :-1]) 
        loss = criterion(logits.transpose(1, 2), indices[:, 1:])
        val_loss += loss.item() * indices.shape[0]
        if cnt >= 1000:
            break

    val_loss /= total
print(np.exp(val_loss))

4.206220314115082


In [61]:
for i in range(5):
    print(our_model.inference('Alice was so tired when she got back home so she went', temp=2))

alice was so tired when she got back home so she went ride ra ra r o always per journey both excited clear be canyard rella mess. more from' mommyhello. a autve' your sometimes they lillieonpe  and un bolet gu stor sun an sourlyge. seen feeling looking tr baby wordsym washing together maysw' babys explore scared nightw theyain yelledhock this aways st." number patit got fun ball when washingable f very hisizardpl means big to be bl wonderfulureensgroundim ranside kept any got pink squither strongump,, play someed together bright beice perfect began leurt rounddopport ranpe dragail and green sun under worryals fatp and ar box theater lots. a gu? now expensive book at go else disgusting. this ind
alice was so tired when she got back home so she went up anyse thought.
alice was so tired when she got back home so she went until," selly stop compitherasticving bec untiliness onamp ⁇  long manyumesgordames with shiny ting lovedr remind adviceite strfortak ind got knehes and silly sl all pret

In [62]:
generator = pipeline('text-generation', model='gpt2-xl')
set_seed(42)
generator("Alice was so tired when she got back home so she went", max_length=200, num_return_sequences=5)



config.json:   0%|          | 0.00/689 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/6.43G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': "Alice was so tired when she got back home so she went into her room and fell asleep.\n\nI got in late and started with her homework, but I couldn't finish it for the rest of the night. Instead, I went out and bought a bottle of water. The first thing I did was fill it with ice water. I sat in my room for a long time trying to think of what to order, but I couldn't come up with anything. Alice couldn't find something either; she didn't want to go to McDonald's because she liked her apple pie but she had to have a bagel. That left me with two choices: I could order the same thing at a deli but they wouldn't give me the bagel and I could go by myself with a soda. So I got up from my desk, opened the deli door and went in as Alice and she was behind me.\n\nWhen I entered the deli, I saw two people waiting for"},
 {'generated_text': 'Alice was so tired when she got back home so she went to a park by herself and spent a lot of time there, and later in the day she went to