In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from fastai2.text.all import L, test_eq
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import re
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import warnings

from htools import *
from stormlight.utils import *

In [3]:
cd_root()

Current directory: /Users/hmamin/stormlight


In [4]:
def load_book(book):
    return load(PATHS[book])[slice(*BOOK_IDX[book])]

In [5]:
def load_books(*books):
    return DotDict((book, load_book(book)) for book in books or PATHS.keys())

In [6]:
def check_clean(book, *args):
    print('Length:', len(book))
    print('Page breaks:', book.count(BREAK))
    print('Feed form chars:', book.count('\x0c'))
    for arg in args:
        print(repr(arg) + ':', book.count(arg))

In [7]:
def print_pages(book, n=150, end=True, max_pages=None):
    for page in book.split(BREAK)[:max_pages]:
        print(page[:n])
        if end:
            print(spacer(n_chars=3))
            print(page[-n:])
        print(spacer())

In [8]:
def load_endnotes(book):
    return load(PATHS[book])[BOOK_IDX[book][-1]:]

In [148]:
books = load_books(*NOVELLAS+STORMLIGHT)
{k: v[:50] for k, v in books.items()}

{'edge': 'xxbrxxPROLOGUE\n\nLift had never robbed a palace bef',
 'war': "xxbrxxPrologue\n\nIt's funny, Vasher thought, how ma",
 'kings': 'xxbrxxPRELUDE TO\n\nTHE STORMLIGHT ARCHIVE\n\nKalak ro',
 'words': 'xxbrxxSIX YEARS AGO\nJasnah Kholin pretended to enj',
 'oath': 'xxbrxxSIX YEARS AGO\nEshonai had always told her si'}

In [10]:
# Way of Kings has fewer words per page. Keep in mind for later decisions.
for name, text in books.items():
    print(name)
    check_clean(text, '\n', '\n\n')
    print('\n')

edge
Length: 287754
Page breaks: 159
Feed form chars: 0
'\n': 5024
'\n\n': 185


war
Length: 1116087
Page breaks: 575
Feed form chars: 0
'\n': 18216
'\n\n': 959


kings
Length: 2200799
Page breaks: 1456
Feed form chars: 0
'\n': 32501
'\n\n': 1585


words
Length: 2285222
Page breaks: 1113
Feed form chars: 0
'\n': 38305
'\n\n': 1133


oath
Length: 2600942
Page breaks: 946
Feed form chars: 0
'\n': 35206
'\n\n': 49




## Dataset

In [140]:
tok = GPT2TokenizerFast.from_pretrained('gpt2', pad_token='<|endoftext|>')

In [141]:
len(tok), tok.vocab_size

(50257, 50257)

In [142]:
special = dict(additional_special_tokens=[BREAK])
tok.add_special_tokens(special)

1

In [143]:
tok.special_tokens_map

{'bos_token': '<|endoftext|>',
 'eos_token': '<|endoftext|>',
 'unk_token': '<|endoftext|>',
 'pad_token': '<|endoftext|>',
 'additional_special_tokens': ['xxbrxx']}

In [144]:
tok.decode([50257]), tok.decode([50258])

('xxbrxx', '')

In [145]:
len(tok), tok.vocab_size

(50258, 50257)

In [17]:
lengths = [len(tok.encode(line)) for line in books.edge.split('.') if line]
pd.Series(lengths).describe(percentiles=np.r_[.25:.8:.25, .9:1:.01])

count    4904.000000
mean       14.369494
std         9.778607
min         1.000000
25%         7.000000
50%        12.000000
75%        19.000000
90%        27.000000
91%        28.000000
92%        29.000000
93%        30.000000
94%        31.000000
95%        33.000000
96%        35.000000
97%        37.000000
98%        40.000000
99%        45.000000
max        83.000000
dtype: float64

In [18]:
lengths = [len(tok.encode(pg)) for pg in books.edge.split(BREAK) if pg]
pd.Series(lengths).describe(percentiles=np.r_[.25:.8:.25, .9:1:.01])

count    158.000000
mean     470.582278
std      118.312509
min       27.000000
25%      421.250000
50%      508.500000
75%      554.500000
90%      572.600000
91%      574.870000
92%      576.440000
93%      578.060000
94%      584.000000
95%      586.150000
96%      587.720000
97%      588.580000
98%      597.740000
99%      604.870000
max      633.000000
dtype: float64

