In [1]:
from transformers import GPT2Config, GPT2Model, GPT2Tokenizer, GPT2LMHeadModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
   
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import trange
from datasets import load_dataset

import os
import torch
import random
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset

from types import SimpleNamespace
from scipy.signal import savgol_filter
from sentencepiece import SentencePieceTrainer, SentencePieceProcessor

from sklearn.metrics import plot_confusion_matrix, confusion_matrix, ConfusionMatrixDisplay
from transformers import GPT2Model, GPT2Tokenizer, GPT2Config

### Add heavy-weight functions First!

In [83]:
CGREYBG   = '\33[100m'
CREDBG    = '\33[41m'
CGREENBG  = '\33[42m'
CYELLOWBG = '\33[43m'
CBLUEBG   = '\33[44m'
CVIOLETBG = '\33[45m'
CBEIGE2  = '\33[46m'
CEND = '\033[0m'

color_by_stars = {
  0: CREDBG,
  1: CYELLOWBG,
  2: CBEIGE2,
  3: CVIOLETBG,
  4: CBLUEBG
}

color_names = [
  f"{color_by_stars[0]} {CEND}: 1 star",
  f"{color_by_stars[1]} {CEND}: 2 star",
  f"{color_by_stars[2]} {CEND}: 3 star",
  f"{color_by_stars[3]} {CEND}: 4 star",
  f"{color_by_stars[4]} {CEND}: 5 star",
]

def get_colored_help():
  return " ".join(color_names)

def get_colored_text(tokens, scores, from_idx = 0, h = True):
  _str = ""
  if h:
    _str = get_colored_help()
    _str += "\n"
  for i,(t,s) in enumerate(zip(tokens, scores)):
    if i >= from_idx:
      _str += color_by_stars[s] + t + " " + CEND
    else:
      _str += t + " "
  
  return _str


print(get_colored_text(
  ["In", "this", "world", "of", "fast", "mov", "ing" , "text"],
  [0, 1, 2, 3, 4, 2, 3, 1, 2],
  1
))
print("target ->", color_names[2])

[41m [0m: 1 star [43m [0m: 2 star [46m [0m: 3 star [45m [0m: 4 star [44m [0m: 5 star
In [43mthis [0m[46mworld [0m[45mof [0m[44mfast [0m[46mmov [0m[45ming [0m[43mtext [0m
target -> [46m [0m: 3 star


In [3]:
def fetch(url):
  # quick hack for downloading and caching file in ".tmp"
  # so next time you fetch() it's loaded locally
  # https://github.com/geohot/tinygrad/blob/master/extra/utils.py
  import requests, os, hashlib, tempfile
  fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest())
  if os.path.isfile(fp) and os.stat(fp).st_size > 0:
    with open(fp, "rb") as f:
      dat = f.read()
  else:
    print("fetching %s" % url)
    dat = requests.get(url).content
    with open(fp+".tmp", "wb") as f:
      f.write(dat)
    os.rename(fp+".tmp", fp)
  return dat

In [4]:
model = GPT2LMHeadModel.from_pretrained("gpt2-xl")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl")
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
model = model.to(device)

In [178]:
RANDOM_TEXT_SAMPLES = [
  "Science Today: Sirius, designated Alpha Canis Majoris, is the brightest star in the night sky",
  "Did you know Ava Cherry (pictured), David Bowie's partner and muse, spent a year searching for him in Europe?",
  "In the news: Samia Suluhu (pictured) becomes the first female president of Tanzania",
  "In the US, a mass shootings at three massage parlors in Atlanta leaves eight dead, including six women of Asian descent.",
  "These representations, learned in an unsupervised manner, achieve state of the art on the binary subset of the Stanford Sentiment Treebank.",
  "Can it extract more precise slices of code? Yes. First submit a pull request telling us the desired",
  "Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard",
  "Roman society under the Republic was primarily a cultural mix of Latin and",
  "Two of the Mahājanapadas were most probably ganatantras (oligarchic republics) and others",
  "The Magadha was one of the most prominent and prosperous of mahajanapadas. The capital city Pataliputra",
  "Parts of western India were dominated by tribes who had a slightly different culture, considered non-Vedic by the mainstream"
]

NUM_CLUSTERS = 5

In [179]:
RANDOM_MAT = torch.randn(model.config.n_embd, NUM_CLUSTERS, requires_grad = False, device = device)

In [180]:
%%time
# get encodings
# text = random.choice(RANDOM_TEXT_SAMPLES)
text = RANDOM_TEXT_SAMPLES[-1]
data = {k:v.to(device) for k,v in tokenizer(text, return_tensors="pt").items()}
B, S = data["input_ids"].shape
print(B, S)
beam_outputs = model.generate(
  **data,
  max_length = S + 100,
  return_str = True,
  seed = 4,
  do_sample = True,
  temperature = 0.9,
  top_k = 40,
  top_p=0.95,
  num_beams = 1,
  early_stopping=True,
  num_return_sequences = 10,
  return_dict_in_generate=True,
  output_hidden_states=True
)
beam_tokens_expanded = [[tokenizer.decode(y, skip_special_tokens = True) for y in x] for x in beam_outputs.sequences.cpu().tolist()]

# cluster values
logits = torch.cat([x.unsqueeze(0) for x in beam_outputs.hidden_states[-1]])[:, :, 0, :]
logits_clus = logits @ RANDOM_MAT
logits_clus = logits_clus.permute((1, 0, 2)).argmax(-1)
logits_clus = logits_clus.tolist()

all_text = get_colored_help()
all_text += "\n"
for _beam_tokens, _beam_logits in zip(beam_tokens_expanded, logits_clus):
  all_text += "-"*70 + "\n"
  all_text += get_colored_text(_beam_tokens, _beam_logits, S, False) + "\n"
all_text += "-"*70
print(all_text)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


1 24
[41m [0m: 1 star [43m [0m: 2 star [46m [0m: 3 star [45m [0m: 4 star [44m [0m: 5 star
----------------------------------------------------------------------
Parts  of  western  India  were  dominated  by  tribes  who  had  a  slightly  different  culture ,  considered  non - V ed ic  by  the  mainstream [46m Hindu [0m[46m groups [0m[46m of [0m[46m South [0m[46m India [0m[46m. [0m[46m They [0m[46m are [0m[46m the [0m[46m ancestors [0m[46m of [0m[46m the [0m[46m Dal [0m[46mits [0m[46m. [0m[46m The [0m[46m two [0m[46m main [0m[46m groups [0m[46m of [0m[46m those [0m[46m tribes [0m[46m were [0m[46m the [0m[45m A [0m
----------------------------------------------------------------------
Parts  of  western  India  were  dominated  by  tribes  who  had  a  slightly  different  culture ,  considered  non - V ed ic  by  the  mainstream [46m, [0m[46m but [0m[46m had [0m[46m traditions [0m[46m that [0m[46m were [0m[46m v