In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
#export
import warnings
warnings.filterwarnings('ignore')

In [3]:
#export
from lib.nb_06 import *
import seaborn as sns

In [5]:
path = Path("../data/pretrained_lm/")

In [6]:
path.ls()

[PosixPath('../data/pretrained_lm/ll_wiki.pkl'),
 PosixPath('../data/pretrained_lm/pretrained.pth'),
 PosixPath('../data/pretrained_lm/vocab.pkl')]

In [347]:
ll = pickle.load(open(path/"ll_wiki.pkl", "rb"))

In [348]:
vocab = pickle.load(open(path/'vocab.pkl', 'rb'))

In [349]:
len(ll.train.proc_x[-1].vocab), len(vocab)

(60002, 60002)

In [350]:
len(ll.train), len(ll.valid)

(18417, 60)

In [351]:
bs, bptt = 128, 70
data = lm_databunchify(ll, bs, bptt)

In [352]:
dps = np.array([0.1, 0.15, 0.25, 0.02, 0.2]) * 0.2
tok_pad = vocab.index(PAD)

In [353]:
emb_sz, nh, nl = 300, 300, 2
model = get_awd_lstm_language_model(len(vocab), emb_sz, nh, nl, tok_pad, *dps)

In [354]:
cbs = [partial(AvgStatsCallback,accuracy_flat),
       partial(CudaCallback, get_device()), 
       Recorder,
       partial(GradientClipping, clip=0.1),
       partial(RNNTrainer, α=2., β=1.),
       ProgressCallback]

In [355]:
learn = Learner(model, data, cross_entropy_flat, lr=5e-3, cb_funcs=cbs, opt_func=adam_opt())

In [356]:
learn.model.load_state_dict(torch.load(path/"pretrained.pth", map_location=get_device()))

In [285]:
learn.model

SequentialRNN(
  (0): AWD_LSTM(
    (emb): Embedding(60002, 300, padding_idx=1)
    (emb_dp): EmbeddingDropout(
      (emb): Embedding(60002, 300, padding_idx=1)
    )
    (rnns): ModuleList(
      (0): WeightDropout(
        (module): LSTM(300, 300, batch_first=True)
      )
      (1): WeightDropout(
        (module): LSTM(300, 300, batch_first=True)
      )
    )
    (input_dp): RNNDropout()
    (hidden_dps): ModuleList(
      (0): RNNDropout()
      (1): RNNDropout()
    )
  )
  (1): LinearDecoder(
    (output_dp): RNNDropout()
    (decoder): Linear(in_features=300, out_features=60002, bias=True)
  )
)

Let's try and make one step ahead predictions from the model

In [286]:
test_sent = "once upon a time "

In [287]:
test_sent

'once upon a time '

Tokenize

In [288]:
tp = TokenizeProcessor()

In [289]:
tp([test_sent, test_sent])[0]

['xxbos', 'once', 'upon', 'a', 'time', 'xxeos']

Numercalize

In [290]:
ll.train.proc_x[-1].proc1(tp([test_sent])[0][:-1])

[2, 442, 433, 15, 69]

In [291]:
num_test_sent = ll.train.proc_x[-1].proc1(tp([test_sent])[0][:-1])

In [292]:
inp = torch.stack((torch.tensor(num_test_sent),torch.tensor(num_test_sent)))

In [293]:
inp

tensor([[  2, 442, 433,  15,  69],
        [  2, 442, 433,  15,  69]])

In [306]:
learn.model.eval()
with torch.no_grad():
    outs = learn.model(inp)

In [307]:
inp.shape

torch.Size([2, 5])

In [308]:
decoded, raw_outputs, outputs = outs

The decoded tensor is flattened to `bs * seq_len` by `len(vocab)`:

In [309]:
bs, sl = inp.shape

In [310]:
bs,sl

(2, 5)

In [311]:
decoded.size()

torch.Size([10, 60002])

