# Faster Text Generation with TensorFlow and XLA

This notebook is a companion to the 🤗 [blog post with the same title](https://huggingface.co/blog/tf-xla-generate). 
It is meant to illustrate how to use XLA with TensorFlow text generation.

It contains two stand-alone examples, one for encoder-decoder models and another for decoder-only models.

⚠️ If you are running this on colab, you might not have access to a GPU. The benefits of XLA are best observed with a GPU!

In [None]:
# Preparing the environment
!pip install transformers>=4.21.0

In [None]:
# Stand-alone TF XLA generate example for Encoder-Decoder Models.

# Note: execution times are deeply dependent on hardware.
# If you have a machine with a powerful GPU, I highly recommend you to try this example there!
import time
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

# 1. Load model and tokenizer
model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)

# 2. Prepare tokenization and generation arguments -- don't forget padding to avoid retracing!
tokenization_kwargs = {"pad_to_multiple_of": 32, "padding": True, "return_tensors": "tf"}
generation_kwargs = {"num_beams": 4, "max_new_tokens": 32}

# 3. Create your XLA generate function a̶n̶d̶ ̶m̶a̶k̶e̶ ̶P̶y̶T̶o̶r̶c̶h̶ ̶e̶a̶t̶ ̶d̶u̶s̶t̶
# This is the only change with respect to original generate workflow!
xla_generate = tf.function(model.generate, jit_compile=True)

# 4. Generate! Remember -- the first call will be slow, but all subsequent calls will be fast if you've done things right.
input_prompts = [
    f"translate English to {language}: I have four cats and three dogs." for language in ["German", "French", "Romanian"]
]
for input_prompt in input_prompts:
    tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)
    start = time.time_ns()
    generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
    end = time.time_ns()
    decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
    print(f"Original prompt -- {input_prompt}")
    print(f"Generated -- {decoded_text}")
    print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")

In [1]:
pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.4-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m49.3 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.4-py3-none-any.whl (200 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.1/200.1 KB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m108.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.13.4 tokenizers-0.13.3 transformers-4.27.4


In [2]:
!pip install --quiet bitsandbytes
!pip install --quiet accelerate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.2/84.2 MB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m215.3/215.3 KB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [60]:
# from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)

#######to split  a very large model between cpu and gpu

In [73]:
device_map = {
    "transformer.word_embeddings": 'cpu',
    "transformer.word_embeddings_layernorm": 'cpu',
    "lm_head": "cpu",
    "transformer.h": 'cpu',
    "transformer.ln_f": 'cpu',
}

In [3]:
# Stand-alone TF XLA generate example for Decoder-Only Models.

# Note: execution times are deeply dependent on hardware.
# If you have a machine with a powerful GPU, I highly recommend you to try this example there!
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# 1. Load model and tokenizer
name = "gpt2"

# remember: decoder-only models need left-padding
model_8bit = AutoModelForCausalLM.from_pretrained(name,load_in_8bit=True,device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(name)
tokenizer.pad_token = tokenizer.eos_token


# 3. Create your XLA generate function a̶n̶d̶ ̶m̶a̶k̶e̶ ̶P̶y̶T̶o̶r̶c̶h̶ ̶e̶a̶t̶ ̶d̶u̶s̶t̶
# This is the only change with respect to original generate workflow!


Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]



Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 118
CUDA SETUP: Loading binary /usr/local/lib/python3.9/dist-packages/bitsandbytes/libbitsandbytes_cuda118.so...


  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)


Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [6]:
import tensorflow as tf

In [18]:

# 2. Prepare tokenization and generation arguments -- don't forget padding to avoid retracing!
tokenization_kwargs = {"padding": True, "return_tensors": "tf"}
generation_kwargs = {"num_beams": 4, "max_new_tokens": 32}

In [25]:
def most_likely_next_token(input_ids,attention_mask):
    model_output = model_8bit.generate(input_ids=input_ids,attention_mask=attention_mask,**generation_kwargs)
    return tf.argmax(model_output.logits[:, -1, :], axis=-1)

# print("Calling regular function with TensorFlow code...")
# most_likely_next_token(inputs)

In [26]:
# 4. Generate! Remember -- the first call will be slow, but all subsequent calls will be fast if you've done things right.
xla_generate = tf.function(most_likely_next_token, jit_compile=True)
input_prompts = ["The best thing about Spain"]
inputs = tokenizer(input_prompts, **tokenization_kwargs)

In [28]:
inputs = tokenizer(["TensorFlow is"], return_tensors="tf")

In [27]:
start = time.time_ns()
generated_text = xla_generate(inputs['input_ids'],inputs['attention_mask'])
end = time.time_ns()
decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [30]:
print(f"Original prompt -- {input_prompt}")
print(f"Generated -- {decoded_text}")
print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")

In [None]:
decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
print(f"Original prompt -- {input_prompt}")
print(f"Generated -- {decoded_text}")
print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")