In [17]:
# # download the complete works of shakespeare as a text file and save it in the home directory
# !wget https://www.gutenberg.org/files/100/100-0.txt -O ./shakespeare.txt

![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/derpyplops/arena/blob/main/shakespeare.ipynb)

zsh:1: bad pattern: [Open


In [18]:
import torch as t
import torch.nn as nn
from einops import rearrange, repeat
from fancy_einsum import einsum
from torch import optim
from impl.transformer_modules import DecoderTransformer, TransformerConfig
from torch.utils.data import DataLoader, Dataset, random_split

In [19]:
import re

# read the file
with open('./shakespeare.txt', 'r') as f:
    text = f.read()

print(re.split(r"\b", text)[:100])

['\ufeff', 'The', ' ', 'Project', ' ', 'Gutenberg', ' ', 'eBook', ' ', 'of', ' ', 'The', ' ', 'Complete', ' ', 'Works', ' ', 'of', ' ', 'William', ' ', 'Shakespeare', ', ', 'by', ' ', 'William', ' ', 'Shakespeare', '\n\n', 'This', ' ', 'eBook', ' ', 'is', ' ', 'for', ' ', 'the', ' ', 'use', ' ', 'of', ' ', 'anyone', ' ', 'anywhere', ' ', 'in', ' ', 'the', ' ', 'United', ' ', 'States', ' ', 'and', '\n', 'most', ' ', 'other', ' ', 'parts', ' ', 'of', ' ', 'the', ' ', 'world', ' ', 'at', ' ', 'no', ' ', 'cost', ' ', 'and', ' ', 'with', ' ', 'almost', ' ', 'no', ' ', 'restrictions', '\n', 'whatsoever', '. ', 'You', ' ', 'may', ' ', 'copy', ' ', 'it', ', ', 'give', ' ', 'it', ' ', 'away']


In [20]:
# dataset containing shakespeare
from torch.utils.data import Dataset

import re

# read the file
with open('./shakespeare.txt', 'r') as f:
    text = f.read()

class ShakespeareDataset(Dataset):
    def __init__(self, text, seq_size):
        super().__init__()
        self.text = text
        self.vocab = sorted(set(text))
        self.vocab_size = len(self.vocab)
        self.char_to_idx = {c: i for i, c in enumerate(self.vocab)}
        self.idx_to_char = {i: c for i, c in enumerate(self.vocab)}
        self.text_as_int = t.tensor([self.char_to_idx[c] for c in self.text])

        self.seq_size = seq_size

        self.num_batches = int(len(text) / (seq_size))

    def __len__(self):
        return self.num_batches

    def __getitem__(self, idx):
        text = self.text_as_int[idx * self.seq_size:(idx + 1) * self.seq_size]
        label = self.text_as_int[idx * self.seq_size + 1:(idx + 1) * self.seq_size + 1]
        return (text, label)

    def to_text(self, idxs):
        return ''.join([self.idx_to_char[idx] for idx in idxs])

    def to_int(self, text):
        return [self.char_to_idx[c] for c in text]

    def to_one_hot(self, idxs):
        return t.eye(self.vocab_size)[idxs]

    def to_text_from_one_hot(self, one_hot):
        return self.to_text(t.argmax(one_hot, dim=-1))

# create the dataset
shakespeare_dataset = ShakespeareDataset(re.split(r"\b", text), 100)

In [21]:
# print(shakespeare_dataset.text[0:52])
# print(shakespeare_dataset.vocab)

for x, y in shakespeare_dataset:
    print(x)
    print(y)
    break


tensor([34542,  9992,   113,  8237,   113,  5523,   113, 17830,   113, 24979,
          113,  9992,   113,  3477,   113, 10995,   113, 24979,   113, 10916,
          113,  9165,   480, 14228,   113, 10916,   113,  9165,     1, 10039,
          113, 17830,   113, 22293,   113, 19582,   113, 31392,   113, 33037,
          113, 24979,   113, 12315,   113, 12317,   113, 21768,   113, 31392,
          113, 10403,   113,  9566,   113, 12244,     0, 24317,   113, 25174,
          113, 25577,   113, 24979,   113, 31392,   113, 34221,   113, 12640,
          113, 24717,   113, 15883,   113, 12244,   113, 34099,   113, 12140,
          113, 24717,   113, 28036,     0, 33837,   786, 11076,   113, 23721,
          113, 15820,   113, 22310,   480, 20265,   113, 22310,   113, 12779])
tensor([ 9992,   113,  8237,   113,  5523,   113, 17830,   113, 24979,   113,
         9992,   113,  3477,   113, 10995,   113, 24979,   113, 10916,   113,
         9165,   480, 14228,   113, 10916,   113,  9165,     1,

In [22]:
def train(config: TransformerConfig, model, train_dataloader: DataLoader, optimizer, criterion):
    criterion = nn.CrossEntropyLoss()
    model = DecoderTransformer(config)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    accuracy_list = []

    for epoch in range(3):  # loop over the dataset multiple times
        accuracy = 0
        total = 0

        running_loss = 0.0
        for i, data in enumerate(train_dataloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(
                rearrange(outputs, 'batch seq vocab -> batch vocab seq'),
                labels
            )
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            print(i)
            # if i % 20 == 19:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss:.5f}')
            running_loss = 0.0

            
            
        # for (x, y) in train_dataloader:

        #     x = x.to(config.device)
        #     y = y.to(config.device)

        #     y_hat = model(x)
        #     y_predictions = y_hat.argmax(2)
        #     accuracy += (y_predictions == y).sum().item()
        #     total += y.size(0) * 6

        #     accuracy_list.append(accuracy/total)
        # print(f'accuracy: {accuracy/total:.3f}')

    print('Finished Training')

    return accuracy_list

In [23]:
shakespeare_dataloader = DataLoader(shakespeare_dataset, batch_size=128, shuffle=False)

config = TransformerConfig(
    vocab_size=shakespeare_dataset.vocab_size,
    hidden_size=256,
    num_heads=4,
    num_layers=2
)

model = DecoderTransformer(config)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

accuracy_list = train(config, model, shakespeare_dataloader, optimizer, criterion)

0
[1,     1] loss: 0.09867
1
[1,     2] loss: 0.09457
2
[1,     3] loss: 0.08988
3
[1,     4] loss: 0.08599
4
[1,     5] loss: 0.08139
5
[1,     6] loss: 0.07603
6
[1,     7] loss: 0.07139
7
[1,     8] loss: 0.06652
8
[1,     9] loss: 0.06217


KeyboardInterrupt: 