<a href="https://colab.research.google.com/github/ggolani/ML/blob/main/GPT2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')



In [None]:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [None]:
writer = SummaryWriter('runs/gpt2')

def log_weights_histograms(model, global_step):
    for name, param in model.named_parameters():
        if param.requires_grad:
            # Tag format is 'weights/layer_name'
            writer.add_histogram(f'weights/{name}', param.data, global_step)

def log_gradients_histograms(model, global_step):
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            # Tag format is 'gradients/layer_name'
            writer.add_histogram(f'gradients/{name}', param.grad, global_step)

# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
#hyperparameters for GPT2-124M
n_vocab = tokenizer.vocab_size
embed_dim = 768
seq_len = 256
n_heads = 12
n_blocks = 12
batch_size = 32
dropout = 0 # range [0-1]

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
import requests
text = requests.get('https://www.gutenberg.org/cache/epub/829/pg829.txt').text
gtTokens = torch.tensor(tokenizer.encode(text))
print(len(gtTokens))

In [None]:
train_ratio = 0.9
train_data = torch.tensor([], dtype=torch.long)
test_data = torch.tensor([], dtype=torch.long)

import math


for i in range(10):
  shard_max = math.floor(len(gtTokens) / 10 * (i+1))
  shard_min = math.floor(len(gtTokens) / 10 * i)
  train_max = math.ceil(shard_min + (shard_max - shard_min) * train_ratio)
  train_data = torch.cat((train_data, gtTokens[shard_min:train_max]))
  test_data = torch.cat((test_data, gtTokens[train_max+1:shard_max-1]))



In [None]:
# a function that returns a batch of data samples
def get_data_batch(training=True):

  # pick the dataset to use
  if training:
    data = train_data
  else:
    data = test_data

  # pick random indices to start
  ix = torch.randint(len(data)-seq_len,size=(batch_size,))

  # get the data and targets (via broadcasting outer product)
  X = data[ix[:,None] + torch.arange(seq_len)]
  y = data[ix[:,None] + torch.arange(1,seq_len+1)]
  return X,y

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()

        self.num_heads = n_heads
        self.head_dim = embed_dim // n_heads

        self.QKV = nn.Linear(embed_dim, 3*embed_dim, bias=True)
        self.W0 = nn.Linear(embed_dim, embed_dim, bias=True)

    def forward(self, x):
        B, T, E = x.shape # [batch, seq_len, embed_dim]
        qkv = self.QKV(x)
        q,k,v = torch.split(qkv, E, dim=2)

        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1,2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1,2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1,2)

        dropp=dropout if self.training==True else 0
        out = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropp) # [B, nHeads, T, head_dim]

        # recombine heads: (B, nHeads, T, head_dim) -> [B, T, E]
        out = out.transpose(1,2).view(B, T, E)

        # finally, linearly mix the attention heads
        out = self.W0(out)

        return out

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()

        self.layernorm_1 = nn.LayerNorm(embed_dim, eps=1e-5)
        self.attn = MultiHeadAttention()

        self.layernorm_2 = nn.LayerNorm(embed_dim, eps=1e-5)

        self.mlp_1 = nn.Linear(embed_dim, 4*embed_dim, bias=True)
        self.gelu = nn.GELU()
        self.mlp_2 = nn.Linear(4*embed_dim, embed_dim, bias=True)

        #n transformer block dropout
        self.trn_dropout = nn.Dropout(dropout)

    def forward(self, x):
        x_att = self.layernorm_1(x)
        x_att = self.trn_dropout(self.attn(x_att)) + x

        x_ff = self.layernorm_2(x_att)
        x_ff = self.mlp_2(self.gelu( self.mlp_1(x_ff) )) # expansion-contraction
        x_ff = x_att + self.trn_dropout(x_ff) #n dropout the MLP and add back to the embeddings vectors

        return x_ff

