Skip to content

Commit 6db198e

Browse files
committed
Another fix
1 parent 2f35556 commit 6db198e

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

keras_nlp/src/models/pali_gemma/pali_gemma_backbone_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_saved_model(self):
8585
input_data=self.input_data,
8686
)
8787

88-
@pytest.mark.large
88+
@pytest.mark.extra_large
8989
def test_smallest_preset(self):
9090
self.run_preset_test(
9191
cls=PaliGemmaBackbone,

keras_nlp/src/utils/tensor_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,21 +137,23 @@ def convert_preprocessing_inputs(x):
137137
if numpy_x.dtype == np.int64:
138138
numpy_x = numpy_x.astype(np.int32)
139139
# We have non-ragged, non-string input. Use backbend type.
140-
return ops.convert_to_tensor(numpy_x)
140+
x = ops.convert_to_tensor(numpy_x)
141+
# Torch will complain about device placement for GPU tensors.
142+
if keras.config.backend() == "torch":
143+
x = x.cpu()
144+
return x
141145
if is_tensor_type(x):
142146
# String or ragged types we keep as tf.
143147
if isinstance(x, tf.RaggedTensor) or x.dtype == tf.string:
144148
return x
145149
# If we have a string input, use tf.tensor.
146150
if isinstance(x, np.ndarray) and x.dtype.type is np.str_:
147151
return tf.convert_to_tensor(x)
152+
x = ops.convert_to_tensor(x)
148153
# Torch will complain about device placement for GPU tensors.
149154
if keras.config.backend() == "torch":
150-
import torch
151-
152-
if isinstance(x, torch.Tensor):
153-
x = x.cpu()
154-
return ops.convert_to_tensor(x)
155+
x = x.cpu()
156+
return x
155157
return x
156158

157159

keras_nlp/src/utils/transformers/preset_loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,7 @@ def load_backbone(self, cls, load_weights, **kwargs):
7171

7272
def load_tokenizer(self, cls, **kwargs):
7373
return self.converter.convert_tokenizer(cls, self.preset, **kwargs)
74+
75+
def load_image_converter(self, cls, **kwargs):
76+
# TODO: auto resize.
77+
return None

0 commit comments

Comments
 (0)