In [19]:
del lengths; gc.collect()

61

In [20]:
tok.encode('kaladin walked towards the cliff')

[74, 282, 17072, 6807, 3371, 262, 19516]

In [21]:
for n in [74, 282, 17072, 6807, 3371, 262, 19516]:
    print(n, tok.decode([n]))

74 k
282 al
17072 adin
6807  walked
3371  towards
262  the
19516  cliff


In [22]:
tok.max_len, tok.max_len_single_sentence

(1024, 1024)

In [139]:
class LMDataset(Dataset):
    
    def __init__(self, books=tuple(PATHS.keys()), tok=None, seq_len=512,
                 tokens=None, subset_frac=1, return_tuple=False):
        if not tok and tokens is None:
            raise ValueError('Must pass in either tokenizer or tokens.')
        self.books = books
        self.tok = tok
        self.seq_len = seq_len
        self.subset_frac = subset_frac
        self.tokens = self._create_tokens() if tokens is None else tokens
        self.return_tuple = return_tuple
        
    def _create_tokens(self):
        books = load_books(*self.books)
        text = BREAK.join(book[:int(len(book)*self.subset_frac)] 
                          for book in books.values())
        tokens = self.tok.encode(text)
            
        # Instead of padding/dropping last batch, load last book's endnotes.
        n_missing = self.seq_len - len(tokens) % self.seq_len
        if n_missing > 0:
            endnotes = load_endnotes(self.books[-1])
            tokens += self.tok.encode(endnotes)[:n_missing]
        return np.array(tokens)
    
    def __getitem__(self, i):
        seq = self.tokens[i*self.seq_len:(i+1)*self.seq_len]
        return (seq, seq) if self.return_tuple else seq
    
    def __len__(self):
        return int(np.ceil(len(self.tokens) / self.seq_len))
    
    def save(self, path):
        data = select(vars(self), drop=['tok'])
        data['tok_type'] = type(self.tok)
        save(data, path)
        
    @classmethod
    def from_pickle(cls, path, tok=None, **kwargs):
        data = load(path)
        data.update(kwargs)
        if type(tok) != data.pop('tok_type'):
            warnings.warn('Tokenizer is different than what was used to '
                          'tokenize data.')
        return cls(tok=tok, **data)

In [134]:
size2kwargs = {'tiny': dict(seq_len=16, subset_frac=0.01),
               'med': dict(seq_len=512, subset_frac=0.05),
               'all': dict(seq_len=512, subset_frac=1)}
ds_paths = {sz: f'data/datasets/gpt2_lm_tokens_{sz}.pkl' 
            for sz in size2kwargs}

bs = 4
shuffle = False

In [135]:
for sz, kwargs in size2kwargs.items():
    ds = LMDataset(tok=tok, **kwargs)
    ds.save(ds_paths[sz])

Writing data to data/datasets/gpt2_lm_tokens_tiny.pkl.
Writing data to data/datasets/gpt2_lm_tokens_med.pkl.
Writing data to data/datasets/gpt2_lm_tokens_all.pkl.


In [18]:
# ~40 sec to create DS for all 6 books w/ seq_len=512 and frac=1
ds = LMDataset(tok=tok, seq_len=seq_len, subset_frac=subset_frac)
dl = DataLoader(ds, batch_size=bs, shuffle=shuffle)

In [36]:
assert all(x.shape[-1] == dl.dataset.seq_len for x in dl),\
    'Batch sizes differ. Check if last batch is incomplete.'
if ds.seq_len == 512 and ds.books == tuple(PATHS.keys()):
    ds.save(ds_paths['all'])
    test_eq(ds.tokens, LMDataset.from_pickle(ds_all_path, tok).tokens)

In [37]:
eprint(tok.decode(block[:20]) for block in item(dl).numpy())

 0:  sight. I
wouldn't send my own brother ashore there without guards, and he's killed
 1: 've revised my earlier
decision. I need you to halt the ship and let me inspect the
 2:  as other men might play with their mustaches. "Brightness,
that's not advisable.
 3:  I've experienced one or two
times in my life."
"No, I simply cannot allow


