In [105]:
from jax import lax, random, numpy as jnp
from flax import linen as nn
from train.trainer import create_train_state, Trainer, load_train_state, save_train_state
from model import GPT
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [106]:
class TextDataset(Dataset):

    def __init__(self, data, block_size):
        chars = sorted(list(set(data)))
        data_size, vocab_size = len(data), len(chars)
        print('The input data has %d characters. %d of these characters are unique. These characters include uppercase and lower case letters, as well as punctuations.'
        % (data_size, vocab_size))

        self.stoi = {ch:i for i,ch in enumerate(chars)}
        self.itos = {i:ch for i,ch in enumerate(chars)} # will be used for prediction/text generation task
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data

    def __getitem__(self, idx):
        text_block = self.data[idx:idx + self.block_size + 1]
        # encode every character to an integer
        encoded_txt = [self.stoi[char] for char in text_block]
        x = torch.tensor(encoded_txt[:-1], dtype=torch.int)
        y = torch.tensor(encoded_txt[1:], dtype=torch.int)
        return x, y

    def __len__(self):
        return 100000#(len(self.data) - self.block_size)

In [107]:
shakespeare_txt = open('./gpt_text_input/shakespeare.txt', 'r').read()

dataset = TextDataset(shakespeare_txt, block_size = 48)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, (0.9, 0.1))

The input data has 1115393 characters. 65 of these characters are unique. These characters include uppercase and lower case letters, as well as punctuations.


In [108]:
config = {
    "n_layers": 4,
    "n_head": 4,
    "n_embd": 48,
    "vocab_size": dataset.vocab_size,
    "block_size": dataset.block_size,
    "embd_pdrop": 0.0
}
epochs=2
continue_training_from_checkpoint=True

In [109]:
model = GPT(**config)

In [110]:
key1, key2, dropout_key = random.split(random.PRNGKey(1), 3)

init_rng = {"params": key2, 'dropout' : dropout_key}

In [111]:
state = create_train_state(model, init_rng, config, key=dropout_key)

if continue_training_from_checkpoint:
    state = load_train_state(state)

In [112]:
trainer = Trainer(train_dataset, test_dataset, train_state=state)

In [None]:
for i in range(epochs):
    trainer.run_trainer(epochs=1)
    save_train_state(trainer.train_state)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1407/1407 [05:54<00:00,  3.97it/s]


train_accuracy: 0.2386350929737091
train_loss: 2.8138303756713867
test_accuracy: 0.29055947065353394
test_loss: 2.4855010509490967


 96%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍        | 1345/1407 [05:15<00:14,  4.16it/s]

In [None]:
state=trainer.train_state

In [None]:
train_loss = trainer.metrics_history["train_loss"]
plt.plot(np.arange(0, len(train_loss), 1), train_loss, label="loss")
train_accuracy = trainer.metrics_history["train_accuracy"]
plt.plot(np.arange(0, len(train_accuracy), 1), train_accuracy, label="accuracy")
plt.legend()

In [None]:
test_loss = trainer.metrics_history["test_loss"]
plt.plot(np.arange(0, len(test_loss), 1), test_loss, label="loss")
test_accuracy = trainer.metrics_history["test_accuracy"]
plt.plot(np.arange(0, len(test_accuracy), 1), test_accuracy, label="accuracy")
plt.legend()

In [None]:
dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)
x, y = next(iter(dataloader))
sentence = "".join([dataset.itos[int(x_i)] for x_i in x[0]])
print(x)
print("input sentence: ", sentence)
print()
x, y = jnp.array(x), jnp.array(y)
sequence= model.generate(state.params, x, 96, key1, 0.4)
    
sentence = "".join([dataset.itos[x_i] for x_i in sequence])
print(sentence)
print(", ".join([dataset.itos[i] for i in range(config["vocab_size"])]))


