In [78]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import math
import time
import tiktoken
from datasets import load_dataset
from io import open


In [104]:
# Hyperparameters
nhead = 4
ninp = 128
em_size = 1500
nhid = 1500
nlayers = 4
dropout = 0.4
bptt = 150
batch_size = 32
eval_batch_size = 16
learning_rate = 4
clip = 0.25
epochs = 40
log_interval = 100
models_dir = '../models/'
data_dir = '../data/wikitext'


In [105]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

In [106]:
## Taken from https://github.com/pytorch/examples/blob/main/word_language_model/data.py on 01/11/2023
# class Dictionary(object):
#     def __init__(self):
#         self.word2idx = {}
#         self.idx2word = []
#         self.counter = 0
    
#     def add_word(self, word):
#         if word not in self.idx2word:
#             self.idx2word.append(word)
#             self.word2idx[word] = self.counter
#             self.counter += 1
    
#     def __len__(self):
#         return len(self.idx2word)

In [107]:

encoder = tiktoken.get_encoding("cl100k_base")
hello_world_encoded = encoder.encode("Hello world!")
assert encoder.decode(hello_world_encoded) == "Hello world!"

hello_world_encoded

[9906, 1917, 0]

In [108]:
encoder.special_tokens_set

{'<|endofprompt|>',
 '<|endoftext|>',
 '<|fim_middle|>',
 '<|fim_prefix|>',
 '<|fim_suffix|>'}

In [109]:
encoder.n_vocab

100277

In [110]:
# def tokenise_prompt(input, corpus):
#     words = input.split() + ['<eos>']
#     ids = torch.LongTensor(len(words))
#     for i, word in enumerate(words):
#         ids[i] = corpus.dictionary.word2idx[word]
#     return ids

In [111]:
## Taken from https://github.com/pytorch/examples/blob/main/word_language_model/data.py on 01/11/2023
class Corpus(object):
    def __init__(self, path):
        # self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path):
        assert os.path.exists(path)
        text = ''
        chunk_size = 1024
        with open(path, 'r', encoding='utf-8') as f:
            for chunk in f.read(chunk_size):
                text += chunk

            # for line in f:
            #     words = line.split()
            #     for word in words:
            #         text += word
            #     text += "\n"
        return encoder.encode(text)
            # return ids

In [112]:
if not os.path.exists(data_dir):
    os.mkdir('data_dir')
    data = load_dataset('wikitext', 'wikitext-2-v1')
    data.save_to_disk(os.path.join(data_dir, 'wikitext-2'))
data = load_dataset('wikitext', 'wikitext-2-v1', data_dir=data_dir)



In [113]:
data

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

In [114]:
## Converting from HF dataset to plain text files can probably improve this
train_data = '\n'.join(data['train']['text'])
valid_data = '\n'.join(data['validation']['text'])
test_data = '\n'.join(data['test']['text'])
with open(os.path.join(data_dir, 'train.txt'), 'w') as f:
    f.write(train_data)
with open(os.path.join(data_dir, 'valid.txt'), 'w') as f:
    f.write(valid_data)
with open(os.path.join(data_dir, 'test.txt'), 'w') as f:
    f.write(test_data)

In [115]:
corpus = Corpus(data_dir)

In [116]:
def batchify(data, bsz):
    data = torch.tensor(data)
    nbatch = data.size(0) // bsz
    data = data.narrow(0, 0, nbatch*bsz)
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)


In [117]:
train_data = batchify(corpus.train, batch_size)
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, eval_batch_size)

In [118]:
class LSTMModel(nn.Module):
    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(LSTMModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.lstm = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        self.init_weights()
        self.nhid = nhid
        self.nlayers = nlayers
        self.ntokens = ntoken
        self.model_type = 'LSTM' 

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.lstm(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output)
        decoded = decoded.view(-1, self.ntokens)
        return F.log_softmax(decoded, dim=1), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))