In [19]:
ds.tokens.shape, len(ds)

((122250,), 2445)

In [20]:
len(tok), tok.vocab_size

(50258, 50257)

## Model

Situation: Initially, we are using pre-trained weights so loss is relatively low and text generation works fairly well. However, the model can't handle the new xxbrxx token and will throw an error when it appears (embedding matrix is 1 row too small).

If we resize the model's embedding matrix, we no longer get errors but now loss is high and text generation is terrible. My hope is that the untrained embedding and new row in the last linear layer is just really bad this point and training a bit will fix this problem. My worry is that something's not working correctly (e.g. indices are shifted when the new row is added so they no longer correspond to the right embeddings).

Ways to check:
- examine rows of model.transformer.wte before and after resize. Only the last one should change.
- Train a bit and see if the issue disappears.
- logic: if everything was shifted by one, we'd expect to see nonsense generation but there would still be variety in token choices. Because we're seeing xxbrxx output repeatedly, I'm guessing that's not an issue here.

Followup: After resizing embedding, generating text while preventing the model from choosing the xxbrxx token gives us results similar to before resizing. Therefore, I'm pretty sure this is just reflecting the fact that there are some
untrained weights now, which should be fixed by training.

In [109]:
model = GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tok.pad_token_id)

In [30]:
model.transformer.wte, len(tok), tok.vocab_size

(Embedding(50257, 768), 50258, 50257)

In [31]:
model.transformer.wte.weight[-3:]

tensor([[-0.0445, -0.0548,  0.0123,  ...,  0.1044,  0.0978, -0.0695],
        [ 0.1860,  0.0167,  0.0461,  ..., -0.0963,  0.0785, -0.0225],
        [ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207]],
       grad_fn=<SliceBackward>)

In [32]:
with torch.no_grad():
    x = torch.tensor(ds[44])
    loss, logits, past = model(x, labels=x)

In [33]:
loss, logits.shape, attrmap('shape', *past)

(tensor(4.5749),
 torch.Size([50, 50257]),
 [torch.Size([2, 1, 12, 50, 64]),
  torch.Size([2, 1, 12, 50, 64]),
  torch.Size([2, 1, 12, 50, 64]),
  torch.Size([2, 1, 12, 50, 64]),
  torch.Size([2, 1, 12, 50, 64]),
  torch.Size([2, 1, 12, 50, 64]),
  torch.Size([2, 1, 12, 50, 64]),
  torch.Size([2, 1, 12, 50, 64]),
  torch.Size([2, 1, 12, 50, 64]),
  torch.Size([2, 1, 12, 50, 64]),
  torch.Size([2, 1, 12, 50, 64]),
  torch.Size([2, 1, 12, 50, 64])])

In [34]:
x = tok.encode(['Huqin turned to one of his companions. Using their ropes, they scuttled'])
gen = model.generate(torch.tensor(x).unsqueeze(0), max_length=112, min_length=10, 
                      repetition_penalty=10, no_repeat_ngram_size=4,
                      early_stopping=True, do_sample=True, temperature=.7)
tok.decode(gen.numpy()[0])

'Huqin turned to one of his companions. Using their ropes, they scuttled the building and reached an area with a few buildings in it that could be considered as part "Shuaai".\n"I am about ready." said Jinxiu Feng from within Wu Hua Tower\'s walls! This was also not surprising considering how many people came here just now (but still there were lots already). But he had no idea what did go on inside this space or where all those who went out yesterday got into… but at least for today everyone would'

In [35]:
with assert_raises(RuntimeError):
    with torch.no_grad():
        x = torch.tensor(ds[0])
        loss, logits, past = model(x, labels=x)

As expected, got RuntimeError(index out of range: Tried to access index 50257 out of table with 50256 rows. at ../aten/src/TH/generic/THTensorEvenMoreMath.cpp:418).


In [36]:
model.resize_token_embeddings(len(tok))
model.tie_weights()

In [37]:
model.transformer.wte.weight.shape

torch.Size([50258, 768])

In [38]:
model.transformer.wte.weight[-4:]