`raw_outputs` and `outputs` each contain the results of the intermediary layers:

In [312]:
len(raw_outputs), len(outputs)

(2, 2)

In [313]:
[o.size() for o in raw_outputs], [o.size() for o in outputs]

([torch.Size([2, 5, 300]), torch.Size([2, 5, 300])],
 [torch.Size([2, 5, 300]), torch.Size([2, 5, 300])])

In [314]:
pred_1 = decoded.view(bs, sl, -1)[1]

In [315]:
pred_tok = torch.argmax(F.softmax(pred_1[-1], dim=0))#.shape

In [316]:
pred_tok

tensor(11)

In [317]:
vocab[pred_tok.item()]

'of'

now we need to append pred_tok to inp

In [329]:
inp1 = inp.numpy()[0]

In [330]:
num_test_sent = np.append(inp1 , np.array([pred_tok.item()]))

In [331]:
inp = torch.stack((torch.tensor(num_test_sent),torch.tensor(num_test_sent)))

In [332]:
inp

tensor([[  2, 442, 433,  15,  69,  11],
        [  2, 442, 433,  15,  69,  11]])

In [357]:
def get_next_tok(num_test_sent):
    inp = torch.stack((torch.tensor(num_test_sent),torch.tensor(num_test_sent)))
    bs, sl = inp.shape
    learn.model.eval()
    with torch.no_grad():
        outs = learn.model(inp)
    decoded, raw_outputs, outputs = outs
    
    pred_1 = decoded.view(bs, sl, -1)[0]
    pred_tok = torch.argmax(F.softmax(pred_1[-1], dim=0))
    return pred_tok.item()

In [358]:
def generate_text(learn, test_sent, vocab = vocab, num_toks=50):
    num_test_sent = ll.train.proc_x[-1].proc1(tp([test_sent])[0][:-1])
    
    for i in range(num_toks):
        next_tok = get_next_tok(num_test_sent)
        num_test_sent = np.append(num_test_sent , np.array([next_tok]))
    
    text = [vocab[tok] for tok in num_test_sent]
    return " ".join(text)

In [371]:
generate_text(learn, "due to low")

'xxbos due to low - level xxmaj canadian xxmaj american ancestry , the xxmaj united xxmaj states xxmaj army ( xxmaj army ) , xxmaj united xxmaj states xxmaj army ( xxmaj army ) , xxmaj united xxmaj states xxmaj army ( xxmaj army ) , xxmaj united xxmaj states xxmaj army ( xxmaj'

