diff --git a/keras/api/__init__.py b/keras/api/__init__.py index 133437917237..bc3338fa2b29 100644 --- a/keras/api/__init__.py +++ b/keras/api/__init__.py @@ -63,5 +63,5 @@ from keras.src.optimizers.optimizer import Optimizer as Optimizer from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.regularizers.regularizers import Regularizer as Regularizer -from keras.src.version import __version__ as __version__ from keras.src.version import version as version +from keras.src.version import __version__ as __version__ diff --git a/keras/api/_tf_keras/keras/__init__.py b/keras/api/_tf_keras/keras/__init__.py index 3457f05233e4..659b4534b3de 100644 --- a/keras/api/_tf_keras/keras/__init__.py +++ b/keras/api/_tf_keras/keras/__init__.py @@ -27,11 +27,6 @@ from keras import utils as utils from keras import visualization as visualization from keras import wrappers as wrappers -from keras._tf_keras.keras import backend as backend -from keras._tf_keras.keras import layers as layers -from keras._tf_keras.keras import losses as losses -from keras._tf_keras.keras import metrics as metrics -from keras._tf_keras.keras import preprocessing as preprocessing from keras.src.backend import Variable as Variable from keras.src.backend import device as device from keras.src.backend import name_scope as name_scope @@ -61,5 +56,12 @@ from keras.src.optimizers.optimizer import Optimizer as Optimizer from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.regularizers.regularizers import Regularizer as Regularizer -from keras.src.version import __version__ as __version__ from keras.src.version import version as version +from keras.src.version import __version__ as __version__ + + +from keras._tf_keras.keras import backend as backend +from keras._tf_keras.keras import layers as layers +from keras._tf_keras.keras import losses as losses +from keras._tf_keras.keras import metrics as metrics +from keras._tf_keras.keras import preprocessing as preprocessing diff --git a/keras/api/_tf_keras/keras/applications/__init__.py b/keras/api/_tf_keras/keras/applications/__init__.py index 7c030b36bd4e..d9bd5494e777 100644 --- a/keras/api/_tf_keras/keras/applications/__init__.py +++ b/keras/api/_tf_keras/keras/applications/__init__.py @@ -72,12 +72,12 @@ ) from keras.src.applications.nasnet import NASNetLarge as NASNetLarge from keras.src.applications.nasnet import NASNetMobile as NASNetMobile -from keras.src.applications.resnet import ResNet50 as ResNet50 from keras.src.applications.resnet import ResNet101 as ResNet101 from keras.src.applications.resnet import ResNet152 as ResNet152 -from keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2 +from keras.src.applications.resnet import ResNet50 as ResNet50 from keras.src.applications.resnet_v2 import ResNet101V2 as ResNet101V2 from keras.src.applications.resnet_v2 import ResNet152V2 as ResNet152V2 +from keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2 from keras.src.applications.vgg16 import VGG16 as VGG16 from keras.src.applications.vgg19 import VGG19 as VGG19 from keras.src.applications.xception import Xception as Xception diff --git a/keras/api/_tf_keras/keras/applications/resnet/__init__.py b/keras/api/_tf_keras/keras/applications/resnet/__init__.py index b8a25644e1d9..cae7b6a5c05b 100644 --- a/keras/api/_tf_keras/keras/applications/resnet/__init__.py +++ b/keras/api/_tf_keras/keras/applications/resnet/__init__.py @@ -4,9 +4,9 @@ since your modifications would be overwritten. """ -from keras.src.applications.resnet import ResNet50 as ResNet50 from keras.src.applications.resnet import ResNet101 as ResNet101 from keras.src.applications.resnet import ResNet152 as ResNet152 +from keras.src.applications.resnet import ResNet50 as ResNet50 from keras.src.applications.resnet import ( decode_predictions as decode_predictions, ) diff --git a/keras/api/_tf_keras/keras/applications/resnet_v2/__init__.py b/keras/api/_tf_keras/keras/applications/resnet_v2/__init__.py index 7f92dd56f374..45b0ced50c2b 100644 --- a/keras/api/_tf_keras/keras/applications/resnet_v2/__init__.py +++ b/keras/api/_tf_keras/keras/applications/resnet_v2/__init__.py @@ -4,9 +4,9 @@ since your modifications would be overwritten. """ -from keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2 from keras.src.applications.resnet_v2 import ResNet101V2 as ResNet101V2 from keras.src.applications.resnet_v2 import ResNet152V2 as ResNet152V2 +from keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2 from keras.src.applications.resnet_v2 import ( decode_predictions as decode_predictions, ) diff --git a/keras/api/_tf_keras/keras/backend/__init__.py b/keras/api/_tf_keras/keras/backend/__init__.py index cd9037bcf4d6..1a87e31676aa 100644 --- a/keras/api/_tf_keras/keras/backend/__init__.py +++ b/keras/api/_tf_keras/keras/backend/__init__.py @@ -23,6 +23,9 @@ from keras.src.backend.config import ( set_image_data_format as set_image_data_format, ) +from keras.src.utils.naming import get_uid as get_uid + + from keras.src.legacy.backend import abs as abs from keras.src.legacy.backend import all as all from keras.src.legacy.backend import any as any @@ -162,4 +165,3 @@ from keras.src.legacy.backend import variable as variable from keras.src.legacy.backend import zeros as zeros from keras.src.legacy.backend import zeros_like as zeros_like -from keras.src.utils.naming import get_uid as get_uid diff --git a/keras/api/_tf_keras/keras/initializers/__init__.py b/keras/api/_tf_keras/keras/initializers/__init__.py index e88013d97315..6cd6240e1bdf 100644 --- a/keras/api/_tf_keras/keras/initializers/__init__.py +++ b/keras/api/_tf_keras/keras/initializers/__init__.py @@ -7,9 +7,6 @@ from keras.src.initializers import deserialize as deserialize from keras.src.initializers import get as get from keras.src.initializers import serialize as serialize -from keras.src.initializers.constant_initializers import STFT as STFT -from keras.src.initializers.constant_initializers import STFT as STFTInitializer -from keras.src.initializers.constant_initializers import STFT as stft from keras.src.initializers.constant_initializers import Constant as Constant from keras.src.initializers.constant_initializers import Constant as constant from keras.src.initializers.constant_initializers import Identity as Identity @@ -19,6 +16,9 @@ from keras.src.initializers.constant_initializers import Identity as identity from keras.src.initializers.constant_initializers import Ones as Ones from keras.src.initializers.constant_initializers import Ones as ones +from keras.src.initializers.constant_initializers import STFT as STFT +from keras.src.initializers.constant_initializers import STFT as STFTInitializer +from keras.src.initializers.constant_initializers import STFT as stft from keras.src.initializers.constant_initializers import Zeros as Zeros from keras.src.initializers.constant_initializers import Zeros as zeros from keras.src.initializers.initializer import Initializer as Initializer diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index ac7e0e12cca5..6dbd68cc83a8 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -356,10 +356,12 @@ from keras.src.layers.rnn.time_distributed import ( TimeDistributed as TimeDistributed, ) +from keras.src.utils.jax_layer import FlaxLayer as FlaxLayer +from keras.src.utils.jax_layer import JaxLayer as JaxLayer +from keras.src.utils.torch_utils import TorchModuleWrapper as TorchModuleWrapper + + from keras.src.legacy.layers import AlphaDropout as AlphaDropout from keras.src.legacy.layers import RandomHeight as RandomHeight from keras.src.legacy.layers import RandomWidth as RandomWidth from keras.src.legacy.layers import ThresholdedReLU as ThresholdedReLU -from keras.src.utils.jax_layer import FlaxLayer as FlaxLayer -from keras.src.utils.jax_layer import JaxLayer as JaxLayer -from keras.src.utils.torch_utils import TorchModuleWrapper as TorchModuleWrapper diff --git a/keras/api/_tf_keras/keras/losses/__init__.py b/keras/api/_tf_keras/keras/losses/__init__.py index 73cc8e82db82..d96fafe148ef 100644 --- a/keras/api/_tf_keras/keras/losses/__init__.py +++ b/keras/api/_tf_keras/keras/losses/__init__.py @@ -4,16 +4,15 @@ since your modifications would be overwritten. """ -from keras.src.legacy.losses import Reduction as Reduction from keras.src.losses import deserialize as deserialize from keras.src.losses import get as get from keras.src.losses import serialize as serialize from keras.src.losses.loss import Loss as Loss -from keras.src.losses.losses import CTC as CTC from keras.src.losses.losses import BinaryCrossentropy as BinaryCrossentropy from keras.src.losses.losses import ( BinaryFocalCrossentropy as BinaryFocalCrossentropy, ) +from keras.src.losses.losses import CTC as CTC from keras.src.losses.losses import ( CategoricalCrossentropy as CategoricalCrossentropy, ) @@ -65,6 +64,15 @@ from keras.src.losses.losses import dice as dice from keras.src.losses.losses import hinge as hinge from keras.src.losses.losses import huber as huber +from keras.src.losses.losses import poisson as poisson +from keras.src.losses.losses import ( + sparse_categorical_crossentropy as sparse_categorical_crossentropy, +) +from keras.src.losses.losses import squared_hinge as squared_hinge +from keras.src.losses.losses import tversky as tversky + + +from keras.src.legacy.losses import Reduction as Reduction from keras.src.losses.losses import kl_divergence as KLD from keras.src.losses.losses import kl_divergence as kld from keras.src.losses.losses import kl_divergence as kullback_leibler_divergence @@ -77,9 +85,3 @@ from keras.src.losses.losses import mean_squared_error as mse from keras.src.losses.losses import mean_squared_logarithmic_error as MSLE from keras.src.losses.losses import mean_squared_logarithmic_error as msle -from keras.src.losses.losses import poisson as poisson -from keras.src.losses.losses import ( - sparse_categorical_crossentropy as sparse_categorical_crossentropy, -) -from keras.src.losses.losses import squared_hinge as squared_hinge -from keras.src.losses.losses import tversky as tversky diff --git a/keras/api/_tf_keras/keras/metrics/__init__.py b/keras/api/_tf_keras/keras/metrics/__init__.py index 11fd5db493cd..31f94c0cb1fe 100644 --- a/keras/api/_tf_keras/keras/metrics/__init__.py +++ b/keras/api/_tf_keras/keras/metrics/__init__.py @@ -17,18 +17,6 @@ from keras.src.losses.losses import categorical_hinge as categorical_hinge from keras.src.losses.losses import hinge as hinge from keras.src.losses.losses import huber as huber -from keras.src.losses.losses import kl_divergence as KLD -from keras.src.losses.losses import kl_divergence as kld -from keras.src.losses.losses import kl_divergence as kullback_leibler_divergence -from keras.src.losses.losses import log_cosh as logcosh -from keras.src.losses.losses import mean_absolute_error as MAE -from keras.src.losses.losses import mean_absolute_error as mae -from keras.src.losses.losses import mean_absolute_percentage_error as MAPE -from keras.src.losses.losses import mean_absolute_percentage_error as mape -from keras.src.losses.losses import mean_squared_error as MSE -from keras.src.losses.losses import mean_squared_error as mse -from keras.src.losses.losses import mean_squared_logarithmic_error as MSLE -from keras.src.losses.losses import mean_squared_logarithmic_error as msle from keras.src.losses.losses import poisson as poisson from keras.src.losses.losses import ( sparse_categorical_crossentropy as sparse_categorical_crossentropy, @@ -144,3 +132,17 @@ from keras.src.metrics.regression_metrics import ( RootMeanSquaredError as RootMeanSquaredError, ) + + +from keras.src.losses.losses import kl_divergence as KLD +from keras.src.losses.losses import kl_divergence as kld +from keras.src.losses.losses import kl_divergence as kullback_leibler_divergence +from keras.src.losses.losses import log_cosh as logcosh +from keras.src.losses.losses import mean_absolute_error as MAE +from keras.src.losses.losses import mean_absolute_error as mae +from keras.src.losses.losses import mean_absolute_percentage_error as MAPE +from keras.src.losses.losses import mean_absolute_percentage_error as mape +from keras.src.losses.losses import mean_squared_error as MSE +from keras.src.losses.losses import mean_squared_error as mse +from keras.src.losses.losses import mean_squared_logarithmic_error as MSLE +from keras.src.losses.losses import mean_squared_logarithmic_error as msle diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index e22715971d62..b7f4702e0752 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -220,9 +220,9 @@ from keras.src.ops.numpy import less_equal as less_equal from keras.src.ops.numpy import linspace as linspace from keras.src.ops.numpy import log as log +from keras.src.ops.numpy import log10 as log10 from keras.src.ops.numpy import log1p as log1p from keras.src.ops.numpy import log2 as log2 -from keras.src.ops.numpy import log10 as log10 from keras.src.ops.numpy import logaddexp as logaddexp from keras.src.ops.numpy import logaddexp2 as logaddexp2 from keras.src.ops.numpy import logical_and as logical_and diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py index 82b6b6dff363..96d38066175a 100644 --- a/keras/api/_tf_keras/keras/ops/numpy/__init__.py +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -106,9 +106,9 @@ from keras.src.ops.numpy import less_equal as less_equal from keras.src.ops.numpy import linspace as linspace from keras.src.ops.numpy import log as log +from keras.src.ops.numpy import log10 as log10 from keras.src.ops.numpy import log1p as log1p from keras.src.ops.numpy import log2 as log2 -from keras.src.ops.numpy import log10 as log10 from keras.src.ops.numpy import logaddexp as logaddexp from keras.src.ops.numpy import logaddexp2 as logaddexp2 from keras.src.ops.numpy import logical_and as logical_and diff --git a/keras/api/_tf_keras/keras/preprocessing/__init__.py b/keras/api/_tf_keras/keras/preprocessing/__init__.py index b11b4f3fd272..955c7f64df03 100644 --- a/keras/api/_tf_keras/keras/preprocessing/__init__.py +++ b/keras/api/_tf_keras/keras/preprocessing/__init__.py @@ -4,9 +4,6 @@ since your modifications would be overwritten. """ -from keras._tf_keras.keras.preprocessing import image as image -from keras._tf_keras.keras.preprocessing import sequence as sequence -from keras._tf_keras.keras.preprocessing import text as text from keras.src.utils.image_dataset_utils import ( image_dataset_from_directory as image_dataset_from_directory, ) @@ -16,3 +13,8 @@ from keras.src.utils.timeseries_dataset_utils import ( timeseries_dataset_from_array as timeseries_dataset_from_array, ) + + +from keras._tf_keras.keras.preprocessing import image as image +from keras._tf_keras.keras.preprocessing import sequence as sequence +from keras._tf_keras.keras.preprocessing import text as text diff --git a/keras/api/_tf_keras/keras/preprocessing/image/__init__.py b/keras/api/_tf_keras/keras/preprocessing/image/__init__.py index 43986878eb40..3af4024c3f61 100644 --- a/keras/api/_tf_keras/keras/preprocessing/image/__init__.py +++ b/keras/api/_tf_keras/keras/preprocessing/image/__init__.py @@ -4,6 +4,13 @@ since your modifications would be overwritten. """ +from keras.src.utils.image_utils import array_to_img as array_to_img +from keras.src.utils.image_utils import img_to_array as img_to_array +from keras.src.utils.image_utils import load_img as load_img +from keras.src.utils.image_utils import save_img as save_img +from keras.src.utils.image_utils import smart_resize as smart_resize + + from keras.src.legacy.preprocessing.image import ( DirectoryIterator as DirectoryIterator, ) @@ -35,8 +42,3 @@ from keras.src.legacy.preprocessing.image import random_shear as random_shear from keras.src.legacy.preprocessing.image import random_shift as random_shift from keras.src.legacy.preprocessing.image import random_zoom as random_zoom -from keras.src.utils.image_utils import array_to_img as array_to_img -from keras.src.utils.image_utils import img_to_array as img_to_array -from keras.src.utils.image_utils import load_img as load_img -from keras.src.utils.image_utils import save_img as save_img -from keras.src.utils.image_utils import smart_resize as smart_resize diff --git a/keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py b/keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py index 501c1f1123de..b721d18974d6 100644 --- a/keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py +++ b/keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py @@ -4,6 +4,9 @@ since your modifications would be overwritten. """ +from keras.src.utils.sequence_utils import pad_sequences as pad_sequences + + from keras.src.legacy.preprocessing.sequence import ( TimeseriesGenerator as TimeseriesGenerator, ) @@ -11,4 +14,3 @@ make_sampling_table as make_sampling_table, ) from keras.src.legacy.preprocessing.sequence import skipgrams as skipgrams -from keras.src.utils.sequence_utils import pad_sequences as pad_sequences diff --git a/keras/api/applications/__init__.py b/keras/api/applications/__init__.py index 7c030b36bd4e..d9bd5494e777 100644 --- a/keras/api/applications/__init__.py +++ b/keras/api/applications/__init__.py @@ -72,12 +72,12 @@ ) from keras.src.applications.nasnet import NASNetLarge as NASNetLarge from keras.src.applications.nasnet import NASNetMobile as NASNetMobile -from keras.src.applications.resnet import ResNet50 as ResNet50 from keras.src.applications.resnet import ResNet101 as ResNet101 from keras.src.applications.resnet import ResNet152 as ResNet152 -from keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2 +from keras.src.applications.resnet import ResNet50 as ResNet50 from keras.src.applications.resnet_v2 import ResNet101V2 as ResNet101V2 from keras.src.applications.resnet_v2 import ResNet152V2 as ResNet152V2 +from keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2 from keras.src.applications.vgg16 import VGG16 as VGG16 from keras.src.applications.vgg19 import VGG19 as VGG19 from keras.src.applications.xception import Xception as Xception diff --git a/keras/api/applications/resnet/__init__.py b/keras/api/applications/resnet/__init__.py index b8a25644e1d9..cae7b6a5c05b 100644 --- a/keras/api/applications/resnet/__init__.py +++ b/keras/api/applications/resnet/__init__.py @@ -4,9 +4,9 @@ since your modifications would be overwritten. """ -from keras.src.applications.resnet import ResNet50 as ResNet50 from keras.src.applications.resnet import ResNet101 as ResNet101 from keras.src.applications.resnet import ResNet152 as ResNet152 +from keras.src.applications.resnet import ResNet50 as ResNet50 from keras.src.applications.resnet import ( decode_predictions as decode_predictions, ) diff --git a/keras/api/applications/resnet_v2/__init__.py b/keras/api/applications/resnet_v2/__init__.py index 7f92dd56f374..45b0ced50c2b 100644 --- a/keras/api/applications/resnet_v2/__init__.py +++ b/keras/api/applications/resnet_v2/__init__.py @@ -4,9 +4,9 @@ since your modifications would be overwritten. """ -from keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2 from keras.src.applications.resnet_v2 import ResNet101V2 as ResNet101V2 from keras.src.applications.resnet_v2 import ResNet152V2 as ResNet152V2 +from keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2 from keras.src.applications.resnet_v2 import ( decode_predictions as decode_predictions, ) diff --git a/keras/api/initializers/__init__.py b/keras/api/initializers/__init__.py index e88013d97315..6cd6240e1bdf 100644 --- a/keras/api/initializers/__init__.py +++ b/keras/api/initializers/__init__.py @@ -7,9 +7,6 @@ from keras.src.initializers import deserialize as deserialize from keras.src.initializers import get as get from keras.src.initializers import serialize as serialize -from keras.src.initializers.constant_initializers import STFT as STFT -from keras.src.initializers.constant_initializers import STFT as STFTInitializer -from keras.src.initializers.constant_initializers import STFT as stft from keras.src.initializers.constant_initializers import Constant as Constant from keras.src.initializers.constant_initializers import Constant as constant from keras.src.initializers.constant_initializers import Identity as Identity @@ -19,6 +16,9 @@ from keras.src.initializers.constant_initializers import Identity as identity from keras.src.initializers.constant_initializers import Ones as Ones from keras.src.initializers.constant_initializers import Ones as ones +from keras.src.initializers.constant_initializers import STFT as STFT +from keras.src.initializers.constant_initializers import STFT as STFTInitializer +from keras.src.initializers.constant_initializers import STFT as stft from keras.src.initializers.constant_initializers import Zeros as Zeros from keras.src.initializers.constant_initializers import Zeros as zeros from keras.src.initializers.initializer import Initializer as Initializer diff --git a/keras/api/losses/__init__.py b/keras/api/losses/__init__.py index 60414fe301d0..93da12285afe 100644 --- a/keras/api/losses/__init__.py +++ b/keras/api/losses/__init__.py @@ -8,11 +8,11 @@ from keras.src.losses import get as get from keras.src.losses import serialize as serialize from keras.src.losses.loss import Loss as Loss -from keras.src.losses.losses import CTC as CTC from keras.src.losses.losses import BinaryCrossentropy as BinaryCrossentropy from keras.src.losses.losses import ( BinaryFocalCrossentropy as BinaryFocalCrossentropy, ) +from keras.src.losses.losses import CTC as CTC from keras.src.losses.losses import ( CategoricalCrossentropy as CategoricalCrossentropy, ) diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index e22715971d62..b7f4702e0752 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -220,9 +220,9 @@ from keras.src.ops.numpy import less_equal as less_equal from keras.src.ops.numpy import linspace as linspace from keras.src.ops.numpy import log as log +from keras.src.ops.numpy import log10 as log10 from keras.src.ops.numpy import log1p as log1p from keras.src.ops.numpy import log2 as log2 -from keras.src.ops.numpy import log10 as log10 from keras.src.ops.numpy import logaddexp as logaddexp from keras.src.ops.numpy import logaddexp2 as logaddexp2 from keras.src.ops.numpy import logical_and as logical_and diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py index 82b6b6dff363..96d38066175a 100644 --- a/keras/api/ops/numpy/__init__.py +++ b/keras/api/ops/numpy/__init__.py @@ -106,9 +106,9 @@ from keras.src.ops.numpy import less_equal as less_equal from keras.src.ops.numpy import linspace as linspace from keras.src.ops.numpy import log as log +from keras.src.ops.numpy import log10 as log10 from keras.src.ops.numpy import log1p as log1p from keras.src.ops.numpy import log2 as log2 -from keras.src.ops.numpy import log10 as log10 from keras.src.ops.numpy import logaddexp as logaddexp from keras.src.ops.numpy import logaddexp2 as logaddexp2 from keras.src.ops.numpy import logical_and as logical_and diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index c1f4e8066e37..c8d2e2dce63d 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -2742,6 +2742,41 @@ def round(x, decimals=0): def tile(x, repeats): x = convert_to_tensor(x) + + # Check if repeats contains only concrete integers + # If so, keep it as a Python list/tuple for better shape inference + try: + if isinstance(repeats, (list, tuple)): + # Try to extract concrete integer values + concrete_repeats = [] + for r in repeats: + if isinstance(r, int): + concrete_repeats.append(r) + elif hasattr(r, 'numpy') and r.shape == (): + # Scalar tensor with concrete value + concrete_repeats.append(int(r.numpy())) + else: + # Not a concrete value, fall back to tensor path + concrete_repeats = None + break + + if concrete_repeats is not None: + # Use concrete repeats directly for better shape inference + repeats = concrete_repeats + # Pad or trim repeats to match x rank + x_rank = x.shape.rank + if x_rank is not None: + if len(repeats) < x_rank: + repeats = [1] * (x_rank - len(repeats)) + repeats + elif len(repeats) > x_rank: + # Need to reshape x to match repeats length + x_shape_list = [1] * (len(repeats) - x_rank) + [d if d is not None else -1 for d in x.shape.as_list()] + x = tf.reshape(x, x_shape_list) + return tf.tile(x, repeats) + except (AttributeError, TypeError, tf.errors.OperatorNotAllowedInGraphError): + pass + + # Original dynamic implementation for non-concrete repeats repeats = tf.reshape(convert_to_tensor(repeats, dtype="int32"), [-1]) repeats_size = tf.size(repeats) repeats = tf.pad( diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 5190ff2cd807..14ba45d2db5f 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -6411,6 +6411,15 @@ def compute_output_spec(self, x): repeats = self.repeats if isinstance(repeats, int): repeats = [repeats] + + # Convert repeats to list if it's a tuple or other iterable + # and extract concrete integer values + if not isinstance(repeats, list): + try: + repeats = list(repeats) + except TypeError: + repeats = [repeats] + if len(x_shape) > len(repeats): repeats = [1] * (len(x_shape) - len(repeats)) + repeats else: @@ -6418,10 +6427,15 @@ def compute_output_spec(self, x): output_shape = [] for x_size, repeat in zip(x_shape, repeats): + # Check if repeat is a concrete integer value + # If it's a symbolic tensor or unknown, we can't infer the size if x_size is None: output_shape.append(None) - else: + elif isinstance(repeat, int): output_shape.append(x_size * repeat) + else: + # repeat is symbolic (e.g., KerasTensor, tf.Tensor, etc.) + output_shape.append(None) return KerasTensor(output_shape, dtype=x.dtype) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 42a8c37b49e3..79bf0816ea6b 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -1820,6 +1820,10 @@ def test_tile(self): self.assertEqual(knp.tile(x, [2]).shape, (None, 6)) self.assertEqual(knp.tile(x, [1, 2]).shape, (None, 6)) self.assertEqual(knp.tile(x, [2, 1, 2]).shape, (2, None, 6)) + + # Test with multi-dimensional input + x = KerasTensor((None, 3, 2, 2)) + self.assertEqual(knp.tile(x, [1, 2, 1, 1]).shape, (None, 6, 2, 2)) def test_trace(self): x = KerasTensor((None, 3, None, 5)) @@ -9507,3 +9511,23 @@ def call(self, x): model.compile(jit_compile=jit_compile) model.predict(np.random.randn(1, 8)) + + def test_tile_shape_inference_in_layer(self): + """Test that ops.tile properly infers output shape when used in a Layer. + + This is a regression test for issue #20914 where TensorFlow backend + would return all-None shapes when tile was called inside a Layer's + call method with concrete integer repeats. + """ + class TileLayer(keras.layers.Layer): + def call(self, x): + # Use concrete integer repeats + repeats = [1, 2, 1, 1] + return knp.tile(x, repeats) + + inputs = keras.Input(shape=(3, 2, 2)) + output = TileLayer()(inputs) + + # With the fix, output shape should be (None, 6, 2, 2) + # Before the fix, it was (None, None, None, None) + self.assertEqual(output.shape, (None, 6, 2, 2))