## Text generation with a miniature GPT

### Introduction
This example demonstrates how to implement an autoregressive language model using a miniature version of the GPT model. The model consists of a single Transformer block with causal masking in its attention layer. We use the text from the IMDB sentiment classification dataset for training and generate new movie reviews for a given prompt. When using this script with your own dataset, make sure it has at least 1 million words.

In [3]:
import tensorflow as tf

import numpy as np
import os
import re
import string
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
import glob
import math

### Prepare the data for word-level language modelling


In [4]:
def get_text_list_from_files(files):
    text_list = []
    for name in files:
        with open(name) as f:
            for line in f:
                text_list.append(line)
    return text_list

# label -> pos : 1, neg : 0 
def get_data_from_text_files(folder_name):

    pos_files = glob.glob("datasets/aclImdb/" + folder_name + "/pos/*.txt")
    pos_texts = get_text_list_from_files(pos_files)
    neg_files = glob.glob("datasets/aclImdb/" + folder_name + "/neg/*.txt")
    neg_texts = get_text_list_from_files(neg_files)
    df = pd.DataFrame(
        {
            "review": pos_texts + neg_texts,
            "sentiment": [1] * len(pos_texts) + [0] * len(neg_texts),
        }
    )
    # sampling 후 reset_index : index 초기화 (https://yganalyst.github.io/data_handling/Pd_2/)
    # 두 데이터 프레임을 합치면서 index를 0부터 초기화, drop=True : 기존 index를 버림
    df = df.sample(len(df)).reset_index(drop=True)
    return df


train_df = get_data_from_text_files("train")
test_df = get_data_from_text_files("test")

all_data = train_df.append(test_df)

In [5]:
all_data.head()

Unnamed: 0,review,sentiment
0,The Last Hunt is one of the few westerns ever ...,1
1,I usually try to construct reasonably well-arg...,0
2,Preston Waters is off to a bad summer. Besides...,1
3,This is the single worst movie I have ever see...,0
4,(some spoilers) - as if you wouldn't know how ...,0


In [6]:
print("# of train : ", len(train_df))
print("# of test : ", len(test_df))

# of train :  25000
# of test :  25000


In [7]:
# train_df = train_df.sample(5000)
# test_df = test_df.sample(5000)

### Build Vocab and Encoding Module

In [8]:
# Tensor -> bytes(from numpy) -> string

def custom_standardization(input_string):
    """ Remove html line-break tags and handle punctuation """
    lowercased = tf.strings.lower(input_string)
    stripped_html = tf.strings.regex_replace(lowercased, "<br />", " ")
    return tf.strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1").numpy().decode('UTF-8')

In [9]:
custom_standardization('An interesting idea for a film, both show')

'an interesting idea for a film , both show'

In [10]:
corpus=[]
for i in range(len(train_df.review.values)):
    if (i+1) % (len(train_df.review.values)/10) == 0:
        print(f'{(i+1)/(len(train_df)/100)} % Done')
    tokens = custom_standardization((train_df.review.values[i])).split()
    corpus.append(tokens)

10.0 % Done
20.0 % Done
30.0 % Done
40.0 % Done
50.0 % Done
60.0 % Done
70.0 % Done
80.0 % Done
90.0 % Done
100.0 % Done


In [11]:
corpus[0]

