In [1]:
# !pip install transformers --upgrade

In [2]:
model_dir = "../mii/bloom-3b"

In [3]:
import torch

In [4]:
from transformers import BloomForCausalLM
model = BloomForCausalLM.from_pretrained(model_dir)#, torch_dtype=torch.float16)

In [5]:
model.dtype

torch.float32

In [6]:
from transformers import pipeline, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_dir)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)

## Greedy Search

In [7]:
start_text = "Testing BLOOM-3B without DeepSpeed (greedy)"
tokens_start_text = len(tokenizer(start_text, return_tensors="pt").input_ids[0])
tokens_start_text

13

In [8]:
import time

new_tokens = 1000
gen_length = new_tokens + tokens_start_text

t0 = time.time()
gen_text = pipe(start_text, min_length=gen_length, max_length=gen_length)[0]['generated_text']
t1 = time.time()
tokens_gen_text = len(tokenizer(gen_text, return_tensors="pt").input_ids[0])

In [9]:
total_new_tokens_generated = tokens_gen_text - tokens_start_text
througput = (total_new_tokens_generated) / (t1 - t0)
print(f"Tokens generated: {total_new_tokens_generated}; Time: {t1 - t0:.1f} seconds; Tokens per second: {througput:.1f}")

Tokens generated: 1000; Time: 32.4 seconds; Tokens per second: 30.9


## Sampling

In [10]:
start_text = "Testing BLOOM-3B without DeepSpeed (sampling)"
tokens_start_text = len(tokenizer(start_text, return_tensors="pt").input_ids[0])
tokens_start_text

13

In [11]:
new_tokens = 1000
gen_length = new_tokens + tokens_start_text

t0 = time.time()
gen_text = pipe(start_text, min_length=gen_length, max_length=gen_length, do_sample=True, top_k=50)[0]['generated_text']
t1 = time.time()
tokens_gen_text = len(tokenizer(gen_text, return_tensors="pt").input_ids[0])

In [12]:
total_new_tokens_generated = tokens_gen_text - tokens_start_text
througput = (total_new_tokens_generated) / (t1 - t0)
print(f"Tokens generated: {total_new_tokens_generated}; Time: {t1 - t0:.1f} seconds; Tokens per second: {througput:.1f}")

Tokens generated: 1000; Time: 33.6 seconds; Tokens per second: 29.8
