<a target="_blank" href="https://colab.research.google.com/github/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/examples/t5/t5_conversion_colab.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [1]:
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This is a simple colab showing how to re-author T5 (encoder-decoder) model,
# convert and run in a colab python environment.

Note: When running notebooks in this repository with Google Colab, some users may see
the following warning message:

![Colab warning](https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/data/colab_warning.jpg?raw=true)

Please click `Restart Session` and run again.

In [None]:
!pip install -r https://raw.githubusercontent.com/google-ai-edge/ai-edge-torch/main/requirements.txt
!pip install ai-edge-torch-nightly

## Download model checkpoint
First we download the T5 pytorch checkpoint from huggingface from https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base.

In [None]:
!curl -O -L https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base/resolve/main/pytorch_model.bin

## T5 Model Authoring and Conversion
Next, we import the T5 encoder/decoder implementation from `ai_edge_torch/generative/examples/t5`, and convert to TFLite with 2 signatures: `encode` and `decode`.

In [2]:
import numpy as np
import torch

import ai_edge_torch
from ai_edge_torch.generative.examples.t5 import t5
from ai_edge_torch.generative.quantize import quant_recipes


def convert_t5_to_tflite_multisig(checkpoint_path: str):
  config = t5.get_model_config_t5()
  # Temporarily disable HLFB until custom op issue is fixed.
  config.enable_hlfb = False
  embedding_layer = torch.nn.Embedding(
      config.vocab_size, config.embedding_dim, padding_idx=0
  )
  t5_encoder_model = t5.build_t5_encoder_model(config, embedding_layer, checkpoint_path)
  t5_decoder_model = t5.build_t5_decoder_model(config, embedding_layer, checkpoint_path)

  # encoder
  seq_len = 512
  prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
  prompt_e_token = [1, 2, 3, 4, 5, 6]
  prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
      prompt_e_token, dtype=torch.long
  )
  prefill_e_input_pos = torch.arange(0, seq_len)
  prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
  prompt_d_token = [1, 2, 3, 4, 5, 6]
  prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
      prompt_d_token, dtype=torch.long
  )
  prefill_d_input_pos = torch.arange(0, seq_len)

  # decoder
  decode_token = torch.tensor([[1]], dtype=torch.long)
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
  decode_d_token = torch.tensor([[1]], dtype=torch.long)
  decode_d_input_pos = torch.tensor([0], dtype=torch.int64)

  # Pad mask for self attention only on "real" tokens.
  # Pad with `-inf` for any tokens indices that aren't desired.
  pad_mask = torch.zeros([seq_len], dtype=torch.float32)
  hidden_states = torch.zeros((1, 512, 768), dtype=torch.float32)
  quant_config = quant_recipes.full_int8_dynamic_recipe()

  edge_model = ai_edge_torch.signature(
          'encode',
          t5_encoder_model.eval(),
          (
              prefill_e_tokens,
              prefill_e_input_pos,
              pad_mask,
          ),
      ).signature(
          'decode',
          t5_decoder_model.eval(),
          (
              hidden_states,
              decode_d_token,
              decode_d_input_pos,
              pad_mask,
          ),
      ).convert(quant_config=quant_config)

  edge_model.export('t5_encode_decode_2_sigs.tflite')
  return edge_model


Finally, we call the convert function, this might take a few minutes to finish.

In [None]:
print('converting T5 to tflite.')
edge_model = convert_t5_to_tflite_multisig("/content/pytorch_model.bin")