tensor([[-0.0445, -0.0548,  0.0123,  ...,  0.1044,  0.0978, -0.0695],
        [ 0.1860,  0.0167,  0.0461,  ..., -0.0963,  0.0785, -0.0225],
        [ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207],
        [ 0.0071, -0.0033,  0.0120,  ...,  0.0155,  0.0073, -0.0005]],
       grad_fn=<SliceBackward>)

In [39]:
with torch.no_grad():
    x = torch.tensor(ds[0])
    loss, logits, past = model(x, labels=x)

In [40]:
loss

tensor(91.3928)

In [41]:
# logits: (seq_len, vocab_sz), maybe (bs, seq_len, vocab_sz) if bs > 1
logits.shape

torch.Size([50, 50258])

In [42]:
# keys and values in attention blocks: (2, batch_size, num_heads, sequence_length, embed_size_per_head))
attrmap('shape', *past)

[torch.Size([2, 1, 12, 50, 64]),
 torch.Size([2, 1, 12, 50, 64]),
 torch.Size([2, 1, 12, 50, 64]),
 torch.Size([2, 1, 12, 50, 64]),
 torch.Size([2, 1, 12, 50, 64]),
 torch.Size([2, 1, 12, 50, 64]),
 torch.Size([2, 1, 12, 50, 64]),
 torch.Size([2, 1, 12, 50, 64]),
 torch.Size([2, 1, 12, 50, 64]),
 torch.Size([2, 1, 12, 50, 64]),
 torch.Size([2, 1, 12, 50, 64]),
 torch.Size([2, 1, 12, 50, 64])]

In [162]:
x = next(iter(dl))
x.shape

torch.Size([4, 50])

In [85]:
x = torch.tensor(ds[0]).unsqueeze(0)
gen = model.generate(x, max_length=112, min_length=10, 
                      repetition_penalty=10, no_repeat_ngram_size=4,
                      early_stopping=True, do_sample=True, temperature=.7)
tok.decode(gen.numpy()[0])

"xxbrxxPROLOGUE\n\nLift had never robbed a palace before. Seemed like a dangerous thing to try.\nNot because she might get caught, but because once you robbed a starvin'\npalace, where did you go next?xxbrxx PRxxbrxxxxbrxx Prolxxbrxxxxbrxxxxbrxxxxbrxx Thexxbrxxxxbrxxxxbrxx Preludexxbrxxxxbrxxxxbrxx Pxxbrxxxxbrxxxxbrxx Txxbrxxxxbrxxxxbrxx Sxxbrxxxxbrxxxxbrxx _xxbrxxxxbrxxxxbrxx Dxxbrxxxxbrxxxxbrxx Exxbrxxxxbrxxxxbrxx Cxxbrxxxxbrxxxxbrxx Axxbrxxxxbrxxxxbrxx Gxxbrxxxxbrxxxxbrxx -xxbrxxxxbrxxxxbrxx Lxxbrxxxxbrxxxxbrxx N"

In [89]:
gen = model.generate(x, bad_words_ids=[[len(tok)-1]], max_length=112,
                     min_length=10, repetition_penalty=10, no_repeat_ngram_size=4,
                      early_stopping=True, do_sample=True, temperature=.7)
tok.decode(gen.numpy()[0])

"xxbrxxPROLOGUE\n\nLift had never robbed a palace before. Seemed like a dangerous thing to try.\nNot because she might get caught, but because once you robbed a starvin'\npalace, where did you go next? And even then... LIFT was very good at knowing when the enemy's king would come and tell them he didn't want him there anymore! You have no idea how many times I've been on this ship over my life… But that guy just got out of his seat right now (and is still in"

In [90]:
tok.decode(x[0].numpy())

"xxbrxxPROLOGUE\n\nLift had never robbed a palace before. Seemed like a dangerous thing to try.\nNot because she might get caught, but because once you robbed a starvin'\npalace, where did you go next?"

## Generation

In [300]:
tok.decode(list(x))

'Huqin turned to one of his companions. Using their ropes, they scuttled'

