In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Use CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
print(f"Using device: {device}")

In [None]:
from pathlib import Path

text = Path('tiny-shakespeare.txt').read_text()

In [None]:
text[0:1000]

In [None]:
class CharTokenizer:
  def __init__(self, vocabulary):
    self.id_for_char = {chr: id for id, chr in enumerate(vocabulary)}
    self.char_for_id = {id: chr for id, chr in enumerate(vocabulary)}

  @staticmethod
  def train_from_text(text):
    vocabulary = set(text)
    return CharTokenizer(vocabulary)

  def encode(self, text):
    ids = []
    for chr in text:
      ids.append(self.id_for_char[chr])
    return torch.tensor(ids, dtype=torch.long)

  def decode(self, ids):
    chars = []
    for id in ids.tolist():
      chars.append(self.char_for_id[id])
    return ''.join(chars)

  def vocabulary_size(self):
    return len(self.id_for_char)

In [None]:
tokenizer = CharTokenizer.train_from_text(text)

In [None]:
print(tokenizer.encode("Hello world"))
print(tokenizer.decode(tokenizer.encode("Hello world")))

In [None]:
tokenizer.vocabulary_size()

In [None]:
config = {
  "vocabulary_size": tokenizer.vocabulary_size(),
  "context_size": 256,
  "d_embed": 768,
  "heads_num": 12,
  "layers_num": 10,
  "dropout_rate": 0.1,
  "use_bias": False,
}

config["head_size"] = config["d_embed"] // config["heads_num"]

In [None]:
class AttentionHead(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.Q_weights = nn.Linear(config["d_embed"], config["head_size"], config["use_bias"])
    self.K_weights = nn.Linear(config["d_embed"], config["head_size"], config["use_bias"])
    self.V_weights = nn.Linear(config["d_embed"], config["head_size"], config["use_bias"])

    self.dropout = nn.Dropout(config["dropout_rate"])

    casual_attention_mask = torch.tril(torch.ones(config["context_size"], config["context_size"]))
    self.register_buffer('casual_attention_mask', casual_attention_mask)

  def forward(self, input):
    batch_size, tokens_num, d_embed = input.shape
    Q = self.Q_weights(input)
    K = self.K_weights(input)
    V = self.V_weights(input)

    attention_scores = Q @ K.transpose(1, 2)
    attention_scores = attention_scores.masked_fill(
        self.casual_attention_mask[:tokens_num,:tokens_num] == 0,
        -torch.inf
    )
    attention_scores = attention_scores / ( K.shape[-1] ** 0.5 )
    attention_scores = torch.softmax(attention_scores, dim=-1)
    attention_scores = self.dropout(attention_scores)

    return attention_scores @ V

In [None]:
input = torch.rand(8, config["context_size"], config["d_embed"])

In [None]:
ah = AttentionHead(config)

In [None]:
output = ah(input)

In [None]:
output.shape

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, config):
    super().__init__()

    heads_list = [AttentionHead(config) for _ in range(config["heads_num"])]
    self.heads = nn.ModuleList(heads_list)

    self.linear = nn.Linear(config["d_embed"], config["d_embed"])
    self.dropout = nn.Dropout(config["dropout_rate"])

  def forward(self, input):
    # print(f"Input shape: {input.shape}")
    heads_outputs = [head(input) for head in self.heads]

    scores_change = torch.cat(heads_outputs, dim=-1)
    # print(f"heads shape: {scores_change.shape}")

    scores_change = self.linear(scores_change)
    return self.dropout(scores_change)

In [None]:
mha = MultiHeadAttention(config)

In [None]:
output = mha(input)

In [None]:
output.shape

In [None]:
class FeedForward(nn.Module):

  def __init__(self, config):
    super().__init__()

    self.linear_layers = nn.Sequential(
        nn.Linear(config["d_embed"], config["d_embed"] * 4),
        nn.GELU(),
        nn.Linear(config["d_embed"] * 4, config["d_embed"]),
        nn.Dropout(config["dropout_rate"])
    )

  def forward(self, input):
    return self.linear_layers(input)

