In [None]:
# https://huggingface.co/stabilityai/stablelm-2-1_6b
import os
import time
import warnings

import torch
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

# Instantiates Arcade100kTokenizer (not yet documented as of 2024-01-20)
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-1_6b", trust_remote_code=True)
# print("Using %s" % tokenizer)

In [None]:
# Instantiates StableLMEpochForCausalLM
if torch.backends.mps.is_available():
    print("Using MPS")
    model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-2-1_6b", trust_remote_code=True)
    model = model.to(torch.device("mps"))
elif torch.cuda.is_available():
    print("Using CUDA")
    model = AutoModelForCausalLM.from_pretrained(
      "stabilityai/stablelm-2-1_6b",
      trust_remote_code=True,
      torch_dtype="auto",
    )
    model.cuda()
else:
    # On Intel without a GPU, the cuda backend will spew errors and warnings.
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    warnings.filterwarnings("ignore")
    model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-2-1_6b", trust_remote_code=True)
    try:
        import intel_extension_for_pytorch as ipex
        model = ipex.optimize(model, dtype=torch.float32)  # bfloat16
        print("Using Intel (accelerated)")
    except ImportError:
        print("Using Intel (non-accelerated)")
        pass
# print("Using %s" % model)

In [None]:
input = "Write a poem about the sky"

start = time.time()
tokens = model.generate(
  **tokenizer(input, return_tensors="pt").to(model.device),
  max_new_tokens=128,
  temperature=0.70,
  top_p=0.95,
  do_sample=True,
  streamer=TextStreamer(tokenizer, skip_prompt=True),
  pad_token_id=tokenizer.eos_token_id,
)
decoded = tokenizer.decode(tokens[0], skip_special_tokens=True)
words = len(decoded.split())
duration = time.time()-start
print(f"Generated {words} words in {duration:.1f}s, {(words/duration):.1f} words/s")