
* Jason Lee, 2022-05-15, for Patent-GPT-J project

* The following is based on https://github.com/kingoflolz/mesh-transformer-jax/blob/master/device_sample.py

* The purpose of this code is to show that PatentGPT-J is capable of doing sentiment analysis in few-shot learning. 

* Tested ok: PatentGPT-J-279M and PatentGPT-J-456M.
  * PatentGPT-J-1.6B is out of memory even on Colab Pro. (not tested: Colab Pro+)

* Reference: 
  * https://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb#scrollTo=n7xAFw-LOYfe
Reference: 
  * https://colab.research.google.com/drive/17zvUhLcpjUKJdTRg00HYdGMEN3uoMy-M?usp=sharing#scrollTo=Wg3x-WQStYHC
  * https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/GPT-J-6B/Inference_with_GPT_J_6B.ipynb#scrollTo=PYeldWLtFlvi


  


In [1]:
proj = "PatentGPT-J-456M" #@param ["PatentGPT-J-456M", "PatentGPT-J-279M", "GPT-J-6B"]

In [2]:
!apt install zstd
!pip install -q pip==20.3.1
# for avoiding --> pip takes too long to resolve conflicting dependencies
# https://github.com/pypa/pip/issues/9517

!pip install jaxlib==0.1.67 
!pip install jax==0.2.12
!pip install tensorflow==2.5.0   # 2.8.0 won't work?

#jax==0.2.22, not workable 
#!pip install -q "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

!pip install -q optax==0.0.9
!pip install -q transformers==4.18.0
!pip install -q ray[default]==1.5.1
!pip install -q smart_open[gcs]
!pip install -q dm-haiku==0.0.5
!pip install -q einops==0.3.0

!pip install chex==0.1.2
# Chex 0.1.3 doesn't support JAX 0.2.12. You need to downgrade to Chex 0.1.2
# https://github.com/kingoflolz/mesh-transformer-jax/issues/221
# https://github.com/kingoflolz/mesh-transformer-jax/issues/43

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following packages were automatically installed and are no longer required:
  libnvidia-common-460 nsight-compute-2020.2.0
Use 'apt autoremove' to remove them.
The following NEW packages will be installed:
  zstd