In [None]:
ff = FeedForward(config)

In [None]:
ouptut = ff(input)

In [None]:
output.shape

In [None]:
class Block(nn.Module):

  def __init__(self, config):
    super().__init__()

    self.multi_head = MultiHeadAttention(config)
    self.layer_norm_1 = nn.LayerNorm(config["d_embed"])

    self.feed_forward = FeedForward(config)
    self.layer_norm_2 = nn.LayerNorm(config["d_embed"])

  def forward(self, input):
    residual = input
    x = self.multi_head(self.layer_norm_1(input))
    x = x + residual

    residual = x
    x = self.feed_forward(self.layer_norm_2(x))
    return x + residual

In [None]:
b = Block(config)

In [None]:
ouptut = b(input)

In [None]:
output.shape

In [None]:
class DemoGPT(nn.Module):
  def __init__(self, config):
    super().__init__()

    self.token_embedding_layer = nn.Embedding(config["vocabulary_size"], config["d_embed"])
    self.positional_embedding_layer = nn.Embedding(config["context_size"], config["d_embed"])

    blocks = [Block(config) for _ in range(config["layers_num"])]
    self.layers = nn.Sequential(*blocks)

    self.layer_norm = nn.LayerNorm(config["d_embed"])
    self.unembedding = nn.Linear(config["d_embed"], config["vocabulary_size"], bias=False)

  def forward(self, token_ids):
    # print("Forward")
    batch_size, tokens_num = token_ids.shape

    x = self.token_embedding_layer(token_ids)
    sequence = torch.arange(tokens_num, device=device)
    x = x + self.positional_embedding_layer(sequence)

    x = self.layers(x)
    x = self.layer_norm(x)
    x = self.unembedding(x)

    return x


In [None]:
model = DemoGPT(config).to(device)

In [None]:
output = model(tokenizer.encode("Hi").unsqueeze(dim=0).to(device))

In [None]:
output.shape

In [None]:
def generate(model, prompt_ids, max_tokens):
    output_ids = prompt_ids
    for _ in range(max_tokens):
      # print(f"Prompt : {output_ids}")
      if output_ids.shape[1] >= config["context_size"]:
        break
      with torch.no_grad():
        logits = model(output_ids)

      logits = logits[:, -1, :]
      probs = F.softmax(logits, dim=-1)
      # Sample a random token given the softmax distribution
      next_token_id = torch.multinomial(probs, num_samples=1)
      # Add new token to the output, and repeat the process
      output_ids = torch.cat([output_ids, next_token_id], dim=-1)
    return output_ids

In [None]:
def generate_with_prompt(model, tokenizer, prompt, max_tokens=100):
  model.eval()

  prompt = tokenizer.encode(prompt).unsqueeze(dim=0).to(device)

  return tokenizer.decode(generate(model, prompt, max_tokens=max_tokens)[0])

In [None]:
generate_with_prompt(model, tokenizer, "First Citizen:\n")

In [None]:
train_split = 0.9
batch_size = 64

train_iterations = 5000
evaluation_interval = 10
learning_rate=4e-4

In [None]:
tokenized_text = tokenizer.encode(text).to(device)
train_count = int(train_split * len(tokenized_text))
train_data, validation_data = tokenized_text[:train_count], tokenized_text[train_count:]

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler

class IdsDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = block_size

    def __len__(self):
        # Ensure all sequences generated are complete by reducing length
        return len(self.data) - self.block_size

    def __getitem__(self, index):
        # The index will be adjusted to ensure all sequences can be generated
        idx = index % (len(self.data) - self.block_size)
        x = self.data[idx:idx + self.block_size]
        y = self.data[idx + 1:idx + 1 + self.block_size]
        return x, y

In [None]:
train_dataset = IdsDataset(train_data, config["context_size"])
validation_dataset = IdsDataset(validation_data, config["context_size"])