['the',
 'last',
 'hunt',
 'is',
 'one',
 'of',
 'the',
 'few',
 'westerns',
 'ever',
 'made',
 'to',
 'deal',
 'with',
 'buffalo',
 'hunting',
 ',',
 'both',
 'as',
 'a',
 'sport',
 'and',
 'business',
 'and',
 'as',
 'a',
 'method',
 'of',
 'winning',
 'the',
 'plains',
 'indian',
 'wars',
 '.',
 'before',
 'the',
 'white',
 'man',
 'set',
 'foot',
 'on',
 'the',
 'other',
 'side',
 'of',
 'the',
 'mississippi',
 ',',
 'the',
 'plains',
 'used',
 'to',
 'have',
 'herds',
 'of',
 'american',
 'bison',
 'as',
 'large',
 'as',
 'some',
 'of',
 'our',
 'largest',
 'cities',
 '.',
 'by',
 'the',
 'time',
 'of',
 'the',
 'period',
 'the',
 'last',
 'hunt',
 'is',
 'set',
 'in',
 ',',
 'the',
 'buffalo',
 'had',
 'been',
 'all',
 'but',
 'wiped',
 'out',
 '.',
 'the',
 '20th',
 'century',
 ',',
 'due',
 'to',
 'the',
 'efforts',
 'of',
 'conservationists',
 ',',
 'saw',
 'a',
 'revival',
 'in',
 'population',
 'of',
 'the',
 'species',
 ',',
 'but',
 'not',
 'hardly',
 'like',
 'it',
 'once

In [37]:
from itertools import chain

corpus_flat=list(chain.from_iterable(corpus))

In [None]:
# corpus_flat = sum(corpus, [])

In [38]:
corpus_flat

['the',
 'last',
 'hunt',
 'is',
 'one',
 'of',
 'the',
 'few',
 'westerns',
 'ever',
 'made',
 'to',
 'deal',
 'with',
 'buffalo',
 'hunting',
 ',',
 'both',
 'as',
 'a',
 'sport',
 'and',
 'business',
 'and',
 'as',
 'a',
 'method',
 'of',
 'winning',
 'the',
 'plains',
 'indian',
 'wars',
 '.',
 'before',
 'the',
 'white',
 'man',
 'set',
 'foot',
 'on',
 'the',
 'other',
 'side',
 'of',
 'the',
 'mississippi',
 ',',
 'the',
 'plains',
 'used',
 'to',
 'have',
 'herds',
 'of',
 'american',
 'bison',
 'as',
 'large',
 'as',
 'some',
 'of',
 'our',
 'largest',
 'cities',
 '.',
 'by',
 'the',
 'time',
 'of',
 'the',
 'period',
 'the',
 'last',
 'hunt',
 'is',
 'set',
 'in',
 ',',
 'the',
 'buffalo',
 'had',
 'been',
 'all',
 'but',
 'wiped',
 'out',
 '.',
 'the',
 '20th',
 'century',
 ',',
 'due',
 'to',
 'the',
 'efforts',
 'of',
 'conservationists',
 ',',
 'saw',
 'a',
 'revival',
 'in',
 'population',
 'of',
 'the',
 'species',
 ',',
 'but',
 'not',
 'hardly',
 'like',
 'it',
 'once

In [39]:
from collections import Counter

all_tokens_array = np.array(corpus_flat)

counter = Counter(all_tokens_array)
print(len(counter))

vocab_size = 20000

104815


In [40]:
# We only take (vocab_size - 2) most commons words from the training data since
# the `StringLookup` class uses 2 additional tokens - one denoting an unknown
vocabulary = [token for token, count in counter.most_common(vocab_size - 2)]

In [41]:
len(vocabulary)

19998

In [42]:
import pickle

# save
with open('datasets/pickle_data/vocab_large_text_tg_gpt.pickle', 'wb') as f:
    pickle.dump(vocabulary, f)

* 사실상 padding 거의 사용 안 됨, 긴 문장에서 80 만큼만 잘라서 input, target 구성

In [43]:
import pickle

with open('datasets/pickle_data/vocab_large_text_tg_gpt.pickle', 'rb') as f:
    vocabulary = pickle.load(f)

In [44]:
token_to_id = {}
token_to_id['[UNK]'] = 0
token_to_id['[PAD]'] = 1
for i, token in enumerate(vocabulary):
    token_to_id[token] = i + 2

In [45]:
token_to_id

{'[UNK]': 0,
 '[PAD]': 1,
 'the': 2,
 '.': 3,
 ',': 4,
 'a': 5,
 'and': 6,
 'of': 7,
 'to': 8,
 'is': 9,
 'it': 10,
 'in': 11,
 'i': 12,
 'this': 13,
 'that': 14,
 "'s": 15,
 'was': 16,
 'as': 17,
 'for': 18,
 'with': 19,
 'movie': 20,
 'but': 21,
 'film': 22,
 ')': 23,
 'on': 24,
 'you': 25,
 "'t": 26,
 '"': 27,
 'not': 28,
 'he': 29,
 'are': 30,
 'his': 31,
 'have': 32,
 'be': 33,
 'one': 34,
 '!': 35,
 'all': 36,
 'at': 37,
 'they': 38,
 'by': 39,
 'an': 40,
 'who': 41,
 'from': 42,
 'so': 43,
 'like': 44,
 '-': 45,
 'there': 46,
 'her': 47,
 'just': 48,
 'about': 49,
 'or': 50,
 'has': 51,
 'out': 52,
 'if': 53,
 '?': 54,
 'what': 55,
 'some': 56,
 'good': 57,
 'can': 58,
 'more': 59,
 'when': 60,
 'very': 61,
 'she': 62,
 'would': 63,
 'up': 64,
 'time': 65,
 'even': 66,
 'no': 67,
 'my': 68,
 'story': 69,
 'only': 70,
 'really': 71,
 'their': 72,
 'had': 73,
 'see': 74,
 'which': 75,
 'were': 76,
 'me': 77,
 'we': 78,
 'well': 79,
 "'": 80,
 'than': 81,
 ':': 82,
 'much': 83,
 'b

In [46]:
id_to_token = {v: k for k, v in token_to_id.items()}

In [47]:
# Input : list, output : list
max_seq_len = 80

def encode_tokens(tokens, max_seq_len=max_seq_len):
    
    encoded = []
    
    for i, token in enumerate(tokens):
        if token in list(token_to_id.keys()):
            encoded.append(token_to_id[token])
        else:
            encoded.append(token_to_id['[UNK]']) # unknown token
    # padding
    if len(tokens) < max_seq_len:
        encoded = encoded + [0] * (max_seq_len - len(tokens))
    # truncate
    elif len(tokens) >= max_seq_len:
        encoded = encoded[:max_seq_len]
        
    return encoded

In [48]:
encode_tokens(corpus[0], 81)

[2,
 245,
 2394,
 9,
 34,
 7,
 2,
 178,
 2961,
 133,
 100,
 8,
 856,
 19,
 5365,
 3391,
 4,
 213,
 17,
 5,
 4049,
 6,
 977,
 6,
 17,
 5,
 4791,
 7,
 1884,
 2,
 16896,
 1422,
 1681,
 3,
 169,
 2,
 473,
 139,
 275,
 2175,
 24,
 2,
 88,
 513,
 7,
 2,
 0,
 4,
 2,
 16896,
 351,
 8,
 32,
 0,
 7,
 337,
 0,
 17,
 1062,
 17,
 56,
 7,
 271,
 9388,
 4792,
 3,
 39,
 2,
 65,
 7,
 2,
 826,
 2,
 245,
 2394,
 9,
 275,
 11,
 4,
 2,
 5365]

* `max_seq_len + 1` 만큼 encoding 하는 이유
* time step 1씩 차이가 나도록

```
tokenized_sentences=vectorize_layer(tf.expand_dims('text is good',-1))
x = tokenized_sentences[:, :-1]
y = tokenized_sentences[:, 1:]

>> tokenized_sentences
... <tf.Tensor: shape=(1, 5), dtype=int64, numpy=array([[1, 1, 1, 0, 0]])>

>> x
... <tf.Tensor: shape=(1, 4), dtype=int64, numpy=array([[1, 1, 1, 0]])>

>> y
... <tf.Tensor: shape=(1, 4), dtype=int64, numpy=array([[1, 1, 0, 0]])>
```

In [31]:
train_text=[]
for i in range(len(corpus)):
    if (i+1) % (len(corpus)/10) == 0:
        print(f'{(i+1)/(len(corpus)/100)} % Done')
    train_text.append(encode_tokens(corpus[i], max_seq_len = max_seq_len + 1))  # encode reviews with vectorizer
train_text = np.array(train_text)
print("train_text shape :", train_text.shape)

10.0 % Done
20.0 % Done
30.0 % Done
40.0 % Done
50.0 % Done
60.0 % Done
70.0 % Done
80.0 % Done
90.0 % Done
100.0 % Done
train_text shape : (25000, 81)


In [32]:
import pickle

# save
with open('datasets/pickle_data/train_large_text_tg_gpt.pickle', 'wb') as f:
    pickle.dump(train_text, f)

In [33]:
import pickle

with open('datasets/pickle_data/train_large_text_tg_gpt.pickle', 'rb') as f:
    train_text = pickle.load(f)

In [34]:
train_x = train_text[:, :-1]
train_y = train_text[:, 1:]

print('train_x shape :', train_x.shape)
print('train_y shape :', train_y.shape)

train_x shape : (25000, 80)
train_y shape : (25000, 80)


* position_(i) 에 해당하는 input token의 target이 position_(i+1)의 token이 되도록 input target 구성

```Python
# from keras
def prepare_lm_inputs_labels(text):
    """
    Shift word sequences by 1 position so that the target for position (i) is
    word at position (i+1). The model will use all words up till position (i)
    to predict the next word.
    """
    text = tf.expand_dims(text, -1)
    tokenized_sentences = vectorize_layer(text)
    x = tokenized_sentences[:, :-1]
    y = tokenized_sentences[:, 1:]
    return x, y
```


In [50]:
train_x[-1:]

array([[  13,   20,  221,   10,  215,    3,   17,    5, 1135,    0,    0,
           0,    4,   12,   58,  381,   25,   13,   20,   51,   10,   36,
           3,    2, 9157,    7,    2,    0, 1771,    3,    2, 1072,   18,
         693,    3,    2, 1845,    7, 5825,    3,    2, 3266,   49,  264,
          10,  153,    2,  267,    3,    2, 9102,    7,  159,   34,   15,
        4066, 6364,    2,  125,    7,  103,    0,   52,    3,    2, 8665,
           7, 2720,    2, 2123, 1833,   45,    2,  352,    7,    0, 1450,
           0,    3,    2]])

In [51]:
train_y[-1:]

array([[  20,  221,   10,  215,    3,   17,    5, 1135,    0,    0,    0,
           4,   12,   58,  381,   25,   13,   20,   51,   10,   36,    3,
           2, 9157,    7,    2,    0, 1771,    3,    2, 1072,   18,  693,
           3,    2, 1845,    7, 5825,    3,    2, 3266,   49,  264,   10,
         153,    2,  267,    3,    2, 9102,    7,  159,   34,   15, 4066,
        6364,    2,  125,    7,  103,    0,   52,    3,    2, 8665,    7,
        2720,    2, 2123, 1833,   45,    2,  352,    7,    0, 1450,    0,
           3,    2, 3312]])

In [52]:
from torch.utils.data import Dataset, DataLoader

class TextGenDataset(Dataset):
    
    def __init__(self, input_tokens, target_tokens):
        self.input_tokens = input_tokens
        self.target_tokens = target_tokens

    def __len__(self):
        return len(self.input_tokens)

    def __getitem__(self, idx):
        
        inputs = self.input_tokens[idx]
        targets = self.target_tokens[idx]
        
        output = {"inputs": inputs,
                  "targets": targets}

        return output

In [17]:
# sample_train_x = train_x[:5000]
# sample_train_y = train_y[:5000]

In [54]:
batch_size = 128

train_dataset = TextGenDataset(train_x, train_y)
# train_dataset = TextGenDataset(sample_train_x, sample_train_y)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

In [55]:
print(iter(train_loader).next())

{'inputs': tensor([[   17,     5,   710,  ..., 10601,  3124,  1747],
        [  528,    55,     5,  ...,  1210,     4,    40],
        [   12,   249,  1215,  ...,    15,    43,    90],
        ...,
        [ 1531,     3,    13,  ...,   420,  7132,     6],
        [   10,    15,   398,  ...,    20,     4,     2],
        [   13,     9,    40,  ...,    21,  1842,     6]]), 'targets': tensor([[    5,   710,   278,  ...,  3124,  1747,  1182],
        [   55,     5,  9659,  ...,     4,    40,  2666],
        [  249,  1215,     6,  ...,    43,    90,    12],
        ...,
        [    3,    13,    20,  ...,  7132,     6,   357],
        [   15,   398,    98,  ...,     4,     2,   214],
        [    9,    40, 10010,  ...,  1842,     6,  3164]])}


In [56]:
len(train_loader)

196

### Implement an embedding layer

Create two seperate embedding layers: one for tokens and one for token index (positions).

In [57]:
class TokenAndPositionEmbedding(nn.Module):
    def __init__(self, vocab_size, max_len, embed_size):
        """
        :param vocab_size: total vocab size
        :param max_len: max length of seqeunce
        :param embed_size: embedding size of token embedding
        """
        super().__init__()
        self.token_emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=0)
        self.pos_emb = nn.Embedding(num_embeddings=max_len, embedding_dim=embed_size)

    def forward(self, sequence, device): # [128, 80]

        max_seq_len = sequence.shape[-1]

        positions = torch.tensor(range(max_seq_len)).to(device)
        positions = self.pos_emb(positions)
        emb_out = self.token_emb(sequence) + positions # torch.Size([128, 80, 256])
        return emb_out

### Implement a Transformer block as a layer

In [58]:
# https://github.com/SamLynnEvans/Transformer/tree/e06ae2810f119c75aa34585442872026875e6462

class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention
    """
    def __init__(self):
        super().__init__()
        
    def attention(self, q, k, v, d_k, mask=None, dropout=None):
        
        scores = torch.matmul(q, k.transpose(-2, -1)) /  math.sqrt(d_k)

        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)

        scores = F.softmax(scores, dim=-1)

        if dropout is not None:
            scores = dropout(scores)

        output = torch.matmul(scores, v)
        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout = 0.0):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)

    def attention(self, q, k, v, d_k, mask=None, dropout=None):
        
        scores = torch.matmul(q, k.transpose(-2, -1)) /  math.sqrt(d_k)

        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)

        scores = F.softmax(scores, dim=-1)

        if dropout is not None:
            scores = dropout(scores)

        output = torch.matmul(scores, v)
        
        return output
    
    def forward(self, q, k, v, mask=None):
                
        bs = q.size(0)
        
        # perform linear operation and split into h heads
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        
        # transpose to get dimensions bs * h * sl * d_model
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)

       # calculate attention using function we will define next
        scores = self.attention(q, k, v, self.d_k, mask)
        
        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
    
        return output

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout = 0.0):
        super().__init__() 
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x):
        x = self.dropout(F.relu(self.linear_1(x)))
        x = self.linear_2(x)
        return x

class Norm(nn.Module):
    def __init__(self, d_model, eps = 1e-6):
        super().__init__()
    
        self.size = d_model
        # create two learnable parameters to calibrate normalisation
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        self.eps = eps
        
    def forward(self, x):
        norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
        / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
        return norm

#### Decoder의 Masked Self-attention 을 위한 Causal mask 생성
https://machinereads.wordpress.com/2020/05/10/xlnet-generalized-autoregressive-pretraining-for-language-understanding-1-3/

**질문?** query, key의 길이가 달라지는 경우가 있는지?
- Keras 구현에 의하면, query, key의 길이를 각각 받아서 mask 생성
- Self-attention이면 query 길이 하나만 받아도 되지 않나?

In [59]:
#####
#  Keras Implementation
#####

def causal_attention_mask(batch_size, n_dest, n_src, dtype):
    """
    Mask the upper half of the dot product matrix in self attention.
    This prevents flow of information from future tokens to current token.
    1's in the lower triangle, counting from the lower right corner.
    """
    i = tf.range(n_dest)[:, None]
    j = tf.range(n_src)
    m = i >= j - n_src + n_dest
    mask = tf.cast(m, dtype)
    mask = tf.reshape(mask, [1, n_dest, n_src])
    mult = tf.concat(
        [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
    )
    
    return tf.tile(mask, mult)

In [60]:
# masked self attention

causal_mask = causal_attention_mask(batch_size=2, n_dest=5, n_src=5, dtype=tf.bool)
causal_mask

<tf.Tensor: shape=(2, 5, 5), dtype=bool, numpy=
array([[[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]],

       [[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]]])>

In [61]:
causal_attention_mask(1,3,5, dtype=tf.bool)

<tf.Tensor: shape=(1, 3, 5), dtype=bool, numpy=
array([[[ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]]])>

In [62]:
causal_attention_mask(1,5,3, dtype=tf.bool)

<tf.Tensor: shape=(1, 5, 3), dtype=bool, numpy=
array([[[False, False, False],
        [False, False, False],
        [ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]]])>

In [63]:
causal_attention_mask(1,5,5, dtype=tf.bool)

<tf.Tensor: shape=(1, 5, 5), dtype=bool, numpy=
array([[[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]]])>

In [64]:
#####
#  Simple Implementation
#####

# https://numpy.org/doc/stable/reference/generated/numpy.tril.html

def simple_causal_attention_mask(batch_size, n_dest, n_src):
    """
    Mask the upper half of the dot product matrix in self attention.
    This prevents flow of information from future tokens to current token.
    1's in the lower triangle, counting from the lower right corner.
    """
#     print(np.ones((batch_size, n_dest, n_src)))
#     print(np.tril(np.ones((batch_size, n_dest, n_src))))
    mask = np.tril(np.ones((batch_size, n_dest, n_src)), n_src-n_dest)

    return mask

In [65]:
# 3-5=-2, 음수 값 만큼 내려가서 diagonal 시작
simple_causal_attention_mask(1,5,3)

array([[[0., 0., 0.],
        [0., 0., 0.],
        [1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]]])

In [66]:
# 5-3=2, 양수 값 만큼 올라간 후 diagonal 시작
simple_causal_attention_mask(1,3,5)

array([[[1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]]])

In [67]:
# ! No Encoder-decoder attention

class TransformerDecoderLayer(nn.Module):
    def __init__(self, seq_len, d_model, heads, d_ff, dropout = 0.0):
        super().__init__()
        self.attn = MultiHeadAttention(heads, d_model)
        self.ff = FeedForward(d_model, d_ff)
        self.norm_1 = Norm(d_model)
        self.norm_2 = Norm(d_model)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        
        self.seq_len = seq_len
        
    def forward(self, batch_size, x, device, mask=None):
        
        causal_mask = simple_causal_attention_mask(batch_size, self.seq_len, self.seq_len)
        causal_mask = torch.tensor(np.array(causal_mask)).to(device)
        
        # self attention + add&norm
        attention_output = self.attn(x, x, x, mask=causal_mask) # Masked self attention
        attention_output = self.dropout_1(attention_output)
        masked_self_att_out = self.norm_1(x + attention_output)

        # ffn + add&norm
        ffn_output = self.ff(masked_self_att_out)
        ffn_output = self.dropout_2(ffn_output)
        out = self.norm_2(masked_self_att_out + ffn_output)
        
        return out

### Implement the miniature GPT model


In [68]:
class MiniGPT(nn.Module):
    def __init__(self, device, vocab_size, max_seq_len, d_model, heads, d_ff):
        super(MiniGPT, self).__init__()
        self.device = device
        self.token_pos_embedding = TokenAndPositionEmbedding(vocab_size=vocab_size, max_len=max_seq_len, embed_size=d_model)
        self.transformer_block = TransformerDecoderLayer(seq_len=max_seq_len, d_model=d_model, heads=heads, d_ff=d_ff, dropout=0.1)
        self.ff = nn.Linear(d_model, vocab_size)
        
    def forward(self, batch_size, inputs):
        x = self.token_pos_embedding(inputs, self.device) 
        x = self.transformer_block(batch_size, x, self.device)
        x = self.ff(x) # torch.Size([batch_size, max_seq_len, vocab_size])
        
        return x

### Implement a callback for generating text -> PyTorch

In [69]:
# TODO : reimp not using tf

from tensorflow import keras

class TextGenerator:
    """A callback to generate text from a trained model.
    1. Feed some starting prompt to the model
    2. Predict probabilities for the next token
    3. Sample the next token and add it to the next input

    Arguments:
        max_tokens: Integer, the number of tokens to be generated after prompt.
        start_tokens: List of integers, the token indices for the starting prompt.
        index_to_word: List of strings, obtained from the TextVectorization layer.
        top_k: Integer, sample from the `top_k` token predictions.
        print_every: Integer, print after this many epochs.
    """

    def __init__(
        self, model, device, max_seq_len, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1
    ):
        self.model = model
        self.device = device
        self.max_seq_len = max_seq_len
        self.max_tokens = max_tokens
        self.start_tokens = start_tokens
        self.index_to_word = index_to_word
        self.print_every = print_every
        self.k = top_k
        
    # random choince from top_k predicted tokens
    # logits : (20000,)
    def sample_from(self, logits):
        logits, indices = tf.math.top_k(logits, k=self.k, sorted=True)
        indices = np.asarray(indices).astype("int32")
        preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0]
        preds = np.asarray(preds).astype("float32")
        return np.random.choice(indices, p=preds)

    def detokenize(self, number):
        return self.index_to_word[number]

    def on_epoch_end(self, epoch, logs=None):
        start_tokens = [_ for _ in self.start_tokens]
        if (epoch + 1) % self.print_every != 0:
            return
        num_tokens_generated = 0
        tokens_generated = []
        while num_tokens_generated <= self.max_tokens:
            pad_len = self.max_seq_len - len(start_tokens)
            # current token index
            current_index = len(start_tokens) - 1
            if pad_len < 0:
                x = start_tokens[:self.max_seq_len]
                current_index = self.max_seq_len - 1
            elif pad_len > 0:
                x = start_tokens + [0] * pad_len
            else:
                x = start_tokens
            x = torch.tensor([x]).to(self.device)
            y = self.model(x.shape[0], x)
            y = y.to('cpu').detach().numpy() # torch.Size([128, 80, 20000])
            
            sample_token = self.sample_from(y[0][current_index]) # token id
            
            tokens_generated.append(sample_token)
            start_tokens.append(sample_token)
            num_tokens_generated = len(tokens_generated)
            
        txt = " ".join(
            [self.detokenize(_) for _ in self.start_tokens + tokens_generated]
        )
        print(f"generated text:\n{txt}\n")

### Train the model

In [70]:
batch_size = 128
vocab_size = 20000  # Only consider the top 20k words
max_seq_len = 80  # Max sequence size
d_model = 256  # Embedding size for each token
heads = 2  # Number of attention heads
d_ff = 256  # Hidden layer size in feed forward network inside transformer

from tqdm.notebook import tqdm
from einops import rearrange
# https://githubmemory.com/repo/arogozhnikov/einops/issues

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print ('Current device :', device)
model = MiniGPT(device=device, vocab_size=vocab_size, max_seq_len=max_seq_len, d_model=d_model, heads=heads, d_ff=d_ff)

learning_rate = 0.001 

criterion =  nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Current device : cuda


In [71]:
import tensorflow as tf
print(tf.__version__)

2.5.0


In [None]:
num_epochs=25
total_step = len(train_loader)

model = model.to(device)
start_prompt = "this movie is"
start_tokens = [token_to_id.get(_, 1) for _ in start_prompt.split()]
num_tokens_generated = 40

for epoch in tqdm(range(0, num_epochs)):    
    for i_batch, sample_batched in enumerate(train_loader):

        batch_inputs = sample_batched['inputs'].to(device) # torch.Size([128, 80])
        batch_targets = sample_batched['targets'].to(device)
        
        batch_size = batch_targets.size(0)
        
        # Forward
        outputs = model(batch_size, batch_inputs) # torch.Size([128, 80, 20000])
        
        # Compute loss
        batch_predicts = rearrange(outputs, 'b c l -> b l c') # torch.Size([128, 20000, 80])
        loss = criterion(batch_predicts, batch_targets) # torch.Size([128, 80])

        # Backward
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if (i_batch+1) % 5 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch+1, num_epochs, i_batch, total_step, loss.item())) 
    
    # Save the model checkpoints
    torch.save(model.state_dict(), './models/text_gen_gpt-{}.ckpt'.format(epoch+1))
    text_gen_callback = TextGenerator(model, device, max_seq_len, num_tokens_generated, start_tokens, id_to_token)
    print(f'>> Epoch {epoch + 1}')
    print(text_gen_callback.on_epoch_end(epoch))

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

Epoch [1/25], Step [4/196], Loss: 9.4184
Epoch [1/25], Step [9/196], Loss: 7.7152
Epoch [1/25], Step [14/196], Loss: 6.8842
Epoch [1/25], Step [19/196], Loss: 6.5720
Epoch [1/25], Step [24/196], Loss: 6.6629
Epoch [1/25], Step [29/196], Loss: 6.7069
Epoch [1/25], Step [34/196], Loss: 6.5952
Epoch [1/25], Step [39/196], Loss: 6.6938
Epoch [1/25], Step [44/196], Loss: 6.5765
Epoch [1/25], Step [49/196], Loss: 6.5991
Epoch [1/25], Step [54/196], Loss: 6.6001
Epoch [1/25], Step [59/196], Loss: 6.5826
Epoch [1/25], Step [64/196], Loss: 6.5045
Epoch [1/25], Step [69/196], Loss: 6.5399
Epoch [1/25], Step [74/196], Loss: 6.5148
Epoch [1/25], Step [79/196], Loss: 6.4928
Epoch [1/25], Step [84/196], Loss: 6.4652
Epoch [1/25], Step [89/196], Loss: 6.4218
Epoch [1/25], Step [94/196], Loss: 6.3699
Epoch [1/25], Step [99/196], Loss: 6.4072
Epoch [1/25], Step [104/196], Loss: 6.2384
Epoch [1/25], Step [109/196], Loss: 6.2437
Epoch [1/25], Step [114/196], Loss: 6.2327
Epoch [1/25], Step [119/196], Los