<a href="https://colab.research.google.com/github/maxmatical/fast.ai/blob/master/GPT_J_6B_Topic_Modelling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPT-J-6B Inference Demo

<a href="http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to run the [GPT-J-6B model](https://github.com/kingoflolz/mesh-transformer-jax/#GPT-J-6B). See the link for more details about the model, including evaluation metrics and credits.

## Install Dependencies

First we download the model and install some dependencies. This step takes at least 5 minutes (possibly longer depending on server load).

!!! **Make sure you are using a TPU runtime!** !!!

In [1]:
!apt install zstd

# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
!time wget https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
# for full fp32 model w/ optimizer params, 61gb
# https://the-eye.eu/public/AI/GPT-J-6B/step_383500.tar.zstd 

!time tar -I zstd -xf step_383500_slim.tar.zstd

!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r mesh-transformer-jax/requirements.txt

# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install mesh-transformer-jax/ jax==0.2.12

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following NEW packages will be installed:
  zstd
0 upgraded, 1 newly installed, 0 to remove and 39 not upgraded.
Need to get 278 kB of archives.
After this operation, 1,141 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 zstd amd64 1.3.3+dfsg-2ubuntu1.2 [278 kB]
Fetched 278 kB in 1s (447 kB/s)
Selecting previously unselected package zstd.
(Reading database ... 160772 files and directories currently installed.)
Preparing to unpack .../zstd_1.3.3+dfsg-2ubuntu1.2_amd64.deb ...
Unpacking zstd (1.3.3+dfsg-2ubuntu1.2) ...
Setting up zstd (1.3.3+dfsg-2ubuntu1.2) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
--2021-06-15 13:46:47--  https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
Resolving the-eye.eu (the-eye.eu)... 162.213.130.242
Connecting to the-eye.eu (the-eye.eu)|162.213.130.242|:443... connected.
HT

