### Import libraries

In [None]:
import math
import time

from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tensorflow.keras.utils import get_file

device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Download and preprocess data

In [None]:
# Download "shakespeare.txt" dataset
path_to_file = get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

text = open(path_to_file, 'rb').read().decode(encoding='utf-8')

In [None]:
vocab = sorted(list(set(text)))

idx_to_char = {i:x for i, x in enumerate(vocab)}
char_to_idx = {x:i for i, x in enumerate(vocab)}

def text_to_idx(text_line):
  '''Converts a list of words into a list of word indices.'''
  idxs = [char_to_idx[char] for char in text_line]
  return idxs

def idx_to_text(idx_line):
  '''Converts a list of word indices into a list of words.'''
  text_line = ''.join([idx_to_char[idx] for idx in idx_line])
  return text_line

In [None]:
def slice_text(text_line, seq_len=129):
    '''Slices text into feature-target pairs.'''
    slices = [text_line[i:i + seq_len] for i in range(0, len(text_line), seq_len)]
    return slices

def pairs_to_dataloader(pairs, batch_size=32, shuffle=True):
    '''Converts feature-target pairs into a DataLoader object.'''
    class TextDataset(Dataset):
        def __init__(self, pairs):
            self.pairs = pairs

        def __len__(self):
            return len(self.pairs)

        def __getitem__(self, idx):
            input_seq, target_seq = self.pairs[idx]
            return torch.tensor(input_seq), torch.tensor(target_seq)

    dataset = TextDataset(pairs)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True)

    dataloader = [(batch[0].to(device), batch[-1].to(device)) for batch in dataloader]
    return dataloader

def generate_dataloader(text, seq_len=129, batch_size=1, shuffle=True):
    '''DIrwctly converts text into a DataLoader object with feature-target pairs.'''
    idx_text = text_to_idx(text)
    slices = slice_text(idx_text, seq_len=seq_len)

    # Adjust slices to have pairs of input and target sequences of the same length
    input_target_pairs = [(slice[:-1], slice[-1]) for slice in slices if len(slice) == seq_len]

    dataloader = pairs_to_dataloader(input_target_pairs, batch_size=batch_size, shuffle=shuffle)
    return dataloader

### Build a model

Token embedding layer

In [None]:
class TokenEmbedding(nn.Module):
  def __init__(self, d_model, vocab_size, dropout):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    return self.dropout(self.embedding(x))

Positional encoding layer

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, maxlen):
    super().__init__()

    den = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model)).to(device)
    pos = torch.arange(0, maxlen, dtype=torch.float).unsqueeze(1).to(device)
    self.encoding = torch.zeros(maxlen, d_model).to(device)
    self.encoding[:, 0::2] = torch.sin(pos * den)
    self.encoding[:, 1::2] = torch.cos(pos * den)

  def forward(self, x):
    return x + self.encoding

Self-attention mechanism

In [None]:
class SelfAttention(nn.Module):
  def __init__(self, d_model, n_heads):
    super().__init__()

    self.n_heads = n_heads

    self.query = nn.Linear(d_model, d_model)
    self.key = nn.Linear(d_model, d_model)
    self.value = nn.Linear(d_model, d_model)

    self.scale = lambda x: x / math.sqrt(n_heads)

    self.out = nn.Linear(d_model, d_model)

    self.norm = nn.LayerNorm(d_model)

  def forward(self, x):
    batch_size, seq_len, emb_dim = x.size()

    Q = self.query(x)
    K = self.key(x)
    V = self.value(x)

    scores = Q @ K.transpose(-1, -2)
    scores = self.scale(scores)

    attention_weights = torch.softmax(scores, -1)
    attended_values = attention_weights @ V

    attended_values = attended_values.transpose(1, 2).contiguous()

    attended_values = attended_values.view(batch_size, seq_len, emb_dim)

    output = self.out(attended_values)

    output += x

    output = self.norm(output)
    return output, attention_weights.detach()

FeedForwad layer for output

In [None]:
class FeedForward(nn.Module):
  def __init__(self, d_model, d_ff, dropout):
    super().__init__()

    # Linear layers
    self.pickles = nn.Linear(d_model, d_ff)
    self.tomatoes = nn.Linear(d_ff, d_model)

    self.norm = nn.LayerNorm(d_model)

    self.dropout = nn.Dropout(dropout)

    # Weights normalization
    nn.init.kaiming_normal_(self.pickles.weight, nonlinearity='relu')
    nn.init.kaiming_normal_(self.tomatoes.weight, nonlinearity='relu')

  def forward(self, x):
    pickle = self.pickles(x)
    pickle = F.relu(pickle)

    tomato = self.tomatoes(pickle)
    tomato = self.dropout(tomato)

    output = self.norm(tomato)
    return output

