<a href="https://colab.research.google.com/github/dvschultz/ml-art-colabs/blob/master/GPT_J_6B_Inference_Demo_(Fixed).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPT-J-6B Inference Demo

<a href="http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to run the [GPT-J-6B model](https://github.com/kingoflolz/mesh-transformer-jax/#GPT-J-6B). See the link for more details about the model, including evaluation metrics and credits.

Looking for a dead simple version without configuration? Try the [Huggingface demo](https://huggingface.co/EleutherAI/gpt-j-6B?text=Do+I+contradict+myself%3F%0AVery+well+then+I+contradict+myself%2C%0A%28I+am+large%2C+I+contain+multitudes.%29).

## Install Dependencies

First we download the model and install some dependencies. This step takes at least 5 minutes (possibly longer depending on server load).

!!! **Make sure you are using a TPU runtime!** !!!

In [None]:
!apt install zstd.  

# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
!time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd

!time tar -I zstd -xf step_383500_slim.tar.zstd

!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
#!pip install -r /content/mesh-transformer-jax/requirements.txt

# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install mesh-transformer-jax/ jax==0.2.12 #tensorflow==2.5.0
!pip install git+git://github.com/deepmind/optax.git
!pip install git+https://github.com/EleutherAI/lm-evaluation-harness@c406a62047
!pip install transformers fabric wandb google-cloud-storage cloudpickle Flask ray[default] einops smart_open[gcs] func_timeout ftfy fastapi uvicorn lm_dataformat pathy
!pip install git+https://github.com/deepmind/dm-haiku
!pip install git+https://github.com/EleutherAI/lm-evaluation-harness/

Reading package lists... Done
Building dependency tree       
Reading state information... Done
Note, selecting 'libzstd-dev' for regex 'zstd.'
Note, selecting 'libzstd1' for regex 'zstd.'
Note, selecting 'libzstd1-dev' for regex 'zstd.'
The following NEW packages will be installed:
  libzstd-dev libzstd1-dev
The following packages will be upgraded:
  libzstd1
1 upgraded, 2 newly installed, 0 to remove and 36 not upgraded.
Need to get 423 kB of archives.
After this operation, 844 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 libzstd1 amd64 1.3.3+dfsg-2ubuntu1.2 [189 kB]
Get:2 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 libzstd-dev amd64 1.3.3+dfsg-2ubuntu1.2 [230 kB]
Get:3 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 libzstd1-dev amd64 1.3.3+dfsg-2ubuntu1.2 [4,492 B]
Fetched 423 kB in 1s (458 kB/s)
(Reading database ... 155219 files and directories currently installed.)
Preparing to unpack .../lib

## Setup Model


In [None]:
import os
import requests 
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

Sometimes the next step errors for some reason, just run it again ¯\\\_(ツ)\_/¯

In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

TypeError: ignored

In [None]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

Here we create the network and load the parameters from the downloaded files. Expect this to take around 5 minutes.

In [None]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

## Run Model

Finally, we are ready to infer with the model! The first sample takes around a minute due to compilation, but after that it should only take about 10 seconds per sample.

Feel free to mess with the different sampling parameters (top_p and temp), as well as the length of the generations (gen_len, causes a recompile when changed).

You can also change other things like per_replica_batch in the previous cells to change how many generations are done in parallel. A larger batch has higher latency but higher throughput when measured in tokens generated/s. This is useful for doing things like best-of-n cherry picking.

*Tip for best results: Make sure your prompt does not have any trailing spaces, which tend to confuse the model due to the BPE tokenization used during training.*

In [None]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    print(f"completion done in {time.time() - start:06}s")
    return samples

print(infer("EleutherAI is")[0])

completion done in 52.59220790863037s
[1mEleutherAI is[0m a creative company and engine integrator

If you are looking for a highly skilled audio visual production company with award winning creative team, looking for a reliable audio visual company that has over 12 years of experience and has built a reputation for creating great creative work with the following skills:

1. Videographer

2. Creative Services

3. Graphics/motion graphics

4. Audiovisual/Cinematography

5. Photographer

6. Title Design

7. Animation and Technical Illustration

We at EuKai believe you need not just a partner but a creative company that will provide all your company's video needs that needs to be done, such as:

Video Proposal, Storyboards, Pre-Production, Production, Post-Production, VR and 360° Videos, Broadcast, Web, Corporate Videos, Packaging Videos, Documentaries, Television, Online Ads, Promotional Videos, Broadcast News, International Channel, Multimedia, Music Videos, Musicals, Music Production

`temperature` - What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. We generally recommend altering this or top_p but not both.
  
`top_p` - An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.

We generally recommend altering this or temperature but not both.

(Definitions taken from https://docs.zeroqode.com/plugins/artifical-intelligence-gpt-3-by-openai)

In [None]:
#@title  { form-width: "300px" }
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

context = """Do I contradict myself?
Very well then I contradict myself,
(I am large, I contain multitudes.)"""

print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])

completion done in 13.52977204322815s
[1mDo I contradict myself?
Very well then I contradict myself,
(I am large, I contain multitudes.)[0m

St. Paul

28.8.08

I'm reading Barbara Ehrenreich's Nickle and Dimed. The most striking story she tells (and it's the lead off story) is about making her house a homestay and getting a teaching job as a way to pay for some of the bills. Here is the last line of her book.

It's like waking up in a strange house that's starting to fall apart, to have to figure out how to fix everything at once. And then when you do, you realize that there's more out there than you can imagine. In some ways it feels like winning the lottery. You're suddenly rich, and everything's fresh. Your laundry might even smell good for the first time in years.

24.7.08

I got back to LA on Saturday and as soon as I got to the house, I called work. I spoke to one of my co-workers and she said that our new chief, Sandeep, would call me this morning. He had a big lunch so I don'