diff --git a/keras_nlp/src/models/backbone.py b/keras_nlp/src/models/backbone.py index 4477df4a2c..bbcbdbb04c 100644 --- a/keras_nlp/src/models/backbone.py +++ b/keras_nlp/src/models/backbone.py @@ -76,14 +76,14 @@ def __init__(self, *args, dtype=None, **kwargs): id(layer) for layer in self._flatten_layers() ) self._initialized = True - # Before Keras 3.2, there is no `keras.dtype_policies.get`. - if hasattr(keras.dtype_policies, "get"): - self.dtype_policy = keras.dtype_policies.get(dtype) - else: - if isinstance(dtype, keras.dtype_policies.DTypePolicy): - dtype = dtype.name - dtype = dtype or keras.config.dtype_policy().name - self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype) + if dtype is not None: + try: + self.dtype_policy = keras.dtype_policies.get(dtype) + # Before Keras 3.2, there is no `keras.dtype_policies.get`. + except AttributeError: + if isinstance(dtype, keras.DTypePolicy): + dtype = dtype.name + self.dtype_policy = keras.DTypePolicy(dtype) def __setattr__(self, name, value): # Work around setattr issues for Keras 2 and Keras 3 torch backend. @@ -121,7 +121,7 @@ def get_config(self): } # Add quantization support by utilizing `DTypePolicyMap` - if hasattr(keras.dtype_policies, "DTypePolicyMap"): + try: if isinstance( self.dtype_policy, keras.dtype_policies.DTypePolicyMap ): @@ -133,6 +133,9 @@ def get_config(self): policy_map[layer.path] = layer.dtype_policy if len(policy_map) > 0: config.update({"dtype": policy_map}) + # Before Keras 3.2, there is no `keras.dtype_policies.get`. + except AttributeError: + pass return config @classmethod