GPT-J-6B Inference Demo¶

This notebook demonstrates how to run the GPT-J-6B model. See the link for more details about the model, including evaluation metrics and credits.

Install Dependencies
First we download the model and install some dependencies. This step takes at least 5 minutes (possibly longer depending on server load).

!!! Make sure you are using a TPU runtime! !!!

In [2]:
print(1)
!apt install zstd
print(2)
# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
!time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
print(3)
!time tar -I zstd -xf step_383500_slim.tar.zstd
print(4)
!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
print(5)
!pip install -r mesh-transformer-jax/requirements.txt
print(6)
# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install mesh-transformer-jax/ jax==0.2.12
print(7)

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Note: you may need to restart the kernel to use updated packages.


Execute the below cell when you are executing this code on Google CoLab.

In [None]:
#import os
#import requests 
#from jax.config import config

#colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
#url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
#requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
#config.FLAGS.jax_xla_backend = "tpu_driver"
#config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

In [6]:
!pip install optax
!pip install transformers
print('optax and transformers installed successfully')

In [9]:
print("number of devices are:", jax.device_count())


number of devices are: 1


In [5]:
import time
import jax

from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

DeviceArray(2, dtype=int32)

In [None]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

print('Created Tokenizer successfully')

In [None]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt(network.state, "step_383500/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

Run Model

Finally, we are ready to infer with the model! The first sample takes around a minute due to compilation, but after that it should only take about 10 seconds per sample.

Feel free to mess with the different sampling parameters (top_p and temp), as well as the length of the generations (gen_len, causes a recompile when changed).

You can also change other things like per_replica_batch in the previous cells to change how many generations are done in parallel. A larger batch has higher latency but higher throughput when measured in tokens generated/s. This is useful for doing things like best-of-n cherry picking.

Tip for best results: Make sure your prompt does not have any trailing spaces, which tend to confuse the model due to the BPE tokenization used during training.

In [None]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    print(f"completion done in {time.time() - start:06}s")
    return samples

print(infer("EleutherAI is")[0])

In [None]:
#@title  { form-width: "300px" }
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

context = """In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."""

print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])