<a href="https://colab.research.google.com/github/maxmatical/ml-cheatsheet/blob/master/GPT_J_6B_Topic_Modelling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPT-J-6B Inference Demo

<a href="http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to run the [GPT-J-6B model](https://github.com/kingoflolz/mesh-transformer-jax/#GPT-J-6B). See the link for more details about the model, including evaluation metrics and credits.

## Install Dependencies

First we download the model and install some dependencies. This step takes at least 5 minutes (possibly longer depending on server load).

!!! **Make sure you are using a TPU runtime!** !!!

In [None]:
!apt install zstd

# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
!time wget https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
# for full fp32 model w/ optimizer params, 61gb
# https://the-eye.eu/public/AI/GPT-J-6B/step_383500.tar.zstd 

!time tar -I zstd -xf step_383500_slim.tar.zstd

!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r mesh-transformer-jax/requirements.txt

# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install mesh-transformer-jax/ jax==0.2.12

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following NEW packages will be installed:
  zstd
0 upgraded, 1 newly installed, 0 to remove and 40 not upgraded.
Need to get 278 kB of archives.
After this operation, 1,141 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 zstd amd64 1.3.3+dfsg-2ubuntu1.2 [278 kB]
Fetched 278 kB in 1s (299 kB/s)
Selecting previously unselected package zstd.
(Reading database ... 160837 files and directories currently installed.)
Preparing to unpack .../zstd_1.3.3+dfsg-2ubuntu1.2_amd64.deb ...
Unpacking zstd (1.3.3+dfsg-2ubuntu1.2) ...
Setting up zstd (1.3.3+dfsg-2ubuntu1.2) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
--2021-07-22 18:55:29--  https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
Resolving the-eye.eu (the-eye.eu)... 162.213.130.242
Connecting to the-eye.eu (the-eye.eu)|162.213.130.242|:443... connected.
HT

Processing ./mesh-transformer-jax
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Collecting jax==0.2.12
  Downloading jax-0.2.12.tar.gz (590 kB)
