# Making GPT2 crack jokes

This is simple experimental notebook for fine-tuning pretrained GPT2 model on jokes dataset. Let's see if it can learn to crack some jokes on it's own. 

For this purpose I will use pretrained models from huggingface [transformers repository](https://github.com/huggingface/transformers).

In [None]:
!pip install transformers

In [7]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
import warnings
warnings.filterwarnings('ignore')

In [20]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

I1029 22:23:33.436547 4458931648 tokenization_utils.py:374] loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json from cache at /Users/martinsf/.cache/torch/transformers/f2808208f9bec2320371a9f5f891c184ae0b674ef866b79c58177067d15732dd.1512018be4ba4e8726e41b9145129dc30651ea4fec86aa61f4b9f40bf94eac71
I1029 22:23:33.445147 4458931648 tokenization_utils.py:374] loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt from cache at /Users/martinsf/.cache/torch/transformers/d629f792e430b3c76a1291bb2766b0a047e36fae0588f9dbc1ae51decdff691b.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda
I1029 22:23:34.091546 4458931648 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json from cache at /Users/martinsf/.cache/torch/transformers/4be02c5697d91738003fb1685c9872f284166aa32e061576bbe6aaeb95649fcf.085d5f6a8e7812ea05ff0e6ed0645ab2e75d80387ad55c1ad9806ee70d272f80
I10

In [12]:
cur_ids = torch.tensor(tokenizer.encode(" The Matrix is everywhere. It is all around us. Even now, in this very room. You can see it when you look out your window or when you turn on your television. You can feel it when you go to work... when you go to church... when you pay your taxes. It is the world that has been pulled over your eyes to blind you from the truth. ")).unsqueeze(0)

In [18]:
def choose_from_top(probs, n=5):
    ind = np.argpartition(probs, -n)[-n:]
    top_prob = probs[ind]
    top_prob = top_prob / np.sum(top_prob) # Normalize
    choice = np.random.choice(n, 1, p = top_prob)
    token_id = ind[choice][0]
    
    print(f"top_prob: {top_prob} choice: {choice}")
    
    return token_id

In [19]:
model.eval()
with torch.no_grad():
    
    for i in range(100):
        outputs = model(cur_ids, labels=cur_ids)
        loss, logits = outputs[:2]
        softmax_logits = torch.softmax(logits[0,-1], dim=0) #Take the first(only one) batch and the last predicted embedding
        next_token_id = choose_from_top(softmax_logits.numpy(), n=5) #Randomly(from the given probability distribution) choose the next word from the top n words
        cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long() * next_token_id], dim = 1) # Add the last word

    output_list = list(cur_ids.squeeze().numpy())
    output_text = tokenizer.decode(output_list)
    print(output_text)

top_prob: [9.3155722e-05 2.7247870e-04 9.8721393e-05 3.6263184e-04 9.9917305e-01] choice: [4]
top_prob: [7.6523706e-05 9.9719161e-01 1.5862266e-04 2.4882520e-03 8.4984953e-05] choice: [1]
top_prob: [1.2930277e-04 1.4282188e-04 7.4180745e-04 3.1581570e-03 9.9582785e-01] choice: [4]
top_prob: [1.05494240e-04 3.54819087e-04 1.06825755e-04 9.99136746e-01
 2.96109793e-04] choice: [3]
top_prob: [6.9632268e-05 7.3165764e-05 2.5476173e-03 9.9714917e-01 1.6041707e-04] choice: [3]
top_prob: [1.1009357e-04 1.1233866e-04 9.9666548e-01 6.1074464e-04 2.5013629e-03] choice: [2]
top_prob: [1.11698915e-04 1.12981099e-04 3.05250345e-04 9.99158502e-01
 3.11551237e-04] choice: [3]
top_prob: [6.9391303e-05 7.8807207e-05 1.5586827e-04 2.6714613e-03 9.9702448e-01] choice: [4]
top_prob: [9.6811564e-05 1.0963494e-04 9.9674791e-01 5.3356669e-04 2.5120892e-03] choice: [2]
top_prob: [1.0886505e-04 1.1686112e-04 2.7823006e-04 3.0914496e-04 9.9918687e-01] choice: [4]
top_prob: [7.6498691e-05 8.4220170e-05 1.4726172

top_prob: [1.1546009e-04 1.3526282e-04 1.6789863e-04 4.3118335e-03 9.9526954e-01] choice: [4]
top_prob: [1.6469926e-04 2.1858762e-04 8.2392924e-02 9.1583848e-01 1.3853372e-03] choice: [3]
top_prob: [1.1469612e-04 1.1911415e-04 9.9920946e-01 2.3727746e-04 3.1947202e-04] choice: [2]
top_prob: [1.1615106e-04 1.3431473e-04 9.9511808e-01 1.6874884e-04 4.4627129e-03] choice: [2]
top_prob: [1.7086152e-04 2.3052000e-04 1.3579886e-03 9.2097145e-01 7.7269159e-02] choice: [3]
top_prob: [1.14097755e-04 2.37847329e-04 9.99212563e-01 3.17364378e-04
 1.18133474e-04] choice: [2]
top_prob: [1.1741950e-04 9.9496341e-01 4.6155457e-03 1.7046939e-04 1.3322030e-04] choice: [1]
top_prob: [1.8262157e-04 2.3517874e-04 1.3104015e-03 9.2934632e-01 6.8925470e-02] choice: [3]
top_prob: [1.12473746e-04 1.16514224e-04 9.99220073e-01 2.37102620e-04
 3.13846482e-04] choice: [2]
top_prob: [1.1724269e-04 1.7016077e-04 1.3180544e-04 4.7654207e-03 9.9481535e-01] choice: [4]
top_prob: [1.8861878e-04 2.3218214e-04 1.2562952