Skip to content

Commit a00efc2

Browse files
authored
Porting PaliGemma transformers checkpoint (#1686)
* chore: adding paligemma * chore: adding paligemma conversion * chore: adding tests
1 parent 29c85c0 commit a00efc2

File tree

3 files changed

+388
-0
lines changed

3 files changed

+388
-0
lines changed

keras_nlp/src/utils/transformers/convert.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020
from keras_nlp.src.utils.transformers.convert_llama3 import (
2121
load_llama3_tokenizer,
2222
)
23+
from keras_nlp.src.utils.transformers.convert_pali_gemma import (
24+
load_pali_gemma_backbone,
25+
)
26+
from keras_nlp.src.utils.transformers.convert_pali_gemma import (
27+
load_pali_gemma_tokenizer,
28+
)
2329

2430

2531
def load_transformers_backbone(cls, preset, load_weights):
@@ -29,6 +35,8 @@ def load_transformers_backbone(cls, preset, load_weights):
2935
return load_gemma_backbone(cls, preset, load_weights)
3036
if cls.__name__ == "Llama3Backbone":
3137
return load_llama3_backbone(cls, preset, load_weights)
38+
if cls.__name__ == "PaliGemmaBackbone":
39+
return load_pali_gemma_backbone(cls, preset, load_weights)
3240
raise ValueError(
3341
f"{cls} has not been ported from the Hugging Face format yet. "
3442
"Please check Hugging Face Hub for the Keras model. "
@@ -42,6 +50,8 @@ def load_transformers_tokenizer(cls, preset):
4250
return load_gemma_tokenizer(cls, preset)
4351
if cls.__name__ == "Llama3Tokenizer":
4452
return load_llama3_tokenizer(cls, preset)
53+
if cls.__name__ == "PaliGemmaTokenizer":
54+
return load_pali_gemma_tokenizer(cls, preset)
4555
raise ValueError(
4656
f"{cls} has not been ported from the Hugging Face format yet. "
4757
"Please check Hugging Face Hub for the Keras model. "
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from functools import partial
15+
16+
import numpy as np
17+
18+
from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE
19+
from keras_nlp.src.utils.preset_utils import SAFETENSOR_CONFIG_FILE
20+
from keras_nlp.src.utils.preset_utils import get_file
21+
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
22+
from keras_nlp.src.utils.preset_utils import load_config
23+
from keras_nlp.src.utils.transformers.safetensor_utils import set_keras_weight
24+
25+
26+
def load_pali_gemma_backbone(cls, preset, load_weights):
27+
"""
28+
Load and initialize the PaliGemma backbone model.
29+
30+
Args:
31+
cls (class): Keras model class.
32+
preset (str): Preset configuration name.
33+
load_weights (bool): Whether to load the weights.
34+
35+
Returns:
36+
backbone: Initialized Keras model backbone.
37+
"""
38+
transformers_config = load_config(preset, HF_CONFIG_FILE)
39+
text_config = transformers_config["text_config"]
40+
vision_config = transformers_config["vision_config"]
41+
backbone = cls(
42+
vocabulary_size=transformers_config["image_token_index"],
43+
image_size=(
44+
vision_config["image_size"]
45+
if "image_size" in vision_config.keys()
46+
else 224
47+
),
48+
num_layers=text_config["num_hidden_layers"],
49+
num_query_heads=text_config["num_attention_heads"],
50+
num_key_value_heads=text_config["num_key_value_heads"],
51+
hidden_dim=text_config["hidden_size"],
52+
intermediate_dim=text_config["intermediate_size"] * 2,
53+
head_dim=text_config["num_image_tokens"],
54+
vit_patch_size=vision_config["patch_size"],
55+
vit_num_heads=vision_config["num_attention_heads"],
56+
vit_hidden_dim=vision_config["hidden_size"],
57+
vit_num_layers=vision_config["num_hidden_layers"],
58+
vit_intermediate_dim=vision_config["intermediate_size"],
59+
)
60+
61+
if not load_weights:
62+
return backbone
63+
64+
jax_memory_cleanup(backbone)
65+
# Code to port the weights from safetensors into the keras nlp model
66+
safetensor_config = load_config(preset, SAFETENSOR_CONFIG_FILE)
67+
safetensor_files = {
68+
fname: get_file(preset, fname)
69+
for fname in set(safetensor_config["weight_map"].values())
70+
}
71+
port_weight = partial(
72+
set_keras_weight,
73+
safetensor_files=safetensor_files,
74+
safetensor_config=safetensor_config,
75+
)
76+
77+
############################################################################
78+
# Image Tower
79+
############################################################################
80+
image_encoder = backbone.vit_encoder.get_layer("image_encoder")
81+
82+
# Embedding
83+
port_weight(
84+
keras_variable=image_encoder.vision_embeddings.patch_embedding.bias,
85+
hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.bias",
86+
)
87+
88+
port_weight(
89+
keras_variable=image_encoder.vision_embeddings.patch_embedding.kernel,
90+
hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.weight",
91+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
92+
hf_tensor,
93+
axes=(2, 3, 1, 0),
94+
),
95+
)
96+
97+
# Positional Embedding
98+
port_weight(
99+
keras_variable=image_encoder.vision_embeddings.position_embedding.embeddings,
100+
hf_weight_key="vision_tower.vision_model.embeddings.position_embedding.weight",
101+
)
102+
103+
# Normalization
104+
port_weight(
105+
keras_variable=image_encoder.encoder_layer_norm.gamma,
106+
hf_weight_key="vision_tower.vision_model.post_layernorm.weight",
107+
)
108+
109+
port_weight(
110+
keras_variable=image_encoder.encoder_layer_norm.beta,
111+
hf_weight_key="vision_tower.vision_model.post_layernorm.bias",
112+
)
113+
114+
# ResBlocks
115+
for index in range(image_encoder.num_layers):
116+
block = image_encoder.resblocks[index]
117+
118+
port_weight(
119+
keras_variable=block.layer_norm_1.beta,
120+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm1.bias",
121+
)
122+
123+
port_weight(
124+
keras_variable=block.layer_norm_1.gamma,
125+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm1.weight",
126+
)
127+
128+
port_weight(
129+
keras_variable=block.layer_norm_2.beta,
130+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm2.bias",
131+
)
132+
133+
port_weight(
134+
keras_variable=block.layer_norm_2.gamma,
135+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm2.weight",
136+
)
137+
138+
port_weight(
139+
keras_variable=block.mlp_dense_1.kernel,
140+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc1.weight",
141+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
142+
hf_tensor,
143+
axes=(1, 0),
144+
),
145+
)
146+
147+
port_weight(
148+
keras_variable=block.mlp_dense_1.bias,
149+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc1.bias",
150+
)
151+
152+
port_weight(
153+
keras_variable=block.mlp_dense_2.kernel,
154+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc2.weight",
155+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
156+
hf_tensor,
157+
axes=(1, 0),
158+
),
159+
)
160+
161+
port_weight(
162+
keras_variable=block.mlp_dense_2.bias,
163+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc2.bias",
164+
)
165+
166+
port_weight(
167+
keras_variable=block.attn.key_proj.bias,
168+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.k_proj.bias",
169+
)
170+
171+
port_weight(
172+
keras_variable=block.attn.key_proj.kernel,
173+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.k_proj.weight",
174+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
175+
hf_tensor,
176+
axes=(1, 0),
177+
),
178+
)
179+
180+
port_weight(
181+
keras_variable=block.attn.out_proj.bias,
182+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.out_proj.bias",
183+
)
184+
185+
port_weight(
186+
keras_variable=block.attn.out_proj.kernel,
187+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.out_proj.weight",
188+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
189+
hf_tensor,
190+
axes=(1, 0),
191+
),
192+
)
193+
194+
port_weight(
195+
keras_variable=block.attn.query_proj.bias,
196+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.q_proj.bias",
197+
)
198+
199+
port_weight(
200+
keras_variable=block.attn.query_proj.kernel,
201+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.q_proj.weight",
202+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
203+
hf_tensor,
204+
axes=(1, 0),
205+
),
206+
)
207+
208+
port_weight(
209+
keras_variable=block.attn.value_proj.bias,
210+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.v_proj.bias",
211+
)
212+
213+
port_weight(
214+
keras_variable=block.attn.value_proj.kernel,
215+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.v_proj.weight",
216+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
217+
hf_tensor,
218+
axes=(1, 0),
219+
),
220+
)
221+
222+
# Multi Modal Projection
223+
port_weight(
224+
keras_variable=backbone.vit_encoder.get_layer(
225+
"image_classifier"
226+
).kernel,
227+
hf_weight_key="multi_modal_projector.linear.weight",
228+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
229+
hf_tensor,
230+
axes=(1, 0),
231+
),
232+
)
233+
234+
port_weight(
235+
keras_variable=backbone.vit_encoder.get_layer("image_classifier").bias,
236+
hf_weight_key="multi_modal_projector.linear.bias",
237+
)
238+
239+
############################################################################
240+
# Language Tower
241+
############################################################################
242+
for index in range(backbone.num_layers):
243+
decoder_layer = backbone.transformer_layers[index]
244+
245+
# Norm layers
246+
port_weight(
247+
keras_variable=decoder_layer.pre_attention_norm.scale,
248+
hf_weight_key=f"language_model.model.layers.{index}.input_layernorm.weight",
249+
)
250+
port_weight(
251+
keras_variable=decoder_layer.pre_ffw_norm.scale,
252+
hf_weight_key=f"language_model.model.layers.{index}.post_attention_layernorm.weight",
253+
)
254+
255+
# Attention layers
256+
port_weight(
257+
keras_variable=decoder_layer.attention.query_dense.kernel,
258+
hf_weight_key=f"language_model.model.layers.{index}.self_attn.q_proj.weight",
259+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
260+
np.reshape(
261+
hf_tensor,
262+
(keras_shape[0], keras_shape[2], keras_shape[1]),
263+
),
264+
axes=(0, 2, 1),
265+
),
266+
)
267+
port_weight(
268+
keras_variable=decoder_layer.attention.key_dense.kernel,
269+
hf_weight_key=f"language_model.model.layers.{index}.self_attn.k_proj.weight",
270+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
271+
np.reshape(
272+
hf_tensor,
273+
(keras_shape[0], keras_shape[2], keras_shape[1]),
274+
),
275+
axes=(0, 2, 1),
276+
),
277+
)
278+
port_weight(
279+
keras_variable=decoder_layer.attention.value_dense.kernel,
280+
hf_weight_key=f"language_model.model.layers.{index}.self_attn.v_proj.weight",
281+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
282+
np.reshape(
283+
hf_tensor,
284+
(keras_shape[0], keras_shape[2], keras_shape[1]),
285+
),
286+
axes=(0, 2, 1),
287+
),
288+
)
289+
port_weight(
290+
keras_variable=decoder_layer.attention.output_dense.kernel,
291+
hf_weight_key=f"language_model.model.layers.{index}.self_attn.o_proj.weight",
292+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
293+
np.reshape(
294+
hf_tensor,
295+
(keras_shape[2], keras_shape[0], keras_shape[1]),
296+
),
297+
axes=(1, 2, 0),
298+
),
299+
)
300+
301+
# MLP layers
302+
port_weight(
303+
keras_variable=decoder_layer.gating_ffw.variables[0],
304+
hf_weight_key=f"language_model.model.layers.{index}.mlp.gate_proj.weight",
305+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
306+
)
307+
port_weight(
308+
keras_variable=decoder_layer.gating_ffw_2.variables[0],
309+
hf_weight_key=f"language_model.model.layers.{index}.mlp.up_proj.weight",
310+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
311+
)
312+
port_weight(
313+
keras_variable=decoder_layer.ffw_linear.variables[0],
314+
hf_weight_key=f"language_model.model.layers.{index}.mlp.down_proj.weight",
315+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
316+
)
317+
318+
# Normalization
319+
port_weight(
320+
keras_variable=backbone.layer_norm.scale,
321+
hf_weight_key="language_model.model.norm.weight",
322+
)
323+
324+
# Embedding
325+
port_weight(
326+
keras_variable=backbone.token_embedding.embeddings,
327+
hf_weight_key="language_model.model.embed_tokens.weight",
328+
hook_fn=lambda hf_tensor, keras_shape: hf_tensor[: keras_shape[0]],
329+
)
330+
331+
return backbone
332+
333+
334+
def load_pali_gemma_tokenizer(cls, preset):
335+
"""
336+
Load the Gemma tokenizer.
337+
338+
Args:
339+
cls (class): Tokenizer class.
340+
preset (str): Preset configuration name.
341+
342+
Returns:
343+
tokenizer: Initialized tokenizer.
344+
"""
345+
return cls(get_file(preset, "tokenizer.model"))

0 commit comments

Comments
 (0)