[K     |████████████████████████████████| 590 kB 5.4 MB/s 
Building wheels for collected packages: mesh-transformer, jax
  Building wheel for mesh-transformer (setup.py) ... [?25l[?25hdone
  Created wheel for mesh-transformer: filename=mesh_transformer-0.0.0-py3-none-any.whl size=24001 sha256=49e9ee9d0c301e4f06d874b96bee77a4b54599f7af2156f6336b609f1698e67f
  Stored in directory: /root/.cache/pip/wheels/56/bd/89/b1f6b2f3d6b938d0c5812ee97756a1afd32521bea293543863
  Building wheel for jax (se

## Setup Model


In [None]:
import os
import requests 
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

Sometimes the next step errors for some reason, just run it again ¯\\\_(ツ)\_/¯

In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

from tqdm import tqdm

In [None]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355256.0, style=ProgressStyle(descript…




Here we create the network and load the parameters from the downloaded files. Expect this to take around 5 minutes.

In [None]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt(network.state, "step_383500/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

  warn("xmap is an experimental feature and probably has bugs!")


key shape (8, 2)
in shape (1, 2048)
dp 1
mp 8
Total parameters: 6053381344
read from disk/gcs in 43.234s


## Run Model

Finally, we are ready to infer with the model! The first sample takes around a minute due to compilation, but after that it should only take about 10 seconds per sample.

Feel free to mess with the different sampling parameters (top_p and temp), as well as the length of the generations (gen_len, causes a recompile when changed).

You can also change other things like per_replica_batch in the previous cells to change how many generations are done in parallel. A larger batch has higher latency but higher throughput when measured in tokens generated/s. This is useful for doing things like best-of-n cherry picking.

*Tip for best results: Make sure your prompt does not have any trailing spaces, which tend to confuse the model due to the BPE tokenization used during training.*

In [None]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    # print(f"completion done in {time.time() - start:06}s")
    return samples

# print(infer("EleutherAI is")[0])

In [None]:
import pandas as pd
df = pd.read_csv("gpt3-test-july20.csv", sep="\t")
df.head(5)

Unnamed: 0,message
0,cancel my account
1,i want to add new line
2,speak with an agent
3,am i able to cancel my plan and close ts numbe...
4,unfortunately i am moving somewhere without fr...


In [None]:
prompt = """
classify the following text
text: i need help.
class: help generic
text: need assistance with my billing.
class: billing issue
text: help logging into my freedom account.
class: account login issue
text: can you help me i got locked out of my iphone.
class: unlock phone
text: want to speak to a representative to help me please.
class: speak agent
text: help with my freedom network.
class: network issue
text: need to change my current address.
class: change address
text: i need a person now.
class: speak agent
text: unlock my huawei device.
class: unlock phone
text: help with billing.
class: billing issue
text: check my account number.
class: account issue
text: we might need to change s plan because he's incurring overage charges
class: change plan
text: i will be moving and need to change the address on my account
class: change address
text: help logging into my freedom account.
class: account login issue
text: """


In [None]:
input = prompt+"so that's awesome and all i needed"+"."
# print(input)

In [None]:
out = infer(input, top_p=0.4, temp=0., gen_len = 64)[0]
print(out)

[1m
classify the following text
text: i need help.
class: help generic
text: need assistance with my billing.
class: billing issue
text: help logging into my freedom account.
class: account login issue
text: can you help me i got locked out of my iphone.
class: unlock phone
text: want to speak to a representative to help me please.
class: speak agent
text: help with my freedom network.
class: network issue
text: need to change my current address.
class: change address
text: i need a person now.
class: speak agent
text: unlock my huawei device.
class: unlock phone
text: help with billing.
class: billing issue
text: check my account number.
class: account issue
text: we might need to change s plan because he's incurring overage charges
class: change plan
text: i will be moving and need to change the address on my account
class: change address
text: help logging into my freedom account.
class: account login issue

text: so that's awesome and all i needed.[0m
class: help generic
text: ca

In [None]:
out.split("\x1b[0m")[1].split("\n")[1].split(":")[1].strip()

'help generic'

In [None]:
data = []
for msg in tqdm(list(df.message)):
  input = prompt+msg+"."
  out = infer(input, top_p=0.5, temp=0., gen_len = 64)[0]
  try:
    c = out.split("\x1b[0m")[1].split("\n")[1].split(":")[1]
  except:
    print("error with the following message")
    print(out)
    c = ""
  data.append(
      {
          "text": msg,
          "class": c.strip()
      }
  )

100%|██████████| 118/118 [03:48<00:00,  1.94s/it]


In [None]:
df_res = pd.DataFrame(data=data)
df_res

Unnamed: 0,text,class
0,cancel my account,cancel account
1,i want to add new line,add new line
2,speak with an agent,speak agent
3,am i able to cancel my plan and close ts numbe...,cancel plan
4,unfortunately i am moving somewhere without fr...,help
...,...,...
113,i will cancel close my account and pay the rem...,cancel
114,i would like to close my account,close account
115,i just want to close my account,close account
116,so i have three prepaid numbers that i paid ye...,add a new number


In [None]:
len(df_res["class"].unique())

60

In [None]:
# df_res.to_csv("gpt3_group_7.csv", index=False)

context summarization

In [None]:
prompt = """
extract context from the following text

text: need assistance with my billing.
context: no context given
text: i am unable to access it due to the fact that my email associated with the account is no longer able to be used
context: email no longer used
text: i will be moving and need to change the address on my account
context: will be moving
text: help logging into my freedom account.
context: no context given
text: we might need to change s plan because he's incurring overage charges
context: incurring overage charges
text: and received a message saying we need a stronger pin
context: require stronger pin
text: """


In [None]:
contexts = []
for msg in tqdm(list(df.message)):
  input = prompt+msg+"."
  out = infer(input, top_p=0.5, temp=0., gen_len = 64)[0]
  try:
    c = out.split("\x1b[0m")[1].split("\n")[1].split(":")[1]
  except:
    print("error with the following message")
    print(out)
    c = ""

  contexts.append(c.strip())

df_res["context"] = contexts

 40%|███▉      | 47/118 [01:31<02:17,  1.94s/it]

error with the following message
[1m
extract context from the following text

text: need assistance with my billing.
context: no context given
text: i am unable to access it due to the fact that my email associated with the account is no longer able to be used
context: email no longer used
text: i will be moving and need to change the address on my account
context: will be moving
text: help logging into my freedom account.
context: no context given
text: we might need to change s plan because he's incurring overage charges
context: incurring overage charges
text: and received a message saying we need a stronger pin
context: require stronger pin
text: i 'm on direct deposit to pay my bill and you have n't taken money from me yet for march.[0m
context:n't taken money
text: we are a part of a p&g.
context: p&g
text: we are a p&g, and the p&g is a p&g.
context: p&g
text: the p&g is a p&


100%|██████████| 118/118 [03:48<00:00,  1.94s/it]


In [None]:
df_res.head(5)

Unnamed: 0,text,class,context
0,cancel my account,cancel account,no context given
1,i want to add new line,add new line,no context given
2,speak with an agent,speak agent,no context given
3,am i able to cancel my plan and close ts numbe...,cancel plan,not here
4,unfortunately i am moving somewhere without fr...,help,will be moving


In [None]:
df_res.to_csv("gpt3-topic-mining-july-results.csv", index=False)

In [None]:
"""
text:
content: ... \n
context:  ... \n
"""

# Topic clustering

In [None]:
!pip install sentence-transformers
!pip install umap-learn
!pip install hdbscan

In [None]:
import pandas as pd
from sentence_transformers import SentenceTransformer


In [None]:
df = pd.read_csv("gpt3-topic-mining-results.csv")

In [None]:
# settings
"""
couple models to try
 paraphrase-distilroberta-base-v2 
 paraphrase-TinyBERT-L6-v2 
 stsb-distilroberta-base-v2
 
"""
pretrained_model = "paraphrase-distilroberta-base-v2"
model = SentenceTransformer(pretrained_model)


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=736.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3721.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=686.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=122.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456356.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=229.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=328515953.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=53.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=239.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355881.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1123.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=798293.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=190.0, style=ProgressStyle(description_…




In [None]:
topics = list(df["class"])
embeddings = model.encode(topics)
print(embeddings.shape)

(118, 768)


## reduce dimensionality with umap

In [None]:
import umap
umap_embeddings = umap.UMAP(n_neighbors=15, 
                            n_components=5, 
                            metric='cosine').fit_transform(embeddings)



## clustering - HDBSCAN

In [None]:
import hdbscan
cluster_hdbscan = hdbscan.HDBSCAN(min_cluster_size=15,
                          metric='euclidean',                      
                          cluster_selection_method='eom').fit(embeddings)

cluster_hdbscan_reduced = hdbscan.HDBSCAN(min_cluster_size=15,
                          metric='euclidean',                      
                          cluster_selection_method='eom').fit(umap_embeddings)

In [None]:
df["grouped_topics_hdbscan"] = cluster_hdbscan.labels_
df["grouped_topics_hdbscan_reduced"] = cluster_hdbscan_reduced.labels_

In [None]:
cluster.labels_

array([ 2,  0, -1,  2,  0, -1,  0,  1,  1, -1,  2,  2,  2,  1,  1,  0, -1,
        1,  1, -1,  0,  0, -1, -1, -1,  2,  2,  1,  1,  0,  0,  2,  2,  2,
        1,  1, -1, -1, -1,  2,  2,  2,  2, -1,  0,  1, -1, -1, -1, -1,  1,
        1, -1,  2,  2,  1,  0,  0, -1,  2,  2,  0,  0,  1,  2, -1,  1,  1,
       -1, -1,  2, -1,  2,  1,  2, -1,  1, -1,  1,  2,  2, -1,  0,  1, -1,
       -1,  0,  2,  2,  2,  0,  1,  1,  1,  2, -1, -1,  1,  0,  1,  1,  1,
        2, -1,  2, -1,  2,  1,  0,  2,  2,  2,  2,  2,  2,  2,  1, -1])

In [None]:
df.head(20)

Unnamed: 0,text,class,context,grouped_topics,grouped_topics_kmeans,grouped_topics_kmeans_reduced,grouped_topics_hdbscan,grouped_topics_hdbscan_reduced
0,cancel my account,cancel account,no context given,2,3,1,-1,2
1,i want to add new line,add new line,no context given,0,6,3,-1,1
2,speak with an agent,speak agent,no context given,-1,7,7,-1,2
3,am i able to cancel my plan and close ts numbe...,cancel plan,not here,2,3,1,-1,2
4,unfortunately i am moving somewhere without fr...,help,will be moving,0,1,2,-1,0
5,i paid the bill it reflected on my bank statem...,pay bill,it did not,-1,2,4,-1,-1
6,i 'm filing my taxes today and i need your hel...,help with my taxes,need help with my phone plan payments,0,1,2,-1,0
7,good afternoon freedom mobile customer service...,block number,can you please block my number,1,5,8,-1,1
8,please block my phone number from scammer,block number,no context given,1,5,8,-1,1
9,i would like to pay my phone bill,pay bill,no context given,-1,2,4,-1,-1


## clustering - KMeans

In [None]:
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=10).fit(embeddings)

In [None]:
kmeans.labels_

array([6, 4, 9, 1, 3, 7, 3, 4, 4, 7, 0, 6, 8, 4, 1, 3, 4, 4, 5, 2, 3, 3,
       9, 2, 7, 1, 8, 5, 4, 3, 3, 9, 8, 8, 4, 5, 2, 7, 7, 8, 1, 8, 6, 7,
       3, 4, 7, 9, 7, 7, 5, 5, 9, 1, 6, 4, 3, 3, 7, 1, 8, 3, 3, 4, 0, 2,
       4, 4, 2, 2, 8, 9, 6, 5, 1, 2, 4, 7, 5, 1, 1, 7, 3, 4, 2, 2, 3, 1,
       8, 8, 3, 5, 5, 5, 6, 7, 2, 4, 3, 4, 5, 5, 1, 2, 0, 7, 1, 4, 3, 6,
       0, 6, 6, 1, 6, 6, 4, 2], dtype=int32)

In [None]:
kmeans_reduced = KMeans(n_clusters=10).fit(embeddings)

In [None]:
kmeans_reduced.labels_

array([3, 8, 6, 3, 2, 4, 2, 5, 5, 4, 8, 3, 9, 5, 8, 8, 8, 8, 7, 1, 2, 2,
       6, 1, 4, 3, 9, 7, 8, 2, 2, 6, 9, 9, 0, 7, 1, 4, 4, 9, 8, 9, 8, 0,
       2, 8, 4, 6, 4, 4, 7, 7, 6, 3, 3, 5, 2, 2, 4, 6, 9, 2, 2, 5, 8, 1,
       8, 8, 1, 1, 9, 6, 8, 7, 3, 1, 8, 4, 7, 3, 3, 4, 2, 7, 1, 1, 2, 6,
       9, 9, 2, 7, 7, 7, 3, 4, 1, 5, 2, 0, 7, 7, 3, 1, 8, 4, 3, 7, 2, 3,
       8, 3, 3, 3, 3, 3, 5, 1], dtype=int32)

In [None]:
df["grouped_topics_kmeans"] = kmeans.labels_
df["grouped_topics_kmeans_reduced"] = kmeans_reduced.labels_

In [None]:
df.head(20)

Unnamed: 0,text,class,context,grouped_topics,grouped_topics_kmeans,grouped_topics_kmeans_reduced,grouped_topics_hdbscan,grouped_topics_hdbscan_reduced
0,cancel my account,cancel account,no context given,2,6,3,-1,2
1,i want to add new line,add new line,no context given,0,4,8,-1,1
2,speak with an agent,speak agent,no context given,-1,9,6,-1,2
3,am i able to cancel my plan and close ts numbe...,cancel plan,not here,2,1,3,-1,2
4,unfortunately i am moving somewhere without fr...,help,will be moving,0,3,2,-1,0
5,i paid the bill it reflected on my bank statem...,pay bill,it did not,-1,7,4,-1,-1
6,i 'm filing my taxes today and i need your hel...,help with my taxes,need help with my phone plan payments,0,3,2,-1,0
7,good afternoon freedom mobile customer service...,block number,can you please block my number,1,4,5,-1,1
8,please block my phone number from scammer,block number,no context given,1,4,5,-1,1
9,i would like to pay my phone bill,pay bill,no context given,-1,7,4,-1,-1
