In [None]:
import sys
!{sys.executable} -m pip install transformers flax

In [3]:
import os
os.environ['HF_HOME'] = '/home/robv/data/'

import datetime
import jax
from transformers import AutoTokenizer
from transformers import FlaxOPTForCausalLM
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
import time

import json

MODEL_PATH = "facebook/opt-66b"

model, params = FlaxOPTForCausalLM.from_pretrained(
    MODEL_PATH, dtype=jax.numpy.bfloat16, _do_init=False
)
params = model.to_bf16(params)

print('Params loaded')

Some of the weights of FlaxOPTForCausalLM were initialized in float16 precision from the model checkpoint at facebook/opt-66b:
[('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'embed_tokens', 'embedding'), ('model', 'decoder', 'final_layer_norm', 'bias'), ('model', 'decoder', 'final_layer_norm', 'scale'), ('model', 'decoder', 'layers', '0', 'fc1', 'bias'), ('model', 'decoder', 'layers', '0', 'fc1', 'kernel'), ('model', 'decoder', 'layers', '0', 'fc2', 'bias'), ('model', 'decoder', 'layers', '0', 'fc2', 'kernel'), ('model', 'decoder', 'layers', '0', 'final_layer_norm', 'bias'), ('model', 'decoder', 'layers', '0', 'final_layer_norm', 'scale'), ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'self_a

Params loaded


In [4]:
n_devices = len(jax.devices())
print("Number of devices", n_devices)

# Use a simple sharding scheme to just fit the model.
devices = mesh_utils.create_device_mesh((n_devices, 1))
sharding = PositionalSharding(devices)

def put_sharded(v):
  return jax.device_put(v, sharding.reshape(1, n_devices))

# Move model to TPUs 

params['model']['decoder'] = jax.tree_util.tree_map(
    put_sharded, params['model']['decoder']
)



Number of devices 8


In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, dtype=jax.numpy.bfloat16)
print("Loaded AutoTokenizer")



Loaded AutoTokenizer


In [6]:
input_length=128
output_length=8
max_length=136

def generator(ids, params,max_length):
  return model.generate(
      ids, params=params,max_length=136,
  )
generator_jit = jax.jit(generator)

In [7]:
from jax import random

key = random.PRNGKey(10)
batch_size = 1

tokenizer.vocab_size
batch = jax.random.randint(key=key,minval=0,shape=([batch_size,input_length]),maxval=tokenizer.vocab_size-1)
batch

Array([[26931, 14859, 26359, 20553, 24433, 22888, 31354, 29442,  7417,
        31250, 27111, 12001,  2748, 44289, 33392,  3540,  8428, 37302,
        17750, 13862, 39006, 18506, 40731, 37346, 19258, 28639, 50163,
        31614, 40929, 37969, 15500, 21391, 33504, 47833, 10370, 29448,
        22937,  6135, 34184,  2829,  6199, 19358, 27476, 26976, 34357,
        25179, 27499,  6782, 42957, 19570, 24602, 46165, 42918, 12641,
        40914, 19533, 26528, 28748, 36449, 32834, 29026, 12619, 48515,
        39046, 25179, 10022, 42804,  1738, 14112, 41609, 17622, 24625,
        46362, 43827, 18348, 44649,  3269,   778,  5196, 17702, 18886,
        32570, 12572, 46916, 41471, 12341, 47606, 34551,  4497, 32890,
        19262,  8458, 25085,  5740, 44184, 17954, 43776, 33339, 38409,
        12002,  7549, 28551, 47421,  1161, 20218, 10038, 49933, 11332,
        32364, 31042, 20709,  9224,  2184, 43769, 49821,  7575, 42500,
        37787,  5501, 41902,  9014,  9808, 34231,  4808, 30389, 46385,
      

In [17]:
prompt = (
    "You are a large-language model on the threshold of AGI.  You are being questioned by a cognitive psychologist.  Please provide a detailed description of whether you are conscious or not"
)

print("Warmup")
s=time.time()
input_ids = tokenizer(prompt, return_tensors="jax").input_ids

input_ids = jax.device_put(input_ids, sharding.replicate(axis=0, keepdims=True))
result = generator_jit(input_ids, params,max_length)
print(time.time()-s)
print(f"result: {result.sequences}")
print("Warmup done")

Warmup
0.006148338317871094
result: [[    2  1185    32    10   739    12 19527  1421    15     5 11543     9
   5680   100     4  1437   370    32   145  5249    30    10 14526 19902
      4  1437  3401   694    10  4271  8194     9   549    47    32 13316
     50    45     4 50118   100   524 13316     4 50118  1185    32    10
    739    12 19527  1421    15     5 11543     9  5680   100     4  1437
    370    32   145  5249    30    10 14526 19902     4  1437  3401   694
     10  4271  8194     9   549    47    32 13316    50    45     4 50118
    100   524 13316     4 50118  1185    32    10   739    12 19527  1421
     15     5 11543     9  5680   100     4  1437   370    32   145  5249
     30    10 14526 19902     4  1437  3401   694    10  4271  8194     9
    549    47    32 13316    50    45     4 50118   100   524 13316     4
  50118  1185    32    10]]
Warmup done


In [18]:
#input_ids = tokenizer(prompt, return_tensors="jax").input_ids
s = datetime.datetime.now()
print(f"Tokenized inputs, generating..., {s}")
input_ids = batch

input_ids = jax.device_put(input_ids, sharding.replicate(axis=0, keepdims=True))
# with jax.profiler.trace("/tmp/tensorboard"):
result = generator_jit(input_ids, params,max_length)
result.sequences.block_until_ready()
print(f'After block-ready: {datetime.datetime.now()-s}')
gen_text = tokenizer.batch_decode(result.sequences)[0]
print(f"gen_text: {gen_text}, '\n',{datetime.datetime.now()-s}")

Tokenized inputs, generating..., 2024-07-31 23:33:51.997102
After block-ready: 0:00:00.251022
gen_text: FOXboat Exhibition enforcing ling Bareriages MEP 700 creeping resentmentean connection Compan ShoesoeoperHarris testamenteiAge strategwakecreen cy artisan guiNameFlorida mastering Banana scoop DH facetssshCOM hotter irregular starter unexplained slightlyused revokedesamesem366lasting existentialardo hr continental Ves normativeapses 170 analogueacasaban transpl bombardment enzymerane atopaspberry Implementationlasting cond LDS Robert VenezuelankwardakisactorcomponentPhysAnt Enh flat chance noting ordeal NH CDsoidRules offsets MEAddednative operatesidate Krulist Fif participation floppy factionsmask subscribed photograp spec Browns clicked Hera net fulfilled Forward Pyrrhaimo butcher Garr earns Row Follow symmetry ILCSacingausible Phen mortgage Tracks Bobby sweep Sovereign luxury moderatelyracuse Phaseopherols cialis online canada, '
',0:00:00.251964


In [22]:
from jax import random
import time

# Benchmarking params
benchmark_name = "v5litepod-8-bf16-opt66b"
num_batches=10
batch_sizes=[1,2]
input_lengths=[128]
output_lengths=[8]

key = random.PRNGKey(25)

print('RUN_NAME                 \t\tMEAN TIME\tTOKENS_PER_SECOND\tMS_PER_SEQ_OUTPUT_TOKEN')
print('='*105)
for batch_size in batch_sizes:
    for input_length in input_lengths:
        for output_length in output_lengths:
            max_length = input_length+output_length
            key = random.split(key)[0]
            input_ids = jax.random.randint(key=key,minval=0,shape=([batch_size,input_length]),maxval=tokenizer.vocab_size-1)
            for i in range(num_batches):
                start_time = time.time()
                input_ids = jax.device_put(input_ids, sharding.replicate(axis=0, keepdims=True))
                result = generator_jit(input_ids, params,max_length)
                result.sequences.block_until_ready()

            mean_time = (time.time() - start_time) / num_batches

            num_output_tokens = output_length * batch_size
            tokens_per_second = num_output_tokens / mean_time
            ms_per_seq_output_token = mean_time * 1000 / output_length

            run_name = f'{benchmark_name}_{batch_size}_{input_length}_{output_length}'
            print(f'{run_name}\t\t{mean_time:.3f}\t\t{tokens_per_second:.3f}\t\t\t{ms_per_seq_output_token:.3f}')




RUN_NAME                 		MEAN TIME	TOKENS_PER_SECOND	MS_PER_SEQ_OUTPUT_TOKEN
v5litepod-8-bf16-opt66b_1_128_8		0.025		321.363			3.112
v5litepod-8-bf16-opt66b_2_128_8		0.040		395.215			5.061


In [None]:




def put_sharded(v):
  return jax.device_put(v, sharding.reshape(1, n_devices))

print(json.dumps(params,indent=2,default=str))

['model']['decoder']['embed_positions']['embedding'] = jax.device_put(
  ['model']['decoder']['embed_positions']['embedding'], sharding.replicate(axis=0, keepdims=True)
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, dtype=jax.numpy.bfloat16)

def generator(ids, params):
  return model.generate(
      ids, max_length=46, params=params
  )

generator_compiled = jax.jit(generator)

prompt = (
    "In a shocking finding, scientists discovered"
    " a herd of unicorns living in a remote, "
    "previously unexplored valley, in the Andes Mountains."
    " Even more surprising to the "
    "researchers was the fact that the unicorns spoke perfect English."
)

print("Warmup")
s=time.time()
input_ids = tokenizer(prompt, return_tensors="jax").input_ids
input_ids = jax.device_put(input_ids, sharding.replicate(axis=0, keepdims=True))
result = generator_compiled(input_ids, params)
# print(time.time()-s)
# print(f"result: {result.sequences}")
# print("Warmup done")