In [119]:
class PoisitionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PoisitionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(position*div_term)
        pe[:, 1::2] = torch.cos(position*div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [120]:
class TransformerModel(nn.Transformer):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__(d_model=ninp, nhead=nhead, dim_feedforward=nhid, num_decoder_layers=nlayers)
        # self.encoder = nn.Embedding(ntoken, ninp)
        self.src_mask = None
        self.pos_encoder = PoisitionalEncoding(ninp, dropout)

        self.input_emb = nn.Embedding(ntoken, ninp)
        self.decoder = nn.Linear(ninp, ntoken)
        self.ninp = ninp
        self.nhid = nhid
        self.nlayers = nlayers
        self.ntokens = ntoken
        self.model_type = 'Transformer'

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz) == 1).transpose(0, 1))
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def init_weights(self):
        initrange = 0.1
        self.input_emb.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, src, has_mask=True):
        if has_mask:
            device = src.device
            if self.src_mask is None or self.src_mask.size(0) != len(src):
                mask = self._generate_square_subsequent_mask(len(src)).to(device)
                self.src_mask = mask
        else:
            self.src_mask = None
        
        src = self.input_emb(src)*math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.encoder(src, mask=self.src_mask)
        output = self.decoder(output)
        return F.log_softmax(output, dim=-1)

In [121]:
ntokens = encoder.n_vocab
transformer_model = TransformerModel(ntoken=ntokens, ninp=ninp, nhead=nhead, nhid=nhid, nlayers=nlayers, dropout=dropout).to(device)
lstm_model = LSTMModel(ntoken=ntokens, ninp=ninp, nhid=nhid, nlayers=nlayers,dropout=dropout).to(device)

criterion = nn.NLLLoss()


In [122]:
def repackage_hidden(h):
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)


In [123]:
def get_batch(source, i):
    seq_len = min(bptt, len(source)-1-i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

In [124]:
def evaluate(model_type, model, data_source, ntokens):
    model.eval()
    total_loss = 0
    if model_type != 'Transformer':
        hidden = model.init_hidden(eval_batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0)-1, bptt):
            data, targets = get_batch(data_source, i)
            if model_type == 'Transformer':
                output = model(data)
                output = output.view(-1, ntokens)
            else:
                output, hidden = model(data, hidden)
                hidden = repackage_hidden(hidden)
            total_loss += len(data)*criterion(output, targets).item()
    return total_loss/(len(data_source)-1)

In [125]:
def train(model_type, model, epoch, lr, ntokens):
    model.train()
    total_loss = 0
    start_time = time.time()
    if model_type != 'Transformer':
        hidden = model.init_hidden(batch_size)
    for batch, i in enumerate(range(0, train_data.size(0)-1, bptt)):
        data, targets = get_batch(train_data, i)
        model.zero_grad()
        if model_type == 'Transformer':
            output = model(data)
            output = output.view(-1, ntokens)
        else:
            output, hidden = model(data, hidden)
            hidden = repackage_hidden(hidden)
        loss = criterion(output, targets)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        for p in model.parameters():
            p.data.add_(p.grad, alpha=-lr)
        
        total_loss += loss.item()

        if batch%log_interval == 0 and batch > 0:
            cur_loss = total_loss/log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                        epoch, batch, len(train_data)//bptt, elapsed*1000/log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
        


In [126]:
def export_onnx(path, model, batch_size, seq_len):
    model.eval()
    x = torch.rand(seq_len, batch_size).to(device)
    hidden = model.init_hidden(batch_size)
    torch.onnx.export(model, (x, hidden), path)

In [127]:
def run_training(model, epochs, learning_rate, ntokens):
    lr = learning_rate
    best_val_loss = None
    model_type = model.model_type
    model_path = os.path.join(models_dir, (model_type + '.pt'))
    for epoch in range(1, epochs):
        epoch_start_time = time.time()
        train(model_type, model, epoch, lr, ntokens)
        val_loss = evaluate(model_type, model, val_data, ntokens)
        print('-'*89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                'valid ppl {:8.2f}'.format(epoch, (time.time()-epoch_start_time), val_loss, math.exp(val_loss)))
        print('-'*89)
        if not best_val_loss or val_loss < best_val_loss:
            with open(model_path, 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        else:
            lr /= 4.0
    # with open(model_path, 'rb') as f:
    #     model.torch.load(f)
    #     if model_type == 'LSTM':
    #         model.flatten_parameters()

    test_loss = evaluate(model_type, model, test_data, ntokens)
    print('='*89)
    print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(test_loss, math.exp(test_loss)))
    print('='*89)
    # export_onnx(os.path.join(onnx_export_dir, model_type), model, batch_size=1, seq_len=bptt)


