<a href="https://colab.research.google.com/github/mosmos6/MTJ-on-TPU0.2/blob/main/GPT_J_inference_on_TPU_driver0_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPT-J inference on TPU_driver0.2

## This is a colab demo to infer with modified mesh-transformer-jax on TPU_driver0.2. You need high memory TPU runtime.

#Load your data from your google cloud bucket (if it's your case)

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
!echo "deb http://packages.cloud.google.com/apt gcsfuse-bionic main" > /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
!apt -qq update
!apt -qq install gcsfuse

In [None]:
!mkdir folderOnColab
!gcsfuse --implicit-dirs - your bucket - folderOnColab

#Installing dependencies

In [None]:
pip install --upgrade pip

In [None]:
# For jax 0.3.25, you must install this BEFORE installing jax.
# However for jax 0.3.5, you must install this after installing jax for some reason.

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

Extract "mesh_transformer" folder and requirements.txt file from [this repo](https://github.com/mosmos6/MTJ-on-TPU0.2), replace those with the original ones in your repo.

In [None]:
!git clone https://github.com/<your repo>
!pip install -r <your repo>/requirements.txt
!pip install <your repo>/

In [None]:
!pip install 'jax[tpu]==0.3.25' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

#Setup model

In [None]:
import os
import requests 
import jax
import jax.config
import jax.tools
import jax.tools.colab_tpu

tpu_driver = 'tpu_driver0.2'
tpu_addr = os.environ['COLAB_TPU_ADDR']
colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
#url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
url = f'http://{colab_tpu_addr}:8475/requestversion/{tpu_driver}'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
jax.tools.colab_tpu.TPU_DRIVER_MODE = 1
jax.config.FLAGS.jax_xla_backend = "tpu_driver"
jax.config.FLAGS.jax_backend_target = f"grpc://{tpu_addr}"

In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer
from mesh_transformer.transformer_shard import CausalTransformerV2

In [None]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "d_head": 256,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]

params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

global_mesh = maps.Mesh(devices, ('dp', 'mp'))
maps.thread_resources.env = maps.ResourceEnv(physical_mesh=global_mesh, loops=())

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

#Create the network and load your parameters

In [None]:
# Here we use read_ckpt instead of read_ckpt_lowmem because lowmem gets stuck forever in inference for some reason.

total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt(network.state, "/content/folderOnColab/<your data>/step_10/", devices.shape[1])

local_shards = max(jax.local_device_count() // mesh_shape[1], 1)
del network.state["opt_state"]
network.state = network.move_xmap(network.state, np.zeros(local_shards))

#Run model

In [None]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
def infer(context, top_p=0.9, temp=0.9, gen_len=100):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    pad_amount = max(pad_amount, 0) 

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)[-seq:] 
    batched_tokens = np.array([padded_tokens] * total_batch) 
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    print(f"completion done in {time.time() - start:06}s")
    return samples

print(infer("Google colab is")[0])

In [None]:
#@title  { form-width: "300px" }
top_p = 1 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

context = """Your context here"""


print(infer(top_p=top_p, temp=temp, gen_len=100, context=context)[0])

*Enjoy!*