<a href="https://colab.research.google.com/github/mosmos6/MTJ-on-TPU0.2/blob/main/GPT_J_inference_attempt_on_TPU_driver0_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPT-J inference attempt on TPU_driver0.2

Load your data from your google cloud bucket (if it's your case)

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
!echo "deb http://packages.cloud.google.com/apt gcsfuse-bionic main" > /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
!apt -qq update
!apt -qq install gcsfuse

In [None]:
!mkdir folderOnColab
!gcsfuse --implicit-dirs <your bucket> folderOnColab

Installing dependencies

In [None]:
pip install --upgrade pip

In [None]:
!git clone https://github.com/<your repo>
!pip install -r <your repo>/requirements.txt
!pip install <your repo>/

In [None]:
!pip install 'jax[tpu]==0.3.5' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

Setup model

In [None]:
import os
import requests 
import jax
import jax.config
import jax.tools
import jax.tools.colab_tpu

#nightly behaves the same

tpu_driver = 'tpu_driver0.2'
tpu_addr = os.environ['COLAB_TPU_ADDR']
colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/{tpu_driver}'
requests.post(url)

jax.tools.colab_tpu.TPU_DRIVER_MODE = 1
jax.config.FLAGS.jax_xla_backend = "tpu_driver"
jax.config.FLAGS.jax_backend_target = f"grpc://{tpu_addr}"

In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

In [None]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

global_mesh = maps.Mesh(devices, ('dp', 'mp'))
maps.thread_resources.env = maps.ResourceEnv(physical_mesh=global_mesh, loops=())

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

Create the network and load your parameters

In [None]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt_lowmem(network.state, "/content/folderOnColab/<your data>/step_10/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))