In [5]:
import numpy as np
from tqdm.auto import tqdm

from hcrot import layers, optim

In [6]:
sentences = [
    "The quick brown fox jumps over the lazy dog",
    "A journey of a thousand miles begins with a single step",
    "To be or not to be that is the question",
    "All that glitters is not gold but it is very valuable",
    "Knowledge is power but enthusiasm pulls the switch",
    "The only thing we have to fear is fear itself",
    "In the end we will remember not the words of our enemies",
    "Life is what happens when you’re busy making other plans",
    "To succeed in life you need two things ignorance and confidence",
    "The future belongs to those who believe in the beauty of their dreams"
]

vocab = {}
for sentence in sentences:
    for word in sentence.split():
        if word not in vocab:
            vocab[word] = len(vocab)

vocab['<pad>'] = len(vocab.keys())
vocab['<eos>'] = len(vocab.keys())

inverse_vocab = {v: k for k, v in vocab.items()}
vocab_size = len(vocab)

def tokenize(sentence):
    return [vocab[word] for word in sentence.split()]

data = [tokenize(sentence) for sentence in sentences]

max_len = max(len(sentence) for sentence in data)
padded_data = [sentence + [vocab['<pad>']] * (max_len - len(sentence)) + [vocab['<eos>']] for sentence in data]
padded_data = np.array(padded_data)

In [7]:
def get_sinusoid_encoding_table(n_seq, d_hidn):
    # refs: https://paul-hyun.github.io/transformer-01/
    def cal_angle(position, i_hidn):
        return position / np.power(10000, 2 * (i_hidn // 2) / d_hidn)
    def get_posi_angle_vec(position):
        return [cal_angle(position, i_hidn) for i_hidn in range(d_hidn)]

    sinusoid_table = np.array([get_posi_angle_vec(i_seq) for i_seq in range(n_seq)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # even index sin 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # odd index cos

    return sinusoid_table

class GPT(layers.Module):
    def __init__(self, vocab_size, embed_size, num_heads, num_layers, max_len=512):
        super().__init__()
        self.embed_size = embed_size
        self.embedding = layers.Embedding(vocab_size, embed_size)
        self.positional_encoding = np.expand_dims(get_sinusoid_encoding_table(max_len, embed_size), axis=0)
        self.transformer_decoder_layer = layers.TransformerDecoderLayer(
            d_model=embed_size,
            nhead=num_heads,
            dim_feedforward=embed_size * 4,
            batch_first=True
        )
        self.transformer_decoder = layers.TransformerDecoder(
            self.transformer_decoder_layer,
            num_layers=num_layers,
        )
        self.fc_out = layers.Linear(embed_size, vocab_size)

    def forward(self, tgt):
        tgt_len = tgt.shape[1]
        tgt_mask = self._generate_square_subsequent_mask(tgt_len)

        tgt_emb = self.embedding(tgt) + self.positional_encoding[:, :tgt.shape[1], :]

        output = self.transformer_decoder(tgt_emb, tgt_emb, tgt_mask=tgt_mask)
        output = self.fc_out(output)

        return output

    def _generate_square_subsequent_mask(self, sz):
        mask = np.triu(np.ones((sz, sz)), 1)
        return mask

In [30]:
embed_size = 256
num_heads = 4
num_layers = 3

model = GPT(vocab_size, embed_size, num_heads, num_layers)
criterion = layers.CrossEntropyLoss()
optimizer = optim.Adam(model, lr_rate=1e-3)

inputs = padded_data[:, :-1]
targets = padded_data[:, 1:]
bsz, seq_len = inputs.shape

num_epochs = 50
for epoch in range(num_epochs):
    outputs = model.forward(inputs)
    
    outputs = outputs.reshape(-1, vocab_size)
    targets = targets.reshape(-1)
    loss = criterion(outputs, targets)

    dz = criterion.backward()
    dz = dz.reshape(bsz, seq_len, -1)
    optimizer.update(dz)
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

Epoch 1/50, Loss: 4.470611729039826
Epoch 2/50, Loss: 3.822672433568201
Epoch 3/50, Loss: 3.305046282255573
Epoch 4/50, Loss: 2.948956121846243
Epoch 5/50, Loss: 2.347445292140538
Epoch 6/50, Loss: 1.891882693010152
Epoch 7/50, Loss: 1.541612920097274
Epoch 8/50, Loss: 1.1817331319095838
Epoch 9/50, Loss: 0.9197488995104414
Epoch 10/50, Loss: 0.684036989145343
Epoch 11/50, Loss: 0.5085725404061618
Epoch 12/50, Loss: 0.36315022370106964
Epoch 13/50, Loss: 0.2658197968044208
Epoch 14/50, Loss: 0.2063143595293862
Epoch 15/50, Loss: 0.1597128879574305
Epoch 16/50, Loss: 0.12694673902448858
Epoch 17/50, Loss: 0.1050292605532066
Epoch 18/50, Loss: 0.08627863200116671
Epoch 19/50, Loss: 0.07179282088260787
Epoch 20/50, Loss: 0.06232451579747673
Epoch 21/50, Loss: 0.05056625525856857
Epoch 22/50, Loss: 0.045285526152731036
Epoch 23/50, Loss: 0.039436409540196414
Epoch 24/50, Loss: 0.03550112554714281
Epoch 25/50, Loss: 0.03308612833875222
Epoch 26/50, Loss: 0.027807124843697965
Epoch 27/50, Lo

In [31]:
sentences

['The quick brown fox jumps over the lazy dog',
 'A journey of a thousand miles begins with a single step',
 'To be or not to be that is the question',
 'All that glitters is not gold but it is very valuable',
 'Knowledge is power but enthusiasm pulls the switch',
 'The only thing we have to fear is fear itself',
 'In the end we will remember not the words of our enemies',
 'Life is what happens when you’re busy making other plans',
 'To succeed in life you need two things ignorance and confidence',
 'The future belongs to those who believe in the beauty of their dreams']

In [41]:
def generate_sentence(model, start_sentence, max_len):
    generated = [vocab[token] for token in start_sentence.split()]
    input_seq = np.expand_dims(np.array(generated),0)
    
    while len(input_seq) < max_len:
        output = model.forward(input_seq)
        next_token_logits = output[-1, -1]
        next_token = np.argmax(next_token_logits).item()
        generated.append(next_token)
        if next_token == vocab['<eos>']:
            break
        input_seq = np.array([generated])
        
    return ' '.join(inverse_vocab[token] for token in generated)

print(generate_sentence(model, 'The', max_len=100))

The future belongs to those who believe in the beauty of their <eos>
