Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def next(prompt, state, index):
)
print("Time taken: ", time_taken)
res_handler.write(
f"{sampler},{execution_method}," f"{time_taken}\n"
f"{sampler},{execution_method},{time_taken}\n"
)
print()
print("*************************************")
Expand Down
5 changes: 3 additions & 2 deletions keras_hub/src/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,9 @@ def validate_and_fix_rank(inputs, tensor_name, base_rank=0):
return tf.squeeze(inputs, axis=-1)
else:
raise ValueError(
f"{tensor_name} must be of rank {base_rank}, {base_rank+1} "
f"or {base_rank+2}. Found rank: {inputs.shape.rank}"
f"{tensor_name} must be of rank {base_rank}, "
f"{base_rank + 1}, or {base_rank + 2}. "
f"Found rank: {inputs.shape.rank}"
)

y_true = validate_and_fix_rank(y_true, "y_true", 1)
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/basnet/basnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def get_resnet_block(_resnet, block_num):
else:
x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]]
y = _resnet.get_layer(
f"stack{block_num}_block{num_blocks[block_num]-1}_add"
f"stack{block_num}_block{num_blocks[block_num] - 1}_add"
).output
return keras.models.Model(
inputs=x,
Expand Down
6 changes: 3 additions & 3 deletions keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ def build(self, input_shape):
dilation_rate=dilation_rate,
use_bias=False,
data_format=self.data_format,
name=f"aspp_conv_{i+2}",
name=f"aspp_conv_{i + 2}",
),
keras.layers.BatchNormalization(
axis=self.channel_axis, name=f"aspp_bn_{i+2}"
axis=self.channel_axis, name=f"aspp_bn_{i + 2}"
),
keras.layers.Activation(
self.activation, name=f"aspp_activation_{i+2}"
self.activation, name=f"aspp_activation_{i + 2}"
),
]
)
Expand Down
6 changes: 3 additions & 3 deletions keras_hub/src/models/densenet/densenet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def __init__(
channel_axis,
stackwise_num_repeats[stack_index],
growth_rate,
name=f"stack{stack_index+1}",
name=f"stack{stack_index + 1}",
)
pyramid_outputs[f"P{index}"] = x
x = apply_transition_block(
x,
channel_axis,
compression_ratio,
name=f"transition{stack_index+1}",
name=f"transition{stack_index + 1}",
)

