In [44]:
import numpy as np
from tqdm.auto import tqdm

from hcrot import layers, optim
from hcrot.utils import softmax

In [45]:
sentences = [
    "The sun always rises after the darkest night, so keep hope alive",
    "A bird in the hand is worth two in the bush, so hold on tightly",
    "Fortune favors the brave, but wisdom guides their steps",
    "The squeaky wheel gets the grease, so speak up when needed",
    "A watched pot never boils, so do not waste your time staring",
    "He who hesitates is lost, so act decisively when the time comes",
    "A penny saved is a penny earned, so start saving early",
    "A chain is only as strong as its weakest link, so strengthen every part",
    "Absence makes the heart grow fonder, but do not forget to stay in touch",
    "Do unto others as you would have them do unto you, and live with kindness",
    "The harder you work, the luckier you get, so never stop trying",
    "Patience is a virtue, but persistence brings success",
    "A rolling stone gathers no moss, so keep moving forward",
    "Great minds think alike, but they also think differently",
    "Birds of a feather flock together, so choose your company wisely",
    "A fool and his money are soon parted, so spend with caution",
    "Time heals all wounds, but some scars may remain",
    "A stitch in time saves nine, so do not delay your efforts",
    "Honesty is the best policy, but tact is also important",
    "If it is not broken, do not fix it, but always seek improvement",
    "Necessity is the mother of invention, so embrace challenges",
    "Do not bite the hand that feeds you, but show gratitude instead",
    "An ounce of prevention is worth a pound of cure, so plan ahead",
    "A little knowledge is a dangerous thing, so always keep learning",
    "Practice what you preach, but also be open to feedback",
    "Where there is smoke, there is fire, so investigate further",
    "Every dog has its day, so your time will come too",
    "Do not put all your eggs in one basket, but diversify your risks",
    "The best things in life are free, but they require effort",
    "A journey well planned is a journey half completed"
]

In [46]:
vocab = {}
for sentence in sentences:
    for word in sentence.split():
        if word not in vocab:
            vocab[word] = len(vocab)

vocab['<pad>'] = len(vocab)
vocab['<sos>'] = len(vocab)
vocab['<eos>'] = len(vocab)

inverse_vocab = {v: k for k, v in vocab.items()}
vocab_size = len(vocab)

def tokenize(sentence):
    return [vocab[word] for word in sentence.split()]

data = [tokenize(sentence) for sentence in sentences]

max_len = max(len(sentence) for sentence in data)
padded_data = [[vocab['<sos>']] + sentence + [vocab['<pad>']] * (max_len - len(sentence)) + [vocab['<eos>']] for sentence in data]
padded_data = np.array(padded_data)

