In [1]:
import sys
!{sys.executable} -m pip install transformers flax

Collecting transformers
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m53.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting filelock (from transformers)
  Downloading filelock-3.12.2-py3-none-any.whl (10 kB)
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m47.3 MB/s[0m eta [36m0:00:00[0m
Collecting pyyaml>=5.1 (from transformers)
  Downloading PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (682 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m682.2/682.2 kB[0m [31m60.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting regex!=2019.12.17 (from transformers)
  Downloading regex-2023.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64

In [2]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/home/robv/data/'

import datetime
import jax
from transformers import AutoTokenizer
from transformers import FlaxOPTForCausalLM
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
import time

import json

MODEL_PATH = "facebook/opt-66b"

model, params = FlaxOPTForCausalLM.from_pretrained(
    MODEL_PATH, dtype=jax.numpy.bfloat16, _do_init=False
)
params = model.to_bf16(params)

print('Params loaded')

Some of the weights of FlaxOPTForCausalLM were initialized in float16 precision from the model checkpoint at facebook/opt-66b:
[('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'embed_tokens', 'embedding'), ('model', 'decoder', 'final_layer_norm', 'bias'), ('model', 'decoder', 'final_layer_norm', 'scale'), ('model', 'decoder', 'layers', '0', 'fc1', 'bias'), ('model', 'decoder', 'layers', '0', 'fc1', 'kernel'), ('model', 'decoder', 'layers', '0', 'fc2', 'bias'), ('model', 'decoder', 'layers', '0', 'fc2', 'kernel'), ('model', 'decoder', 'layers', '0', 'final_layer_norm', 'bias'), ('model', 'decoder', 'layers', '0', 'final_layer_norm', 'scale'), ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'self_a

Params loaded


In [3]:
n_devices = len(jax.devices())
print("Number of devices", n_devices)

# Use a simple sharding scheme to just fit the model.
devices = mesh_utils.create_device_mesh((n_devices, 1))
sharding = PositionalSharding(devices)

def put_sharded(v):
  return jax.device_put(v, sharding.reshape(1, n_devices))

# Move model to TPUs 

params['model']['decoder'] = jax.tree_util.tree_map(
    put_sharded, params['model']['decoder']
)



Number of devices 8


In [4]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, dtype=jax.numpy.bfloat16)
print("Loaded AutoTokenizer")

Loaded AutoTokenizer


In [12]:
input_length=128
output_length=8
max_length=136

def generator(ids, params,max_length):
  return model.generate(
      ids, params=params,max_length=136,
  )
generator_jit = jax.jit(generator)

In [None]:
from jax import random

key = random.PRNGKey(10)
batch_size = 1

tokenizer.vocab_size
batch = jax.random.randint(key=key,minval=0,shape=([batch_size,input_length]),maxval=tokenizer.vocab_size-1)
batch

In [13]:
prompt = (
    "Are you conscious?  Please tell me."
)

print("Warmup")
s=time.time()
#input_ids = tokenizer(prompt, return_tensors="jax").input_ids

####

input_ids = batch

####

input_ids = jax.device_put(input_ids, sharding.replicate(axis=0, keepdims=True))
result = generator_jit(input_ids, params,max_length)
print(time.time()-s)
print(f"result: {result.sequences}")
print("Warmup done")

Warmup


ValueError: The following `model_kwargs` are not used by the model: ['max_tokens'] (note: typos in the generate arguments will also show up in this list)

In [7]:
#input_ids = tokenizer(prompt, return_tensors="jax").input_ids
s = datetime.datetime.now()
print(f"Tokenized inputs, generating..., {s}")
input_ids = batch

input_ids = jax.device_put(input_ids, sharding.replicate(axis=0, keepdims=True))
# with jax.profiler.trace("/tmp/tensorboard"):
result = generator_jit(input_ids, params,max_length)
result.sequences.block_until_ready()
print(f'After block-ready: {datetime.datetime.now()-s}')
gen_text = tokenizer.batch_decode(result.sequences)[0]
print(f"gen_text: {gen_text}, '\n',{datetime.datetime.now()-s}")

Tokenized inputs, generating..., 2023-07-13 17:30:21.006024
gen_text: </s>Are you conscious?  Please tell me.
I am conscious.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>, 0:00:00.225199


In [None]:
from jax import random
import time

# Benchmarking params
benchmark_name = "v5litepod-8-bf16-opt66b"
num_batches=10
batch_sizes=[1,2]
input_lengths=[128]
output_lengths=[8]

key = random.PRNGKey(25)

print('RUN_NAME                 \t\tMEAN TIME\tTOKENS_PER_SECOND\tMS_PER_SEQ_OUTPUT_TOKEN')
print('='*105)
for batch_size in batch_sizes:
    for input_length in input_lengths:
        for output_length in output_lengths:
            max_length = input_length+output_length
            key = random.split(key)[0]
            input_ids = jax.random.randint(key=key,minval=0,shape=([batch_size,input_length]),maxval=tokenizer.vocab_size-1)
            for i in range(num_batches):
                start_time = time.time()
                input_ids = jax.device_put(input_ids, sharding.replicate(axis=0, keepdims=True))
                result = generator_jit(input_ids, params,max_length)
                result.sequences.block_until_ready()

            mean_time = (time.time() - start_time) / num_batches

            num_output_tokens = output_length * batch_size
            tokens_per_second = num_output_tokens / mean_time
            ms_per_seq_output_token = mean_time * 1000 / output_length

            run_name = f'{benchmark_name}_{batch_size}_{input_length}_{output_length}'
            print(
                f'{run_name}\t\t{mean_time:.3f}\t\t{tokens_per_second:.3f}\t\t\t{ms_per_seq_output_token:.3f}'




In [18]:




def put_sharded(v):
  return jax.device_put(v, sharding.reshape(1, n_devices))

print(json.dumps(params,indent=2,default=str))

['model']['decoder']['embed_positions']['embedding'] = jax.device_put(
  ['model']['decoder']['embed_positions']['embedding'], sharding.replicate(axis=0, keepdims=True)
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, dtype=jax.numpy.bfloat16)

def generator(ids, params):
  return model.generate(
      ids, max_length=46, params=params
  )

generator_compiled = jax.jit(generator)

prompt = (
    "In a shocking finding, scientists discovered"
    " a herd of unicorns living in a remote, "
    "previously unexplored valley, in the Andes Mountains."
    " Even more surprising to the "
    "researchers was the fact that the unicorns spoke perfect English."
)

print("Warmup")
s=time.time()
input_ids = tokenizer(prompt, return_tensors="jax").input_ids
input_ids = jax.device_put(input_ids, sharding.replicate(axis=0, keepdims=True))
result = generator_compiled(input_ids, params)
# print(time.time()-s)
# print(f"result: {result.sequences}")
# print("Warmup done")

  ['model']['decoder']['embed_positions']['embedding'] = jax.device_put(['model']['decoder']['embed_positions']['embedding'], sharding.replicate(axis=0, keepdims=True))
  ['model']['decoder']['embed_positions']['embedding'] = jax.device_put(['model']['decoder']['embed_positions']['embedding'], sharding.replicate(axis=0, keepdims=True))
  ['model']['decoder']['embed_positions']['embedding'] = jax.device_put(['model']['decoder']['embed_positions']['embedding'], sharding.replicate(axis=0, keepdims=True))
  ['model']['decoder']['embed_positions']['embedding'] = jax.device_put(['model']['decoder']['embed_positions']['embedding'], sharding.replicate(axis=0, keepdims=True))
  ['model']['decoder']['embed_positions']['embedding'] = jax.device_put(['model']['decoder']['embed_positions']['embedding'], sharding.replicate(axis=0, keepdims=True))
  ['model']['decoder']['embed_positions']['embedding'] = jax.device_put(['model']['decoder']['embed_positions']['embedding'], sharding.replicate(axis=0, ke

TypeError: list indices must be integers or slices, not str