x = apply_dense_block(
Expand Down Expand Up @@ -140,7 +140,7 @@ def apply_dense_block(x, channel_axis, num_repeats, growth_rate, name=None):

for i in range(num_repeats):
x = apply_conv_block(
x, channel_axis, growth_rate, name=f"{name}_block{i+1}"
x, channel_axis, growth_rate, name=f"{name}_block{i + 1}"
)
return x

Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/flux/flux_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(

def fit(self, *args, **kwargs):
raise NotImplementedError(
"Currently, `fit` is not supported for " "`FluxTextToImage`."
"Currently, `fit` is not supported for `FluxTextToImage`."
)

def generate_step(
Expand Down
4 changes: 2 additions & 2 deletions keras_hub/src/models/pali_gemma/pali_gemma_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"pali_gemma_3b_mix_224": {
"metadata": {
"description": (
"image size 224, mix fine tuned, text sequence " "length is 256"
"image size 224, mix fine tuned, text sequence length is 256"
),
"params": 2923335408,
"path": "pali_gemma",
Expand Down Expand Up @@ -45,7 +45,7 @@
"pali_gemma_3b_896": {
"metadata": {
"description": (
"image size 896, pre trained, text sequence length " "is 512"
"image size 896, pre trained, text sequence length is 512"
),
"params": 2927759088,
"path": "pali_gemma",
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/resnet/resnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(
use_bias=False,
padding="same",
dtype=dtype,
name=f"conv{conv_index+1}_conv",
name=f"conv{conv_index + 1}_conv",
)(x)

if not use_pre_activation:
Expand Down
10 changes: 5 additions & 5 deletions keras_hub/src/models/retinanet/feature_pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def build(self, input_shapes):
)
if i == backbone_max_level + 1 and self.use_p5:
self.output_conv_layers[level].build(
(None, None, None, input_shapes[f"P{i-1}"][-1])
(None, None, None, input_shapes[f"P{i - 1}"][-1])
if self.data_format == "channels_last"
else (None, input_shapes[f"P{i-1}"][1], None, None)
else (None, input_shapes[f"P{i - 1}"][1], None, None)
)
else:
self.output_conv_layers[level].build(
Expand Down Expand Up @@ -277,7 +277,7 @@ def call(self, inputs):
if i < backbone_max_level:
# for the top most output, it doesn't need to merge with any
# upper stream outputs
upstream_output = self.top_down_op(output_features[f"P{i+1}"])
upstream_output = self.top_down_op(output_features[f"P{i + 1}"])
output = self.merge_op([output, upstream_output])
output_features[level] = (
self.lateral_batch_norm_layers[level](output)
Expand All @@ -296,9 +296,9 @@ def call(self, inputs):
for i in range(backbone_max_level + 1, self.max_level + 1):
level = f"P{i}"
feats_in = (
inputs[f"P{i-1}"]
inputs[f"P{i - 1}"]
if i == backbone_max_level + 1 and self.use_p5
else output_features[f"P{i-1}"]
else output_features[f"P{i - 1}"]
)
if i > backbone_max_level + 1:
feats_in = self.activation(feats_in)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ def __init__(

def fit(self, *args, **kwargs):
raise NotImplementedError(
"Currently, `fit` is not supported for "
"`StableDiffusion3Inpaint`."
"Currently, `fit` is not supported for `StableDiffusion3Inpaint`."
)

def generate_step(
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/vit/vit_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def build(self, input_shape):
attention_dropout=self.attention_dropout,
layer_norm_epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name=f"tranformer_block_{i+1}",
name=f"tranformer_block_{i + 1}",
)
encoder_block.build((None, None, self.hidden_dim))
self.encoder_layers.append(encoder_block)
Expand Down
3 changes: 1 addition & 2 deletions keras_hub/src/tokenizers/byte_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def __init__(
):
if not is_int_dtype(dtype):
raise ValueError(
"Output dtype must be an integer type. "
f"Received: dtype={dtype}"
f"Output dtype must be an integer type. Received: dtype={dtype}"
)

# Check normalization_form.
Expand Down
3 changes: 1 addition & 2 deletions keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ def __init__(
) -> None:
if not is_int_dtype(dtype):
raise ValueError(
"Output dtype must be an integer type. "
f"Received: dtype={dtype}"
f"Output dtype must be an integer type. Received: dtype={dtype}"
)

# Check normalization_form.
Expand Down
10 changes: 6 additions & 4 deletions keras_hub/src/utils/timm/convert_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,20 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix):
num_stacks = len(backbone.stackwise_num_repeats)
for stack_index in range(num_stacks):
for block_idx in range(backbone.stackwise_num_repeats[stack_index]):
keras_name = f"stack{stack_index+1}_block{block_idx+1}"
keras_name = f"stack{stack_index + 1}_block{block_idx + 1}"
hf_name = (
f"features.denseblock{stack_index+1}.denselayer{block_idx+1}"
"features."
f"denseblock{stack_index + 1}"
f".denselayer{block_idx + 1}"
)
port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.norm1")
port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.norm2")
port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")

for stack_index in range(num_stacks - 1):
keras_transition_name = f"transition{stack_index+1}"
hf_transition_name = f"features.transition{stack_index+1}"
keras_transition_name = f"transition{stack_index + 1}"
hf_transition_name = f"features.transition{stack_index + 1}"
port_batch_normalization(
f"{keras_transition_name}_bn", f"{hf_transition_name}.norm"
)
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/utils/timm/convert_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def port_batch_normalization(keras_layer, hf_weight_prefix):
# 97 is the start of the lowercase alphabet.
letter_identifier = chr(block_idx + 97)

keras_block_prefix = f"block{stack_index+1}{letter_identifier}_"
keras_block_prefix = f"block{stack_index + 1}{letter_identifier}_"
hf_block_prefix = f"blocks.{stack_index}.{block_idx}."

if block_type == "v1":
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/utils/timm/convert_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix):
for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
if version == "v1":
keras_name = f"stack{stack_index}_block{block_idx}"
hf_name = f"layer{stack_index+1}.{block_idx}"
hf_name = f"layer{stack_index + 1}.{block_idx}"
else:
keras_name = f"stack{stack_index}_block{block_idx}"
hf_name = f"stages.{stack_index}.blocks.{block_idx}"
Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_albert_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_bart_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
9 changes: 5 additions & 4 deletions tools/checkpoint_conversion/convert_bloom_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@


flags.DEFINE_string(
"preset", None, f'Must be one of {", ".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {', '.join(PRESET_MAP.keys())}"
)
flags.mark_flag_as_required("preset")
flags.DEFINE_boolean(
Expand Down Expand Up @@ -244,9 +244,10 @@ def preprocessor_call(input_str):

def main(_):
preset = FLAGS.preset
assert (
preset in PRESET_MAP.keys()
), f'Invalid preset {preset}. Must be one of {", ".join(PRESET_MAP.keys())}'
assert preset in PRESET_MAP.keys(), (
f"Invalid preset {preset}. "
f"Must be one of {', '.join(PRESET_MAP.keys())}"
)

validate_only = FLAGS.validate_only

Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_clip_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
flags.DEFINE_string(
"preset",
None,
f'Must be one of {",".join(PRESET_MAP.keys())}',
f"Must be one of {','.join(PRESET_MAP.keys())}",
required=True,
)
flags.DEFINE_string(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_electra_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
flags.DEFINE_string(
"preset",
"electra_base_discriminator_en",
f'Must be one of {",".join(PRESET_MAP)}',
f"Must be one of {','.join(PRESET_MAP)}",
)
flags.mark_flag_as_required("preset")

Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_f_net_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_falcon_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
absl.flags.DEFINE_string(
"preset",
"falcon_refinedweb_1b_en",
f'Must be one of {",".join(PRESET_MAP.keys())}.',
f"Must be one of {','.join(PRESET_MAP.keys())}.",
)


Expand Down
8 changes: 4 additions & 4 deletions tools/checkpoint_conversion/convert_gemma_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
flags.DEFINE_string(
"preset",
None,
f'Must be one of {",".join(PRESET_MAP.keys())}',
f"Must be one of {','.join(PRESET_MAP.keys())}",
required=True,
)

Expand Down Expand Up @@ -228,9 +228,9 @@ def main(_):
flax_dir = FLAGS.flax_dir
else:
presets = PRESET_MAP.keys()
assert (
preset in presets
), f'Invalid preset {preset}. Must be one of {",".join(presets)}'
assert preset in presets, (
f"Invalid preset {preset}. Must be one of {','.join(presets)}"
)
handle = PRESET_MAP[preset]
flax_dir = download_flax_model(handle)

Expand Down
6 changes: 3 additions & 3 deletions tools/checkpoint_conversion/convert_gpt2_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down Expand Up @@ -236,8 +236,8 @@ def check_output(

def main(_):
assert FLAGS.preset in PRESET_MAP.keys(), (
f'Invalid preset {FLAGS.preset}. '
f'Must be one of {",".join(PRESET_MAP.keys())}'
f"Invalid preset {FLAGS.preset}. "
f"Must be one of {','.join(PRESET_MAP.keys())}"
)
num_params = PRESET_MAP[FLAGS.preset][0]
hf_model_name = PRESET_MAP[FLAGS.preset][1]
Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_llama3_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_llama_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)

flags.DEFINE_string(
Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_mistral_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_mix_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
}

flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(DOWNLOAD_URLS.keys())}'
"preset", None, f"Must be one of {','.join(DOWNLOAD_URLS.keys())}"
)


Expand Down
Loading
Loading