In [154]:
# - num_beams: larger -> slower but better results.
# - num_return_sequences: must be <= num_beams
# - output: (bs*num_return_sequences, seq_len)
# - no_repeat_ngram_size: very effective at reducing repetition but use carefully,
# if we generate a long piece of text all at once this persists. Maybe we can generate in chunks?
# - temperature: 0 means greedy, 1 means more randomness which can be risky
# - blog suggests top k search (top_k=50) or top p search may be good for 
# story generation, while beam search is better for machine translation 
# (still need do_sample=True)
pred = model.generate(x.unsqueeze(0), max_length=60, min_length=10, 
                      repetition_penalty=10, no_repeat_ngram_size=3,
                      num_return_sequences=3, num_beams=4,
                      early_stopping=True, do_sample=True, temperature=.7)

In [155]:
tok.decode(list(pred[0]))

'Huqin turned to one of his companions. Using their ropes, they scuttled off into the distance with a loud thud and an explosion shook them all away from each other\'s eyes as if nothing had happened at that moment."\n"What are you talking about?" he asked'

In [156]:
tok.decode(list(pred[1]))

'Huqin turned to one of his companions. Using their ropes, they scuttled off into the forest and disappeared in a flash as if nothing had happened at all."\n"What did you do?" asked Ye Xiwen with an expression that seemed like it was about time for'

In [157]:
tok.decode(list(pred[2]))

'Huqin turned to one of his companions. Using their ropes, they scuttled out the door and ran into a small room with an old man sitting on it waiting for them in order that he might take care not only himself but also those around him who were watching from afar as'

In [302]:
pred = model.generate(torch.tensor(ds[800]).unsqueeze(0),
                      max_length=60, min_length=10, top_k=50, 
                      early_stopping=True, do_sample=True, temperature=.7)

In [304]:
tok.decode(ds[800])

'Something stirred inside of Lift. Like the little swirls of wind at the advent'

In [310]:
tok.decode(ds[801]), tok.decode(ds[802])

('of a storm.', 'Darkness looked at her with a sharp motion. "Something is--"')

In [307]:
pred

tensor([[22210, 33091,  2641,   286, 43711,    13,  4525,   262,  1310, 42835,
            82,   286,  2344,   379,   262, 19980,   286,   262,  2344,    11,
           810,   790,  1657,  3947,   284,  5202,   832,   340,    11,  1865,
           262,  1657,  2346,   373,   991,    13,   198,   198,     6,    40,
          4724,   340,   338,   407,   355,  1290,   503,   355,   314,  1807,
            13,   632,   338,   655,   257,  3155,   286,  2431,   656,   262]])

In [308]:
tok.decode(list(pred[0]))

"Something stirred inside of Lift. Like the little swirls of wind at the advent of the wind, where every light seemed to flow through it, yet the light itself was still.\n\n'I guess it's not as far out as I thought. It's just a couple of minutes into the"

In [63]:
tok.get_vocab()['xxbrxx']

50257

In [159]:
def generate_to_file(model, tok, x, path, *skip_tokens, max_length=112, 
                     mode='a', verbose=True, **kwargs):
    if not isinstance(x, torch.Tensor):
        x = torch.tensor(x)
    if len(x.shape) == 1:
        x = x.unsqueeze(0)
        
    kwargs_ = dict(max_length=max_length, min_length=10, 
                   repetition_penalty=10, no_repeat_ngram_size=4, 
                   early_stopping=True, do_sample=True, temperature=.7, 
                   top_p=.95, top_k=max_length,
                   bad_words_ids=[[tok.get_vocab()[t] for t in skip_tokens]]
                                  if skip_tokens else None)
    # Update after so defaults are overwritten if user provides value.
    kwargs_.update(kwargs)
    res = model.generate(x, **kwargs_)[0].numpy().tolist()
    old, new = map(tok.decode, (res[:len(x[0])], res[len(x[0]):]))
    if verbose:
        print(old, '\n\n', new)
    save(spacer()+old+new+spacer(), path, mode_pre=mode, verbose=verbose)
    return old, new

In [105]:
res = generate_to_file(model, tok, ds[800], 'data/generated/sample.txt', 
                       'xxbrxx', mode='w', max_length=112)

 or use it again.
Szeth lowered his Shardblade, standing among the cinder-eyed corpses. Here, in
Alethkar, men often spoke of the legends--of mankind's hard-won victory over the
Void 

  dragons and their treacherous servants who came to battle our heroes; they were called "the dead," even now! There is no such thing as a life without death itself... And I am glad that my people can finally make contact with each other once more: for this great city stands today..." He looked at Ale