In [None]:
class LLM(nn.Module):
    def __init__(self):
        super().__init__()

        self.wte = nn.Embedding(n_vocab, embed_dim)
        self.wpe = nn.Embedding(seq_len, embed_dim)
        #n dropout
        self.emb_dropout = nn.Dropout(dropout)

        self.transformerBlocks = nn.Sequential(*[TransformerBlock() for _ in range(n_blocks)])

        self.layernorm_final = nn.LayerNorm(embed_dim, eps=1e-5)

        self.final_head = nn.Linear(embed_dim, n_vocab, bias=False)
        self.final_head.weight = nn.Parameter(self.wte.weight)

        self.apply(self.weightInits)

    def weightInits(self, module):
        # revisit initialization to optimize for choice of activation function
        if isinstance(module, nn.Linear):
          nn.init.xavier_normal_(module.weight)
          if module.bias is not None:
            nn.init.zeros_(module.bias)

        if isinstance(module, nn.Embedding):
          nn.init.xavier_normal_(module.weight)


    def forward(self, idx):
        token_embeddings = self.wte(idx)
        pos_embeddings = self.wpe(torch.arange(idx.shape[-1], device=device))
        x = token_embeddings + pos_embeddings
        x = self.emb_dropout(x) #n dropout after summing E+P

        x = self.transformerBlocks(x)
        x = self.layernorm_final(x)

        logits = self.final_head(x)

        outputs = F.log_softmax(logits/np.sqrt(embed_dim),dim=-1)

        return outputs

    def generate(self, idx, max_new_tokens=50):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -seq_len:]
            logits = self(idx_cond)
            logits = logits[:, -1, :]
            probs = torch.exp(logits)

            idx_next = torch.multinomial(probs, num_samples=1)

            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [None]:

CHECKPOINT_PATH = '/content/gdrive/My Drive/my_checkpoint.pth'
def save_checkpoint(model, optimizer, epoch, loss, filepath):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, filepath)

In [None]:
# create an instance and test with some data
model = LLM().to(device)
#load from gdrive if reusing checkpointed model
#checkpoint = torch.load(CHECKPOINT_PATH, map_location=torch.device(device))
#model.load_state_dict(checkpoint['model_state_dict'])
#model.to(device)

In [None]:
%tensorboard --logdir=runs

In [None]:
loss_function = nn.NLLLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=.001, weight_decay=.01)

In [None]:
num_epochs = 4001

# initialize losses
train_loss = []
test_loss = []

for epoch in range(num_epochs):

  # get a batch of data
  X,y = get_data_batch()

  # move data to GPU
  X,y = X.to(device), y.to(device)

  # clear previous gradients
  model.zero_grad(set_to_none=True)

  # forward pass
  log_probs = model(X)

  # calculate the losses on the (reshaped) targets
  loss = loss_function(log_probs.view(-1,log_probs.shape[-1]),y.view(-1))

  # backprop
  loss.backward()

  if epoch%100==0:
    log_weights_histograms(model, epoch)
    log_gradients_histograms(model, epoch)
    writer.add_scalar('loss', loss, epoch)

  optimizer.step()

  # store the per-sample loss
  train_loss.append( loss.item() )

  if epoch%1000==0:
    save_checkpoint(model, optimizer, epoch, loss, CHECKPOINT_PATH)

  # evaluate the model with the test set
  if epoch%100==0:

    with torch.no_grad():
      X,y = get_data_batch(False)       # False -> testset data
      X,y = X.to(device), y.to(device)  # push it to the GPU
      out = model(X)                    # forward pass
      thisloss = loss_function(out.view(-1,out.shape[-1]),y.view(-1)) # calculate loss
      test_loss.append( thisloss.item() )

      # update our progress :)
      print(f'Epoch {epoch:4}, train loss: {train_loss[-1]:5.2f}, test loss: {test_loss[-1]:5.2f}')

In [None]:
# plot the losses
plt.plot(train_loss,'k',label='Train loss')
plt.plot(range(0,num_epochs,50),test_loss,'rs-',markerfacecolor='w',markersize=8,label='Test loss')

plt.legend()
plt.gca().set(xlabel='Epoch',ylabel='Loss')
plt.show()

In [None]:
prompt = 'I find likewise that your printer has been so'
in2gpt = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)

output = model.generate(in2gpt,max_new_tokens=5)
print(tokenizer.decode(output[0]).replace('\r','\n'))