In [47]:
def get_sinusoid_encoding_table(n_seq, d_hidn):
    # refs: https://paul-hyun.github.io/transformer-01/
    def cal_angle(position, i_hidn):
        return position / np.power(10000, 2 * (i_hidn // 2) / d_hidn)
    def get_posi_angle_vec(position):
        return [cal_angle(position, i_hidn) for i_hidn in range(d_hidn)]

    sinusoid_table = np.array([get_posi_angle_vec(i_seq) for i_seq in range(n_seq)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # even index sin 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # odd index cos

    return sinusoid_table

class GPT(layers.Module):
    def __init__(self, vocab_size, embed_size, num_heads, num_layers, max_len=16):
        super().__init__()
        self.embed_size = embed_size
        self.embedding = layers.Embedding(vocab_size, embed_size)
        self.positional_encoding = np.expand_dims(get_sinusoid_encoding_table(max_len, embed_size), axis=0)
        self.transformer_decoder_layer = layers.TransformerDecoderLayer(
            d_model=embed_size,
            nhead=num_heads,
            dim_feedforward=embed_size * 4,
            batch_first=True
        )
        self.transformer_decoder = layers.TransformerDecoder(
            self.transformer_decoder_layer,
            num_layers=num_layers,
        )
        self.fc_out = layers.Linear(embed_size, vocab_size)

    def forward(self, tgt):
        tgt_len = tgt.shape[1]
        tgt_mask = self._generate_square_subsequent_mask(tgt_len)

        tgt_emb = self.embedding(tgt) + self.positional_encoding[:, :tgt_len, :]

        output = self.transformer_decoder(tgt_emb, tgt_emb, tgt_mask=tgt_mask)
        output = self.fc_out(output)

        return output

    def _generate_square_subsequent_mask(self, sz):
        mask = np.triu(np.ones((sz, sz)), 1)
        return mask

In [48]:
embed_size = 128
num_heads = 8
num_layers = 2

model = GPT(vocab_size, embed_size, num_heads, num_layers)
criterion = layers.CrossEntropyLoss()
optimizer = optim.Adam(model, lr_rate=1e-3)

inputs = padded_data[:, :-1]
targets = padded_data[:, 1:]
bsz, seq_len = inputs.shape

num_epochs = 100
pbar = tqdm(range(num_epochs))
for epoch in pbar:
    outputs = model.forward(inputs)
    
    outputs = outputs.reshape(-1, vocab_size)
    targets = targets.reshape(-1)
    loss = criterion(outputs, targets)

    dz = criterion.backward()
    dz = dz.reshape(bsz, seq_len, -1)
    optimizer.update(dz)
    
    pbar.set_description(f'Loss: {loss.item():5f}')

  0%|          | 0/100 [00:00<?, ?it/s]

Loss: 0.045446: 100%|██████████| 100/100 [01:11<00:00,  1.39it/s]


In [49]:
sentences

['The sun always rises after the darkest night, so keep hope alive',
 'A bird in the hand is worth two in the bush, so hold on tightly',
 'Fortune favors the brave, but wisdom guides their steps',
 'The squeaky wheel gets the grease, so speak up when needed',
 'A watched pot never boils, so do not waste your time staring',
 'He who hesitates is lost, so act decisively when the time comes',
 'A penny saved is a penny earned, so start saving early',
 'A chain is only as strong as its weakest link, so strengthen every part',
 'Absence makes the heart grow fonder, but do not forget to stay in touch',
 'Do unto others as you would have them do unto you, and live with kindness',
 'The harder you work, the luckier you get, so never stop trying',
 'Patience is a virtue, but persistence brings success',
 'A rolling stone gathers no moss, so keep moving forward',
 'Great minds think alike, but they also think differently',
 'Birds of a feather flock together, so choose your company wisely',
 'A 

In [90]:
def generate_sentence(model, start_sentence, max_len):
    generated = [vocab[token] for token in start_sentence.split()]
    input_seq = np.expand_dims(np.array(generated),0)
    
    while len(input_seq[0]) < max_len:
        # print(' '.join(inverse_vocab[token] for token in input_seq[0]))
        output = model.forward(input_seq)
        next_token_logits = output[-1, -1]
        next_token = np.argmax(next_token_logits).item()
        generated.append(next_token)
        if next_token == vocab['<eos>']:
            break
        input_seq = np.array([generated])
        
    return ' '.join(inverse_vocab[token] for token in generated)

def generate_with_top_k_and_top_p(model, start_sentence, max_len, top_k=0, top_p=1.0):
    def top_k_top_p_filtering(logits, top_k=0, top_p=0.1):
        filter_value = float('-inf')
        if top_k > 0:
            indices_to_remove = logits < np.take_along_axis(logits, np.argsort(-logits,axis=-1)[:,:top_k], axis=-1)[..., -1, None]
            logits[indices_to_remove] = filter_value
        
        if top_p < 1.:
            sorted_indices = np.argsort(-logits)
            sorted_logits = -np.sort(-logits)
            
            cumulative_probs = np.cumsum(softmax(sorted_logits, dim=-1), axis=-1)
            
            sorted_indices_to_remove = cumulative_probs > top_p
            
            # shift
            sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1)
            sorted_indices_to_remove[:,0] = 0
            
            indices_to_remove = np.zeros_like(logits, dtype=bool)
            np.put_along_axis(indices_to_remove, sorted_indices, sorted_indices_to_remove, axis=-1)
            
            logits[indices_to_remove] = filter_value
            
        return logits
    
    generated = [vocab[token] for token in start_sentence.split()]
    input_seq = np.expand_dims(np.array(generated),0)
    
    while len(input_seq[0]) < max_len:
        output = model.forward(input_seq)
        next_token_logits = output[-1]
        next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
        probs = softmax(next_token_logits, dim=-1)
        next_token = np.random.choice(np.arange(probs[0].shape[0]),size=1,p=probs[0])
        generated.append(next_token[0])
        if next_token == vocab['<eos>']:
            break
        input_seq = np.array([generated])
        
    return ' '.join(inverse_vocab[token] for token in generated)

sentence = "Great minds think alike, but" # Original: Great minds think alike, but they also think differently
print(generate_sentence(model, sentence, max_len=16))
print(generate_with_top_k_and_top_p(model, sentence, max_len=16, top_k=64, top_p=0.8))

Great minds think alike, but they also Great minds think alike, with saving <eos>
Great minds think alike, but minds minds feedback minds minds minds decisively after minds minds is