0 upgraded, 1 newly installed, 0 to remove and 42 not upgraded.
Need to get 278 kB of archives.
After this operation, 1,141 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 zstd amd64 1.3.3+dfsg-2ubuntu1.2 [278 kB]
Fetched 278 kB in 1s (337 kB/s)
Selecting previously unselected package zstd.
(Reading database ... 155203 files and directories currently installed.)
Preparing to unpack .../zstd_1.3.3+dfsg-2ubuntu1.2_amd64.deb ...
Unpacking zstd (1.3.3+dfsg-2ubuntu1.2) ...
Setting up zstd (1.3.3+dfsg-2ubuntu1.2) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
[K     |████████████████████████████████| 1.5 MB 5.2

[?25l[K     |██▊                             | 10 kB 30.8 MB/s eta 0:00:01[K     |█████▌                          | 20 kB 10.6 MB/s eta 0:00:01[K     |████████▎                       | 30 kB 9.0 MB/s eta 0:00:01[K     |███████████                     | 40 kB 8.4 MB/s eta 0:00:01[K     |█████████████▊                  | 51 kB 4.4 MB/s eta 0:00:01[K     |████████████████▌               | 61 kB 5.2 MB/s eta 0:00:01[K     |███████████████████▎            | 71 kB 5.6 MB/s eta 0:00:01[K     |██████████████████████          | 81 kB 5.8 MB/s eta 0:00:01[K     |████████████████████████▉       | 92 kB 6.4 MB/s eta 0:00:01[K     |███████████████████████████▌    | 102 kB 5.2 MB/s eta 0:00:01[K     |██████████████████████████████▎ | 112 kB 5.2 MB/s eta 0:00:01[K     |████████████████████████████████| 118 kB 5.2 MB/s 
[K     |████████████████████████████████| 72 kB 638 kB/s 
[K     |████████████████████████████████| 4.0 MB 4.7 MB/s 
[K     |███████████████████████████████

In [3]:
import os
import transformers

#proj = 'PatentGPT-J-1.6B'  # out of memory on Colab Pro
#proj = 'GPT-J-6B'  # https://github.com/kingoflolz/mesh-transformer-jax

params_file = ''
params_dict = {'PatentGPT-J-279M':'pgj_d_1024_layer_14.json', 
               'PatentGPT-J-456M': 'pgj_d_1024.json', 
               'PatentGPT-J-1.6B': 'pgj_d_2048.json'}

print('project: %s' % proj)
if proj.startswith('PatentGPT'):               
  params_file = params_dict[proj]
  size = proj[proj.find('-J-')+3:]
  url_params = f'https://huggingface.co/patent/patentgpt-j-{size}/raw/main/{params_file}' 
  if os.path.exists(params_file) == False:
    !wget $url_params
    print('Downloaded: %s' % params_file)
  else:
    print('Existed: %s' % params_file)

  encoder_path = 'bpe_output'
  url_bpe = f'https://huggingface.co/patent/patentgpt-j-{size}/resolve/main/{encoder_path}.tgz'
  if os.path.exists(encoder_path) == False:
    !wget $url_bpe
    cmd = f"tar xvfz {encoder_path}.tgz"
    !$cmd     
    print('Downloaded: %s' % encoder_path)
  else:
    print('Existed: %s' % encoder_path)
  #tokenizer = transformers.GPT2TokenizerFast(tokenizer_file='%s/tokenizer.json' % encoder_path)

  ckpt_folder = 'step_350000/'
  if proj in ['PatentGPT-J-456M', 'PatentGPT-J-279M']:
    url_step = f'https://huggingface.co/patent/patentgpt-j-{size}/resolve/main/step.tgz'
    if os.path.exists(ckpt_folder):
      print('Existed: %s' % ckpt_folder)
    else:
      !wget $url_step
  elif proj == 'PatentGPT-J-1.6B': 
    # a single file is too large --> split into multiple files
    if os.path.exists(ckpt_folder):
      print('Existed: %s' % ckpt_folder)
    else:
      final_fn = 'step.tgz'
      for i in range(ord('a'), ord('q')+1):
        ch = chr(i)
        fn = 'step_350000.tgz.parta%s' % chr(i)
        print('donwloading: %s' % fn)
        url = 'https://huggingface.co/patent/patentgpt-j-1.6B/resolve/main/%s' % fn
        !wget $url
        !cat $fn >> $final_fn
        !rm $fn
  !tar xfz step.tgz
else: # using original GPT-J-6B
  params_file = os.path.join('mesh-transformer-jax', 'configs', '6B_roto_256.json')
  d_model = 4096
  encoder_path = 'gpt2'
  ckpt_folder = 'step_383500/'
  if os.path.exists(ckpt_folder) == False:
    !time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
    !time tar -I zstd -xf step_383500_slim.tar.zstd
    print('Downloaded: %s' % ckpt_folder)    
  else:
    print('Existed: %s' % ckpt_folder)

if os.path.exists('mesh-transformer-jax') == False:
  !git clone https://github.com/kingoflolz/mesh-transformer-jax.git
  !pip install mesh-transformer-jax/

print('Checkpoint is ready.')

project: PatentGPT-J-456M
--2022-05-16 03:44:12--  https://huggingface.co/patent/patentgpt-j-456M/raw/main/pgj_d_1024.json
Resolving huggingface.co (huggingface.co)... 34.197.58.156, 3.210.158.153, 18.214.24.217, ...
Connecting to huggingface.co (huggingface.co)|34.197.58.156|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 814 [application/json]
Saving to: ‘pgj_d_1024.json’


2022-05-16 03:44:12 (168 MB/s) - ‘pgj_d_1024.json’ saved [814/814]

Downloaded: pgj_d_1024.json
--2022-05-16 03:44:12--  https://huggingface.co/patent/patentgpt-j-456M/resolve/main/bpe_output.tgz
Resolving huggingface.co (huggingface.co)... 34.197.58.156, 3.210.158.153, 18.214.24.217, ...
Connecting to huggingface.co (huggingface.co)|34.197.58.156|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/51/6d/516d501174169d14a8973546dea0bb30d839231727da773bc6c65ab82a5f5595/6541333d806f4abdc7b0f4503ed39183c4fd0a6a901f01e2bbdc608951

In [4]:
%%writefile sentiment_analysis.py
import argparse
import json
import time
import jax
import numpy as np
import optax
import requests 
import jax.tools.colab_tpu
import transformers
import os

from mesh_transformer import util
from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer
from jax.experimental import maps
from mesh_transformer.util import clip_by_global_norm
from jax.config import config
from smart_open import open

import pdb

# jax.tools.colab_tpu.setup_tpu()
# print(jax.devices())

# Can the following solve the issue:
# 2022-05-16 01:54:29.817720: W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:606] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program.
colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

context = \
"""Message: Support has been terrible for 2 weeks...
Sentiment: Negative
###
Message: I love your API, it is simple and so fast!
Sentiment: Positive
###
Message: It is really bad! How could it be possible? 
Sentiment: Negative
###
Message: The API is great. It is really good!
Sentiment: Positive
###
Message: GPT-J has been released 2 months ago.
Sentiment: Neutral
###
Message: Your team has been amazing, thanks!
Sentiment:"""

max_count = 100
top_p = 0.9
temp = 0.75
count = positive = negative = neutral = others = 0

def parse_args():
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default=None)
    parser.add_argument("--encoder_path", type=str)
    parser.add_argument("--ckpt_folder", type=str)
    args = parser.parse_args()
    return args

def generate_one_record(tokenizer, network, tokens, seq):
  global count, positive, negative, neutral, others

  start = time.time()
  provided_ctx = len(tokens)
  pad_amount = seq - provided_ctx
  padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
  batched_tokens = np.array([padded_tokens] * total_batch)
  length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

  gen_length = 8 #512
  output = network.generate(batched_tokens, length, gen_length, 
    {"top_p": np.ones(total_batch) * top_p, 
    "temp": np.ones(total_batch) * temp})
  gen_text = ''
  for idx, o in enumerate(output[1][0][:, :, 0]):
    gen_text = str(tokenizer.decode(o)).strip()
    if gen_text.startswith('Positive'):
      positive += 1
    elif gen_text.startswith('Negative'):
      negative += 1
    elif gen_text.startswith('Neutral'):
      neutral += 1
    else:
      others += 1

  count += 1
  print('[ %s ][ positive: %s][ negative: %s][ neutral: %s][ others: %s] text = [%s]' % 
        (count, positive, negative, neutral, others, gen_text))
  
  # provided_ctx = len(tokens)
  # pad_amount = seq - provided_ctx
  # padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
  # batched_tokens = np.array([padded_tokens] * total_batch)
  # length = np.ones(total_batch, dtype=np.uint32) * len(tokens)
  # start = time.time()
  # gen_len = 8
  # output = network.generate(batched_tokens, length, gen_len, 
  #   {"top_p": np.ones(total_batch) * top_p, 
  #     "temp": np.ones(total_batch) * temp})

  # samples = []
  # decoded_tokens = output[1][0]
  # for o in decoded_tokens[:, :, 0]:
  #   gen_text = str(tokenizer.decode(o)).strip()
  #   print('gen_text: %s' % gen_text)
  # print(f"completion done in {time.time() - start:06}s")

if __name__ == "__main__":
    args = parse_args()
    params = json.load(open(args.config))

    gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
    per_replica_batch = params["per_replica_batch"]
    cores_per_replica = params["cores_per_replica"]

    assert cores_per_replica <= 8

    bucket = params["bucket"]
    model_dir = params["model_dir"]
    layers = params["layers"]
    d_model = params["d_model"]
    n_heads = params["n_heads"]
    n_vocab = params["n_vocab"]
    seq = params["seq"]
    norm = params["norm"]

    params["sampler"] = nucleaus_sample
    opt = optax.chain(
        optax.scale(1 / gradient_accumulation_steps),
        clip_by_global_norm(1),
        optax.scale_by_adam(),
        optax.additive_weight_decay(0),
        optax.scale(-1),
        optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))
    )

    start = time.time()
    print(f"jax devices: {jax.device_count()}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)
    total_batch = per_replica_batch * jax.device_count() // cores_per_replica

    if args.encoder_path == 'gpt2':
      params["optimizer"] = optax.scale(0)
      tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
      tokens = tokenizer.encode(context)

      maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))
      network = CausalTransformer(params)
      network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1])
      network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

      for i in range(max_count):
        generate_one_record(tokenizer, network, tokens, seq)
    else:
      ckpt_folder = args.ckpt_folder # '/content/step_350000/'
      print(f"using checkpoint {ckpt_folder}")
      params["optimizer"] = opt
      tokenizer_file='%s/tokenizer.json' % args.encoder_path
      print('tokenizer_file: %s' % tokenizer_file) 
      tokenizer = transformers.GPT2TokenizerFast(tokenizer_file=tokenizer_file)
      tokens = tokenizer.encode(context)

      with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
          network = CausalTransformer(params)
          start = time.time()
          network.state = read_ckpt(network.state, ckpt_folder, devices.shape[1])
          #network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1])
          print(f"network loaded in {time.time() - start:.06}s")

          local_shards = max(jax.local_device_count() // mesh_shape[1], 1)
          del network.state["opt_state"]
          network.state = network.move_xmap(network.state, np.zeros(local_shards))

          while True:
              generate_one_record(tokenizer, network, tokens, seq)
              if count >= max_count:
                  print('done')
                  break


Writing sentiment_analysis.py


In [5]:
cmd = f"python sentiment_analysis.py --config {params_file} " \
  f"--encoder_path {encoder_path} --ckpt_folder {ckpt_folder}"
!$cmd 

jax devices: 8
jax runtime initialized in 23.9388s
using checkpoint step_350000/
tokenizer_file: bpe_output/tokenizer.json
  warn("xmap is an experimental feature and probably has bugs!")
key shape (8, 2)
in shape (1, 2048)
dp 1
mp 8
Total parameters: 456418528
read from disk/gcs in 7.23076s
network loaded in 11.1466s
2022-05-16 03:49:36.650143: W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:606] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program.
[ 1 ][ positive: 0][ negative: 0][ neutral: 0][ others: 1] text = [Quickness
###
Message:]
[ 2 ][ positive: 1][ negative: 0][ neutral: 0][ others: 1] text = [Positive
###
Message: Not]
[ 3 ][ positive: 2][ negative: 0][ neutral: 0][ others: 1] text = [Positive
###
Message: B]
[ 4 ][ positive: 3][ negative: 0][ neutral: 0][ others: 1] text = [Positive
###
Message: A]
[ 5 ][ positive: 4][ negative: 0][ neutral: 0][ others: 1] text = [Posi