Processing ./mesh-transformer-jax
Collecting jax==0.2.12
[?25l  Downloading https://files.pythonhosted.org/packages/9a/67/d1a9c94104c559b49bbcb72e9efc33859e982d741ea4902d2a00e66e09d9/jax-0.2.12.tar.gz (590kB)
[K     |████████████████████████████████| 593kB 7.1MB/s 
Building wheels for collected packages: jax, mesh-transformer
  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Created wheel for jax: filename=jax-0.2.12-cp37-none-any.whl size=682484 sha256=4fe8b4c4d46700321d6552bb561de54fa81cdaea9f1f342aa3925a217bff0bc6
  Stored in directory: /root/.cache/pip/wheels/cf/00/88/75c2043dff473f58e892c7e6adfd2c44ccefb6111fcc021e5b
  Building wheel for mesh-transformer (setup.py) ... [?25l[?25hdone
  Created wheel for mesh-transformer: filename=mesh_transformer-0.0.0-cp37-none-any.whl size=20010 sha256=506067c0623bc90421a7f375099ec3bfeb9ec0b953a58053425c0c9e8900c59c
  Stored in directory: /root/.cache/pip/wheels/de/a9/d2/2be3e25299342b60fca7965d4e416264ff8b6d8a7e8def76da
Successfull

## Setup Model


In [2]:
import os
import requests 
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

Sometimes the next step errors for some reason, just run it again ¯\\\_(ツ)\_/¯

In [4]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

from tqdm import tqdm

In [5]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355256.0, style=ProgressStyle(descript…




Here we create the network and load the parameters from the downloaded files. Expect this to take around 5 minutes.

In [6]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt(network.state, "step_383500/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

  warn("xmap is an experimental feature and probably has bugs!")


key shape (8, 2)
in shape (1, 2048)
dp 1
mp 8
read from disk/gcs in 32.0624s


## Run Model

Finally, we are ready to infer with the model! The first sample takes around a minute due to compilation, but after that it should only take about 10 seconds per sample.

Feel free to mess with the different sampling parameters (top_p and temp), as well as the length of the generations (gen_len, causes a recompile when changed).

You can also change other things like per_replica_batch in the previous cells to change how many generations are done in parallel. A larger batch has higher latency but higher throughput when measured in tokens generated/s. This is useful for doing things like best-of-n cherry picking.

*Tip for best results: Make sure your prompt does not have any trailing spaces, which tend to confuse the model due to the BPE tokenization used during training.*

In [7]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [8]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    # print(f"completion done in {time.time() - start:06}s")
    return samples

# print(infer("EleutherAI is")[0])

In [None]:
import pandas as pd
df = pd.read_csv("gpt3_test_group_7.tsv", sep="\t")
df.head(5)

In [12]:
prompt = """
text: i need help logging into my freedom account
class: help_login
text: i need help with my freedom network
class: help_network
text: i need to change my current address
class: change_address
text: i need a person now
class: speak_agent
text: i need is a representative to help me please
class: speak_agent
text: i need help with billing
class: help_billing
text: """



In [13]:
input = prompt+"i need to get copies of my phone bills for"
print(input)


text: i need help logging into my freedom account
class: help_login
text: i need help with my freedom network
class: help_network
text: i need to change my current address
class: change_address
text: i need a person now
class: speak_agent
text: i need is a representative to help me please
class: speak_agent
text: i need help with billing
class: help_billing
text: i need to get copies of my phone bills for


In [15]:
out = infer(input, top_p=0.5, temp=0., gen_len = 32)[0]
print(out)

[1m
text: i need help logging into my freedom account
class: help_login
text: i need help with my freedom network
class: help_network
text: i need to change my current address
class: change_address
text: i need a person now
class: speak_agent
text: i need is a representative to help me please
class: speak_agent
text: i need help with billing
class: help_billing
text: i need to get copies of my phone bills for[0m a year
class: get_phone_bills
text: i need to get copies of my phone bills for a year
class: get_phone


In [16]:
out.split("\x1b[0m")[1].split("\n")[1].split(": ")[1]

'get_phone_bills'

In [None]:
data = []
for msg in tqdm(list(df.message)):
  input = prompt+msg
  out = infer(input, top_p=0.5, temp=0., gen_len = 32)[0]
  c = out.split("\x1b[0m")[1].split("\n")[1].split(": ")[1]
  data.append(
      {
          "text": msg,
          "class": c
      }
  )

 56%|█████▋    | 53/94 [00:58<00:45,  1.12s/it]

completion done in 1.1142046451568604s


 57%|█████▋    | 54/94 [00:59<00:44,  1.11s/it]

completion done in 1.1041615009307861s


 59%|█████▊    | 55/94 [01:01<00:43,  1.11s/it]

completion done in 1.1042964458465576s


 60%|█████▉    | 56/94 [01:02<00:42,  1.11s/it]

completion done in 1.1024456024169922s


 61%|██████    | 57/94 [01:03<00:41,  1.11s/it]

completion done in 1.0973320007324219s


 62%|██████▏   | 58/94 [01:04<00:39,  1.11s/it]

completion done in 1.1022074222564697s


 63%|██████▎   | 59/94 [01:05<00:38,  1.11s/it]

completion done in 1.1013517379760742s


 64%|██████▍   | 60/94 [01:06<00:37,  1.11s/it]

completion done in 1.100005865097046s


 65%|██████▍   | 61/94 [01:07<00:36,  1.11s/it]

completion done in 1.098825216293335s


 66%|██████▌   | 62/94 [01:08<00:35,  1.11s/it]

completion done in 1.0995919704437256s


 67%|██████▋   | 63/94 [01:09<00:34,  1.11s/it]

completion done in 1.1038846969604492s


 68%|██████▊   | 64/94 [01:11<00:33,  1.11s/it]

completion done in 1.1087276935577393s


 69%|██████▉   | 65/94 [01:12<00:32,  1.11s/it]

completion done in 1.116987705230713s


 70%|███████   | 66/94 [01:13<00:31,  1.11s/it]

completion done in 1.1083462238311768s


 71%|███████▏  | 67/94 [01:14<00:30,  1.11s/it]

completion done in 1.1124680042266846s


 72%|███████▏  | 68/94 [01:15<00:29,  1.12s/it]

completion done in 1.111543893814087s


 73%|███████▎  | 69/94 [01:16<00:27,  1.12s/it]

completion done in 1.108982801437378s


 74%|███████▍  | 70/94 [01:17<00:26,  1.12s/it]

completion done in 1.1117477416992188s


 76%|███████▌  | 71/94 [01:18<00:25,  1.12s/it]

completion done in 1.110285758972168s


 77%|███████▋  | 72/94 [01:19<00:24,  1.12s/it]

completion done in 1.105137586593628s


 78%|███████▊  | 73/94 [01:21<00:23,  1.12s/it]

completion done in 1.1106650829315186s


 79%|███████▊  | 74/94 [01:22<00:22,  1.12s/it]

completion done in 1.1115610599517822s


 80%|███████▉  | 75/94 [01:23<00:21,  1.11s/it]

completion done in 1.0988500118255615s


 81%|████████  | 76/94 [01:24<00:19,  1.11s/it]

completion done in 1.0991706848144531s


 82%|████████▏ | 77/94 [01:25<00:18,  1.11s/it]

completion done in 1.0978033542633057s


 83%|████████▎ | 78/94 [01:26<00:17,  1.11s/it]

completion done in 1.0988655090332031s


 84%|████████▍ | 79/94 [01:27<00:16,  1.11s/it]

completion done in 1.1014583110809326s


 85%|████████▌ | 80/94 [01:28<00:15,  1.11s/it]

completion done in 1.10347580909729s


 86%|████████▌ | 81/94 [01:29<00:14,  1.11s/it]

completion done in 1.1032214164733887s


 87%|████████▋ | 82/94 [01:31<00:13,  1.11s/it]

completion done in 1.100987195968628s


 88%|████████▊ | 83/94 [01:32<00:12,  1.11s/it]

completion done in 1.109684705734253s


 89%|████████▉ | 84/94 [01:33<00:11,  1.11s/it]

completion done in 1.1049270629882812s


 90%|█████████ | 85/94 [01:34<00:09,  1.11s/it]

completion done in 1.0971934795379639s


 91%|█████████▏| 86/94 [01:35<00:08,  1.11s/it]

completion done in 1.1011962890625s


 93%|█████████▎| 87/94 [01:36<00:07,  1.11s/it]

completion done in 1.1006746292114258s


 94%|█████████▎| 88/94 [01:37<00:06,  1.11s/it]

completion done in 1.1068167686462402s


 95%|█████████▍| 89/94 [01:38<00:05,  1.11s/it]

completion done in 1.0999782085418701s


 96%|█████████▌| 90/94 [01:39<00:04,  1.11s/it]

completion done in 1.09977126121521s


 97%|█████████▋| 91/94 [01:41<00:03,  1.11s/it]

completion done in 1.102379322052002s


 98%|█████████▊| 92/94 [01:42<00:02,  1.11s/it]

completion done in 1.1020255088806152s


 99%|█████████▉| 93/94 [01:43<00:01,  1.11s/it]

completion done in 1.100691556930542s


100%|██████████| 94/94 [01:44<00:00,  1.11s/it]

completion done in 1.1028633117675781s





In [None]:
df_res = pd.DataFrame(data=data)
df_res

Unnamed: 0,text,class
0,i need help logging into my freedom account,help_login
1,how much do i need to add to my top up bill,add_topup
2,i need help with my freedom network,help_network
3,i need help to get my old number back,get_number
4,i need to reset my voicemail pin as i cannot a...,reset_pin
...,...,...
89,need to check the sr created for my case,help_case
90,i need to rest my password,rest_password
91,there i needed some information to port over m...,help_billing
92,i need to be reimbursed asap,help_reimburse


In [None]:
df_res.to_csv("gpt3_group_7.csv", index=False)