Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added keras_hub/.DS_Store
Binary file not shown.
9 changes: 5 additions & 4 deletions keras_hub/src/models/t5/t5_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ class T5Backbone(Backbone):
projections in the multi-head attention layers. Defaults to
hidden_dim / num_heads.
dropout: float. Dropout probability for the Transformer layers.
activation: activation function (or activation string name). The
activation to be used in the inner dense blocks of the
Transformer layers. Defaults to `"relu"`.
activation: string. The activation function to use in the dense blocks
of the Transformer Layers.
use_gated_activation: boolean. Whether to use activation gating in
the inner dense blocks of the Transformer layers.
the inner dense blocks of the Transformer layers. When used with
the GELU activation function, this is referred to as GEGLU
(gated GLU) from https://arxiv.org/pdf/2002.05202.
The original T5 architecture didn't use gating, but more
recent versions do. Defaults to `True`.
layer_norm_epsilon: float. Epsilon factor to be used in the
Expand Down
32 changes: 31 additions & 1 deletion keras_hub/src/models/t5/t5_presets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""XLM-RoBERTa model preset configurations."""
"""T5 model preset configurations."""

backbone_presets = {
"t5_small_multi": {
Expand All @@ -14,6 +14,16 @@
},
"kaggle_handle": "kaggle://keras/t5/keras/t5_small_multi/2",
},
"t5_1.1_small": {
"metadata": {
"description": (""),
"params": 60511616,
"official_name": "T5 1.1",
"path": "t5",
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
},
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_small",
},
"t5_base_multi": {
"metadata": {
"description": (
Expand All @@ -27,6 +37,16 @@
},
"kaggle_handle": "kaggle://keras/t5/keras/t5_base_multi/2",
},
"t5_1.1_base": {
"metadata": {
"description": (""),
"params": 247577856,
"official_name": "T5 1.1",
"path": "t5",
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
},
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_base",
},
"t5_large_multi": {
"metadata": {
"description": (
Expand All @@ -40,6 +60,16 @@
},
"kaggle_handle": "kaggle://keras/t5/keras/t5_large_multi/2",
},
"t5_1.1_large": {
"metadata": {
"description": (""),
"params": 750251008,
"official_name": "T5 1.1",
"path": "t5",
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
},
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_large",
},
"flan_small_multi": {
"metadata": {
"description": (
Expand Down
Binary file added tools/.DS_Store
Binary file not shown.
110 changes: 92 additions & 18 deletions tools/checkpoint_conversion/convert_t5_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,98 @@
from absl import app
from absl import flags
from checkpoint_conversion_utils import get_md5_checksum
from keras import ops

import keras_hub

PRESET_MAP = {
"t5_small_multi": "t5-small",
"t5_base_multi": "t5-base",
"t5_large_multi": "t5-large",
"t5_1.1_small": "google/t5-v1_1-small",
"t5_1.1_base": "google/t5-v1_1-base",
"t5_1.1_large": "google/t5-v1_1-large",
"t5_1.1_xl": "google/t5-v1_1-xl",
"t5_1.1_xxl": "google/t5-v1_1-xxl",
"flan_small_multi": "google/flan-t5-small",
"flan_base_multi": "google/flan-t5-base",
"flan_large_multi": "google/flan-t5-large",
}


PARAM_MAP = {
"t5_1.1_small": {
"trainable": True,
"vocabulary_size": 32128,
"hidden_dim": 512,
"intermediate_dim": 1024,
"num_layers": 8,
"num_heads": 6,
"activation": "gelu",
"key_value_dim": 64,
"dropout": 0.1,
"use_gated_activation": True,
"layer_norm_epsilon": 1e-6,
"tie_embedding_weights": False,
},
"t5_1.1_base": {
"trainable": True,
"vocabulary_size": 32128,
"hidden_dim": 768,
"intermediate_dim": 2048,
"num_layers": 12,
"num_heads": 12,
"activation": "gelu",
"key_value_dim": 64,
"dropout": 0.1,
"use_gated_activation": True,
"layer_norm_epsilon": 1e-6,
"tie_embedding_weights": False,
},
"t5_1.1_large": {
"trainable": True,
"vocabulary_size": 32128,
"hidden_dim": 1024,
"intermediate_dim": 2816,
"num_layers": 24,
"num_heads": 16,
"activation": "gelu",
"key_value_dim": 64,
"dropout": 0.1,
"use_gated_activation": True,
"layer_norm_epsilon": 1e-6,
"tie_embedding_weights": False,
},
"t5_1.1_xl": {
"trainable": True,
"vocabulary_size": 32128,
"hidden_dim": 2048,
"intermediate_dim": 5120,
"num_layers": 24,
"num_heads": 32,
"activation": "gelu",
"key_value_dim": 64,
"dropout": 0.1,
"use_gated_activation": True,
"layer_norm_epsilon": 1e-6,
"tie_embedding_weights": False,
},
"t5_1.1_xxl": {
"trainable": True,
"vocabulary_size": 32128,
"hidden_dim": 4096,
"intermediate_dim": 10240,
"num_layers": 24,
"num_heads": 64,
"activation": "gelu",
"key_value_dim": 64,
"dropout": 0.1,
"use_gated_activation": True,
"layer_norm_epsilon": 1e-6,
"tie_embedding_weights": False,
},
}


FLAGS = flags.FLAGS

flags.DEFINE_string(
Expand Down Expand Up @@ -52,9 +131,7 @@ def extract_vocab(hf_tokenizer):


def convert_checkpoints(hf_model):
keras_hub_model = keras_hub.models.T5Backbone.from_preset(
FLAGS.preset, load_weights=False
)
keras_hub_model = keras_hub.models.T5Backbone(**PARAM_MAP[FLAGS.preset])

hf_wts = hf_model.state_dict()
print("Original weights:")
Expand Down Expand Up @@ -308,17 +385,12 @@ def check_output(
keras_hidden_states = keras_out["decoder_sequence_output"]
hf_hidden_states = hf_out.decoder_hidden_states[-1]

keras_outputs = ops.take_along_axis(
keras_hidden_states, ops.where(decoder_padding_mask)
)
hf_outputs = ops.take_along_axis(
hf_hidden_states, ops.where(decoder_padding_mask)
)

print("-> KerasHub output:", keras_outputs[0:5])
print("-> HF output:", hf_outputs[0:5])
print("-> KerasHub output:", keras_hidden_states[0:5])
print("-> HF output:", hf_hidden_states[0:5])
np.testing.assert_allclose(
keras_outputs.detach().numpy(), hf_outputs.detach().numpy(), atol=1e-5
keras_hidden_states.numpy(),
hf_hidden_states.detach().numpy(),
atol=1e-2,
)

if keras_model.tie_embedding_weights:
Expand All @@ -333,7 +405,7 @@ def check_output(
print("-> KerasHub logits:", keras_logits[0:5])
print("-> HF logits:", hf_logits[0:5])
np.testing.assert_allclose(
keras_logits.detach().numpy(), hf_logits.detach().numpy(), atol=1e-3
keras_logits.numpy(), hf_logits.detach().numpy(), atol=1e-1
)


Expand All @@ -352,16 +424,18 @@ def main(_):
keras_model = convert_checkpoints(hf_model)

# Save the model.
model_path = f"./{FLAGS.preset}/model.weights.h5"
model_path = f"./{FLAGS.preset}"
weight_path = os.path.join(model_path, "model.weights.h5")
print(f"\n-> Save KerasHub model weights to `{model_path}`.")
keras_model.save_weights(model_path)
keras_model.save_to_preset(model_path)
print("-> Print MD5 checksum of the model weights files.")
print(f"`{model_path}` md5sum: ", get_md5_checksum(model_path))
print(f"`{model_path}` md5sum: ", get_md5_checksum(weight_path))
print(f"-> Param count {count_params(keras_model.weights)}")

print("\n-> Convert vocab.")
hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_id)
keras_tokenizer = extract_vocab(hf_tokenizer)
keras_tokenizer.save_to_preset(model_path)

check_output(
keras_model,
Expand Down
Loading