Top k sampling and neucleus sampling, shamelessly copied from [this twitter thread](https://twitter.com/thom_wolf/status/1124263846674345985?s=12) by Thomas Wold of Hugging Face.

Some research pointers:
- [Importance of a Search Strategy in Neural Dialogue Modelling](https://arxiv.org/pdf/1811.00907.pdf)
- [Correcting Length Bias in Neural Machine Translation](https://arxiv.org/abs/1808.10006)
- [Breaking the Beam Search Curse: A Study of (Re-)Scoring Methods and Stopping Criteria for Neural Machine Translation](https://arxiv.org/abs/1808.09582)
- [Hierarchical Neural Story Generation](https://arxiv.org/abs/1805.04833)
- [Better Language Models and Their Implications](https://openai.com/blog/better-language-models/)
- [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)

> Today, the most promising candidates for high-entropy tasks decoders seem to be top-k & nucleus sampling
General principle: at each step, sample from the next-token distribution filtered to keep only the top-k tokens or the top tokens with cumulative prob above a threshold.

https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317

In [372]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (..., vocabulary size)
            top_k >0: keep only top k tokens with highest probability (top-k filtering).
            top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
    """
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

In [384]:
tp = TokenizeProcessor()
test_sent = "once upon a time, "

In [385]:
num_test_sent = ll.train.proc_x[-1].proc1(tp([test_sent])[0][:-1])
inp = torch.stack((torch.tensor(num_test_sent),torch.tensor(num_test_sent)))

In [386]:
learn.model.eval()
with torch.no_grad():
    outs = learn.model(inp)

decoded, raw_outputs, outputs = outs

In [387]:
logits = decoded

In [388]:
# Here is how to use this function for top-p sampling
temperature = 1.0
top_k = 0
top_p = 0.9

# Get logits with a forward pass in our model (input is pre-defined)
#logits = model(input)

# Keep only the last token predictions, apply a temperature coefficient and filter
logits = logits[..., -1, :] / temperature
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

# Sample from the filtered distribution
probabilities = F.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probabilities, 1)

In [389]:
next_token

tensor([2036])

In [390]:
vocab[next_token.item()]

'joint'

In [422]:
def get_next_tok_sample(logits, top_k=0, top_p=0.9, tempreature=1.0):
    # Keep only the last token predictions, apply a temperature coefficient and filter
    logits = logits[..., -1, :].clone().detach() / temperature
    filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

    # Sample from the filtered distribution
    probabilities = F.softmax(filtered_logits, dim=-1)
    next_token = torch.multinomial(probabilities, 1)
    return next_token.item()

In [423]:
def get_logits(learn, num_test_sent):
    #print(num_test_sent)
    inp = torch.stack((torch.tensor(num_test_sent),torch.tensor(num_test_sent)))
    #print(inp.shape)
    #bs, sl = inp.shape
    learn.model.eval()
    with torch.no_grad():
        outs = learn.model(inp)
    decoded, raw_outputs, outputs = outs
    return decoded

In [424]:
def generate_text(learn, test_sent, vocab = vocab, num_toks=50):
    num_test_sent = ll.train.proc_x[-1].proc1(tp([test_sent])[0][:-1])
    for i in range(num_toks):
        logits = get_logits(learn, num_test_sent)
        next_tok = get_next_tok_sample(logits)
        num_test_sent = np.append(num_test_sent , np.array([next_tok]))
    
    text = [vocab[tok] for tok in num_test_sent]
    return " ".join(text)

In [425]:
generate_text(learn, "once upon a time, there", vocab=vocab)

'xxbos once upon a time , there are several tribal bands associated with the xxmaj middle xxmaj east . xxmaj the original for a settlement is on the corner of xxmaj hale and xxmaj lowry xxmaj road , and began a productive modification of a school built during the early stages of the boom . xxmaj the'

In [426]:
generate_text(learn, "once upon a time, there", vocab=vocab, num_toks=100)

'xxbos once upon a time , there were several variations between attitudes to students , who eventually had learned over the political spectrum , and their inspirations : and the younger sense of interest in the xxmaj western house was broken up in favor of planning to explore separate " treasure " models of the original . xxmaj when xxmaj guests chose to stay with xxmaj gertrude , she produced the photographs for the xxmaj queen \'s and xxmaj albert xxmaj park xxmaj row standards ; however , xxmaj jacqueline knew that xxmaj miss xxmaj lee , though uncertain about the connection , wanted to win by'

In [428]:
generate_text(learn, "There is a strong chance", vocab=vocab, num_toks=500)

'xxbos xxmaj there is a strong chance of acquiring information that takes the information on the earth and asserts that most land upon which it can not be read by the individual have been encouraged , and the level has been governed by the short , more extensive term that was in full development . xxmaj others argue that news can not be obtained , while xxmaj ancient xxmaj histories , today how scholars discuss the a language of material which had broken down , would have created a stable definition for the economy . xxmaj this theory gives rise to xxmaj christian beliefs in a xxmaj history of xxmaj absolute xxmaj change , xxmaj questions , and the tradition of history of every city . xxmaj within the xxmaj eighteenth xxmaj century , a true study of the present nation was carried by the director of the social law system , the xxmaj people \'s xxmaj republic of xxmaj ireland xxmaj who , which " took some jobs " for several organizations . xxmaj new figure of xxmaj alice \'s youth used t