Decoder-only transformer model

In [None]:
class DecoderTransformer(nn.Module):
  def __init__(self, d_model, maxlen, vocab_size, dropout, n_heads, d_ff, n_att):
    super().__init__()

    # Classes
    self.embedding = TokenEmbedding(d_model, vocab_size, dropout).to(device)
    self.posencoding = PositionalEncoding(d_model, maxlen).to(device)
    self.sequential_attention = [SelfAttention(d_model, n_heads).to(device) for _ in range(n_att)]
    self.neuralnet = FeedForward(d_model, d_ff, dropout).to(device)

    self.flatten = lambda x: x.view(x.size(0), -1)
    self.out = nn.Linear(maxlen * d_model, vocab_size)

  def forward(self, x):
    embeded = self.embedding(x)
    posencoded = self.posencoding(embeded)
    att_Ws = []

    attended = posencoded
    for lil_attention in self.sequential_attention:
      attended, att_W = lil_attention(attended)
      att_Ws.append(att_W)

    boring = self.neuralnet(attended)
    flat = self.flatten(boring)
    output = self.out(flat)
    return output, att_Ws

### Hyperparameters

In [None]:
D_MODEL     = 64
MAXLEN      = 128
VOCAB_SIZE  = len(vocab)
DROPOUT     = .05
N_HEADS     = 8
BATCH_SIZE  = 1
D_FF        = 1024
N_ATT       = 2
lr          = 3e-4

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=ShakespeareGPT.parameters(), lr=lr)

Initialize the model

In [None]:
ShakespeareGPT = DecoderTransformer(D_MODEL, MAXLEN, VOCAB_SIZE, DROPOUT, N_HEADS, D_FF, N_ATT).to(device)

In [None]:
# Generate a dataloader
dataloader = generate_dataloader(text, batch_size=BATCH_SIZE)

### Train the model

In [None]:
def train_epoch(model, print_loss):
  '''A single epoch of a training loop.'''
  model.train()

  LOSS = 0

  for i, (input, target) in enumerate(dataloader):
    optimizer.zero_grad()

    logits, _ = model(input)

    loss = criterion(logits, target)
    LOSS += loss.item()
    loss.backward()

    optimizer.step()

    if (i+1)%print_loss == 0:
      print(f'Training epoch {i+1}/{len(dataloader)}: {loss.item():.5f}')

  LOSS /= len(dataloader)
  return LOSS

def get_time(epoch_time):
  '''Converts time in seconds into minutes and seconds format.'''
  minutes = int(epoch_time) // 60
  seconds = epoch_time - minutes*60
  return f'Time taken: {minutes} m. {seconds:.1f} s.'

def predict_char(input, model):
  model.eval()
  with torch.no_grad():
    logits, _ = model(input)
    idx = torch.argmax(logits, -1)
  return idx

In [None]:
epochs = 10
print_loss = 1000
loss_list = []

# Training loop
for epoch in tqdm(range(1, epochs+1)):
  start_time = time.time()
  loss = train_epoch(ShakespeareGPT, print_loss)
  loss_list.append(loss)
  epoch_time = time.time() - start_time
  print(f'Epoch #{epoch}: Loss = {loss:.5f}\n{get_time(epoch_time)}')

### Make predictions with a trained model

In [None]:
max_token = 1000

initial, target = next(iter(dataloader))
top = initial.tolist()[0]
input = initial

for _ in range(max_token):
  pred = predict_char(input, Optimus)
  top.append(pred.tolist()[0])
  input = torch.cat([input[:, 1:], pred.view(1, -1)], -1)

In [None]:
print(f'START SEQUENCE:\n{idx_to_text(initial.tolist()[0])}', end='\n'*3)
print(f'PREDICTED:\n{idx_to_text(top)}')

**Example of text generated by ShakespeareGPT**
START SEQUENCE:
ortuned him by any means?

MONTAGUE:
Both by myself and many other friends:
But he, his own affections' counsellor,
Is to himsel




PREDICTED:
ortuned him by any means?

MONTAGUE:
Both by myself and many other friends:
But he, his own affections' counsellor,
Is to himselfeld
 sasone sundrntwe ie tipt cerstres ee ingt

Asse so se elout:
I ho wolntee te te s yous noth yound trlind en meis nof, bole sh sorsiit ari riph, gheld go topdows:notth thmunue p mo w ch then lme ist at se sut
I hig noull fseisthu f hal  hiy owwott bongonn yotiad your eatrln yol my toatt sourrowertay tou doin, you
he I bea thy blay, pesctl
They amont hir weis ghatd ton, I annedeator w yount in, tuene hs l oresain tenot

bo no shot singe, I le rey lefne, sotoowise ho