<a href="https://colab.research.google.com/github/ggolani/ML/blob/main/GPT2_FineTuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM,GPT2Tokenizer
import requests

# vector plots
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
# load pretrained GPT-2 model and tokenizer
gpt2 = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# hyperparameters
seq_len    = 256 # max sequence length
batch_size =  16

# use GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
# tokenize the text
# Gulliver's travels :)
text = requests.get('https://www.gutenberg.org/cache/epub/829/pg829.txt').text


gtTokens = tokenizer.encode(text,return_tensors='pt')
print(gtTokens.shape)

# but the rest of the code is setup for dimensionless tensors
gtTokens = gtTokens[0]
print(gtTokens.shape)

In [None]:
# find the most frequent 100 tokens
uniq,counts = np.unique(gtTokens,return_counts=True)
freqidx = np.argsort(counts)[::-1]
top100 = uniq[freqidx[:100]]

In [None]:
numreps =  10 # number of random repetitions
numtoks = 100 # output length

# random starting tokens
randstarts = torch.randint(tokenizer.vocab_size,(numreps,1)).to(device)

# generate some data
out = gpt2.generate(
  randstarts,
  max_length = numtoks+1,
  min_length = numtoks+1,
  do_sample  = True,
  bad_words_ids = [tokenizer.encode(tokenizer.eos_token)],
  pad_token_id = tokenizer.encode(tokenizer.eos_token)[0]
).cpu()

# calculate and report the percentage
percentFreqTokens_pre = np.mean(100*np.isin(out[:,1:],top100).flatten())
print(f"Gulliver's travels common tokens appeared in {percentFreqTokens_pre}% of new tokens.")

In [None]:
class KLDivergenceLoss_x(nn.Module):
  def __init__(self):
    super().__init__()

    # mask: 1 if token contains a target, 0 otherwise
    self.mask = torch.zeros(tokenizer.vocab_size, device=device)
    for t in range(tokenizer.vocab_size):
      thistoken = tokenizer.decode([t])
      if 'x' in thistoken:
        self.mask[t] = 1

    # normalize to pdist
    self.mask = self.mask/torch.sum(self.mask)

  def forward(self, log_probs):
    # assumes log-softmax-prob input!
    return F.kl_div(log_probs, self.mask, reduction='batchmean')

In [None]:
# create a loss function instance
#loss_function = KLDivergenceLoss_x().to(device)

In [None]:
# move the model to the GPU
gpt2 = gpt2.to(device)

In [None]:
# check out some text
prompt = 'I cannot believe that'
in2gpt = tokenizer.encode(prompt,return_tensors='pt').to(device)

output = gpt2.generate(in2gpt,max_length=100,pad_token_id=50256,do_sample=True).cpu()
print(tokenizer.decode(output[0]))

In [None]:
 # create the optimizer functions (note the small learning rate)
optimizer = torch.optim.AdamW(gpt2.parameters(), lr=5e-5, weight_decay=.01)
#HF models provide their own inbuilt loss function

In [None]:
num_samples = 1234

# initialize losses
train_loss = np.zeros(num_samples)

for sampli in range(num_samples):

  # get a batch of data
  ix = torch.randint(len(gtTokens)-seq_len,size=(batch_size,))
  X  = gtTokens[ix[:,None] + torch.arange(seq_len)]

  # move data to GPU
  X = X.to(device)

  # clear previous gradients
  gpt2.zero_grad()

  # forward pass (Hugging Face shifts X internally to get y)

  output = gpt2(X, labels=X)
  # calculate the losses
  loss = output.loss
  # if using KL-divergence custom function
  #logits = gpt2(X, labels=X).logits
  #logits_reshape = logits.view(-1,tokenizer.vocab_size)
  #logprobs_reshape = F.log_softmax(logits_reshape,dim=-1)
  #loss = loss_function(logprobs_reshape)

  # backprop
  loss.backward()
  optimizer.step()

  # store the per-sample loss
  train_loss[sampli] = loss.item()

  # update progress display
  if sampli%77==0:
    print(f'Sample {sampli:4}/{num_samples}, train loss: {train_loss[sampli]:.4f}')

In [None]:
# plot the losses
plt.figure(figsize=(8,4))
plt.plot(train_loss,'k',markersize=8)

plt.gca().set(xlabel='Data sample',ylabel='Train loss',xlim=[-1,num_samples])
plt.show()

In [None]:
# Qualtative assessment
prompt = 'I cannot believe that'
in2gpt = tokenizer.encode(prompt,return_tensors='pt').to(device)

output = gpt2.generate(in2gpt,max_length=100,pad_token_id=50256)
print(tokenizer.decode(output[0]))

# Calculate percentage of GT tokens generated

In [None]:
# random starting tokens
randstarts = torch.randint(tokenizer.vocab_size,(numreps,1)).to(device)

# generate some data
out = gpt2.generate(
  randstarts,
  max_length = numtoks+1,
  min_length = numtoks+1,
  do_sample  = True,
  bad_words_ids = [tokenizer.encode(tokenizer.eos_token)],
  pad_token_id = tokenizer.encode(tokenizer.eos_token)[0]
).cpu()


for o in out:
  print('\n*** Next batch of output:')
  print(tokenizer.decode(o))

In [None]:
# calculate and report the percentage
percentFreqTokens_pst = np.mean(100*np.isin(out[:,1:],top100).flatten())

print(f'Common GT tokens usage went from {percentFreqTokens_pre:.2f}% to {percentFreqTokens_pst:.2f}% after fine-tuning.')