In [128]:
run_training(transformer_model, epochs, learning_rate, encoder.n_vocab)

-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  1.33s | valid loss 11.35 | valid ppl 85320.73
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   2 | time:  0.13s | valid loss 10.71 | valid ppl 44877.89
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   3 | time:  0.11s | valid loss 13.41 | valid ppl 667764.41
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   4 | time:  0.11s | valid loss 10.65 | valid ppl 42099.85
-------------------------------------------------------------------------

In [130]:
run_training(lstm_model, epochs, learning_rate, encoder.n_vocab)

-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  2.46s | valid loss 11.51 | valid ppl 99861.40
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   2 | time:  0.56s | valid loss 11.51 | valid ppl 99222.61
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   3 | time:  0.57s | valid loss 11.50 | valid ppl 98565.07
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   4 | time:  0.57s | valid loss 11.49 | valid ppl 97942.33
--------------------------------------------------------------------------

In [131]:
# tokenise_prompt('The meaning of life is', corpus)

NameError: name 'tokenise_prompt' is not defined

In [None]:
if transformer_model is None:
    transformer_model = torch.load('../models/Transformer.pt', map_location=device)


In [149]:
def generate_text(model, prompt, output_file, corpus=corpus, temp=0.4, device=device, ntokens=encoder.n_vocab):
    if model.model_type != 'Transformer':
        hidden = model.init_hidden(1)
        # lstm_model.flatten_parameters()
    model.to(device)
    # input = tokenise_prompt(prompt, corpus).to(device)
    input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
    with open(output_file, 'w') as f:
        with torch.no_grad():
            for i in range(1000):
                if model.model_type == 'Transformer':
                    output = model(input, has_mask=False)
                    word_weights = output[-1].squeeze().div(temp).exp().cpu()
                    word_idx = torch.multinomial(word_weights, 1)[0]
                    word_tensor = torch.Tensor([[word_idx]]).long().to(device)
                    input = torch.cat([input, word_tensor], 0)
                else:
                    output, hidden = model(input, hidden)
                    word_weights = output.squeeze().div(temp).exp().cpu()
                    word_idx = torch.multinomial(word_weights, 1)[0]
                    input.fill_(word_idx)
                word = corpus.dictionary.idx2word[word_idx]
                f.write(word + ('\n' if i%20 == 19 else ' '))
                if i % log_interval == 0:
                    print('| Generated {}/{} words'.format(i, 1000))
            print('Done')

In [150]:
generate_text(transformer_model, 'The meaning of life is', '../output/tokenized_transformer_generated.txt')

3200
| Generated 0/1000 words
330
1174
4298
66416
1174
279
279
18
86262
330
1847
18
279
279
662
279
315
1174
323
18
279
1847
279
86262
86262
279
279
4298
366
88
86262
66416
279
279
279
279
220
279
279
279
279
304
279
279
279
279
1174
279
86262
279
279
439
66416
311
1174
279
551
1174
279
86262
279
1174
279
304
88
4298
315
66416
330
5089
279
279
86262
279
279
279
86262
279
66416
86262
1174
279
1174
315
279
330
279
86262
1174
264
662
1847
279
279
315
662
304
279
88
279
662
86262
279
279
662
662
304
86262
279
279
279
1174
279
279
279
279
86262
86262
1174
4298
1174
66416
279
3200
433
1174
1174
4298
279
279
279
1174
279
1174
315
279
88
315
4298
88
88
4298
279
88
279
88
279
1174
366
279
279
662
571
304
88
279
279
1174
1174
315
279
1847
86262
279
279
3200
1174
279
88
264
66416
279
86262
279
88
330
279
4298
279
1174
88
304
18
88
279
86262
3200
366
4298
315
279
439
1174
220
279
279
662
279
279
279
18
66416
3200
279
323
315
88
264
279
304
279
279
662
3200
86262
279
279
1174
1174
279
4298
662
279


In [None]:
# generate_text('LSTM.pt', 'LSTM', corpus, device)