In [None]:
from torch.utils.data import Dataset, DataLoader, RandomSampler

train_sampler = RandomSampler(train_dataset, num_samples=batch_size * train_iterations, replacement=True)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)

validation_sampler = RandomSampler(validation_dataset, replacement=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, sampler=validation_sampler)

In [None]:
@torch.no_grad()
def calculate_validation_loss(model, batches_num):
  model.eval()
  total_loss = 0

  validation_iter = iter(validation_dataloader)

  for _ in range(batches_num):
    idx, targets = next(validation_iter)
    logits = model(idx)

    logits_view = logits.view(batch_size * config["context_size"], config["vocabulary_size"])
    # print(f"Logits view shape: {logits_view.shape}")
    targets_view = targets.view(batch_size * config["context_size"])
    # print(f"Targets view shape: {targets_view.shape}")
    loss = F.cross_entropy(logits_view, targets_view)

    # loss = F.cross_entropy(logits, targets)
    total_loss += loss.item()

  average_loss = total_loss / batches_num

  return average_loss

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
# Remove it later!!

# from google.colab import drive
# drive.mount('/content/drive')

%pip install ipywidgets

In [None]:

import os
from IPython.display import display, clear_output
from matplotlib import pyplot as plt
from IPython.display import display
import ipywidgets as widgets
%matplotlib inline

# Create an output widget to handle the plot updates
plot_output = widgets.Output()

# Display the output widget once, outside of the function
display(plot_output)

def update_plot(train_losses, train_steps, validation_losses, validation_steps):

  with plot_output:
    clear_output(wait=True)  # Clear only the plot output, not the text
    plt.figure(figsize=(7, 5))
    plt.plot(train_steps, train_losses, label='Training Loss')
    plt.plot(validation_steps, validation_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('epoch')
    plt.legend(loc='center left')
    plt.grid(True)
    plt.show()


# Set up lists to store losses for plotting
train_losses = []
train_steps = []
eval_losses = []
eval_steps = []


for step_num, sample in enumerate(train_dataloader):

  model.train()
  idx,  = sample
  logits = model(idx)
  # print(f"Logits shape: {logits.shape}")
  # print(f"Targets shape: {targets.shape}")
  logits_view = logits.view(batch_size * config["context_size"], config["vocabulary_size"])
  # print(f"Logits view shape: {logits_view.shape}")
  targets_view = targets.view(batch_size * config["context_size"])
  # print(f"Targets view shape: {targets_view.shape}")
  loss = F.cross_entropy(logits_view, targets_view)
  # Backward propagation
  loss.backward()
  # Update model parameters
  optimizer.step()
  # Set to None to reduce memory usage
  optimizer.zero_grad(set_to_none=True)

  train_losses.append(loss.item())
  train_steps.append(step_num)

  print(f"Step {step_num}. Loss {loss.item():.3f}")

  if step_num % evaluation_interval == 0:
    print("Demo GPT:\n" + generate_with_prompt(model, tokenizer, "\n"))

    validation_loss = calculate_validation_loss(model, batches_num=10)
    eval_losses.append(validation_loss)
    eval_steps.append(step_num)

    print(f"Step {step_num}. Validation loss: {validation_loss:.3f}")


  update_plot(train_losses, train_steps, eval_losses, eval_steps)

  # if step_num in {100, 3000, 4000, 4999}:
  #   dir = f"/content/drive/MyDrive/demo-gpt/{step_num}"
  #   os.makedirs(dir)
  #   torch.save({
  #       "model_state": model.state_dict(),
  #       # "optimizer_state": optimizer.state_dict(),
  #   },
  #   f"{dir}/checkpoint.pth",
  # )




In [None]:
checkpoint = torch.load("/content/drive/MyDrive/demo-gpt/4999/checkpoint.pth")
model = DemoGPT(config).to(device)
model.load_state_dict(checkpoint["model_state"])

In [None]:
print(generate_with_prompt(model, tokenizer, "ROMEO:\n"))