In [96]:
from typing import Tuple
import tensorflow as tf
from utils import GlobalSumPooling

import tensorflow as tf
import tensorflow.keras.layers as kl
import numpy as np

In [172]:
def broadcast_shape(x, max_number_of_patches: int) -> tf.Tensor:
    x_expanded = tf.expand_dims(x, axis=1)
    x_broadcasted = tf.broadcast_to(x_expanded, [tf.shape(x)[0], max_number_of_patches, tf.shape(x)[-1]])
    return x_broadcasted

In [176]:

def build_model(m_a: int, m_c: int, n_layers: int, input_shape: Tuple[int, int],
                max_number_of_patches: int) -> tf.keras.models.Model:
    '''
    :param m_a: size of the hidden layers in the MLP of the components
    :param m_c: size of the hidden layers in the MLP of the concatenated global sum output and size + n_patches MLP output
    :param n_layers: number of layers in each of the MLPs
    :param input_shape: shape of the input data (number of patches, number of features)
    :param max_number_of_patches: maximum number of patches
    :return: a Keras model
    '''
    input_data = tf.keras.Input(shape=input_shape, name='patches_input')
    size_value = tf.keras.Input(shape=(1,), name='extra_value_input')
    n_patches_hot_encoded_value = tf.keras.Input(
        shape=(max_number_of_patches + 1,), name='hot_encoded_value_input')
    n_patches = tf.argmax(n_patches_hot_encoded_value, axis=1)[..., None]
    n_patches = tf.cast(n_patches, tf.float32)
    
    n_patches_broadcased = broadcast_shape(n_patches, max_number_of_patches)
    size_broadcased = broadcast_shape(size_value, max_number_of_patches)
    
    concat_input_data = tf.keras.layers.Concatenate(axis=-1)([input_data, n_patches_broadcased, size_broadcased])
    masked_input = tf.keras.layers.Masking(mask_value=0.0)(concat_input_data)

    currentOutput = masked_input
    for i in range(n_layers):
        dense_output = tf.keras.layers.Dense(
            m_a, activation='linear')(currentOutput)
        batchNorm = tf.keras.layers.BatchNormalization(
            momentum=0.75)(dense_output)
        activation = tf.keras.layers.ReLU()(batchNorm)
        currentOutput = activation

    global_pooling_output = GlobalSumPooling(
        data_format='channels_last')(currentOutput)

    currentOutput = global_pooling_output
    for i in range(n_layers):
        dense_output = tf.keras.layers.Dense(
            m_c, activation='linear')(currentOutput)
        batchNorm = tf.keras.layers.BatchNormalization(
            momentum=0.75)(dense_output)
        activation = tf.keras.layers.ReLU()(batchNorm)
        currentOutput = activation

    before_sigmoid_output = currentOutput

    output = tf.keras.layers.Dense(
        1, activation='sigmoid')(before_sigmoid_output)
    model = tf.keras.Model(
        inputs=[input_data, size_value, n_patches_hot_encoded_value], outputs=output)
    return model


In [177]:
model = build_model(
    m_a=128, m_c=64, n_layers=4, input_shape=(10, 9), max_number_of_patches=10)

In [178]:
model(input_example)

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[0.5       ],
       [0.67460376]], dtype=float32)>

In [3]:
tf.config.set_visible_devices([], 'GPU')

2025-04-04 17:07:32.152851: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2027] TensorFlow was not built with CUDA kernel binaries compatible with compute capability 9.0. CUDA kernels will be jit-compiled from PTX, which could take 30 minutes or longer.
2025-04-04 17:07:32.154004: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2027] TensorFlow was not built with CUDA kernel binaries compatible with compute capability 9.0. CUDA kernels will be jit-compiled from PTX, which could take 30 minutes or longer.
2025-04-04 17:07:32.155136: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2027] TensorFlow was not built with CUDA kernel binaries compatible with compute capability 9.0. CUDA kernels will be jit-compiled from PTX, which could take 30 minutes or longer.
2025-04-04 17:07:32.156287: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2027] TensorFlow was not built with CUDA kernel binaries compatible with compute capability 9.0. CUDA kernels will be jit-compiled from PTX, whi

In [4]:
input_shape = (10, 9)
n_layers = 5
m_a = 1024
m_c = 1024
max_number_of_patches = 10

In [5]:
input_data = tf.keras.Input(shape=input_shape, name='patches_input')

In [6]:
input_data

<KerasTensor: shape=(None, 10, 9) dtype=float32 (created by layer 'patches_input')>

In [75]:
import numpy as np

n_patches_one_hot = np.zeros((2, max_number_of_patches + 1))
n_patches_one_hot[0][0] = 1
n_patches_one_hot[1][5] = 1
input_example = [tf.zeros((2, 10, 9)), tf.zeros((2, 1)), tf.convert_to_tensor(n_patches_one_hot)]

In [40]:
size_value = tf.keras.Input(shape=(1,), name='extra_value_input')
n_patches_hot_encoded_value = tf.keras.Input(
    shape=(max_number_of_patches + 1,), name='hot_encoded_value_input')
n_patches = tf.argmax(n_patches_hot_encoded_value, axis=1)[..., None]

In [121]:
size_value

<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'extra_value_input')>

In [122]:
n_patches

<KerasTensor: shape=(None, 1) dtype=int64 (created by layer 'tf.__operators__.getitem_1')>

In [168]:
import tensorflow as tf
import tensorflow.keras.layers as kl
import numpy as np

x = tf.keras.Input([1])
y = tf.keras.layers.Reshape([1, 1])(x)  # Need to add empty dims before broadcasting
# Retain the batch and depth dimensions, but broadcast along H and W
print(y.shape)
broadcasted_shape = tf.where([True, False, True],
                           tf.shape(y), [0, max_number_of_patches, 0])
y = tf.broadcast_to(y, broadcasted_shape)  # Broadcast to shape [None, 10, 1]
print(y.shape)

model = tf.keras.Model(inputs=x, outputs=y)

print(model(np.random.random(size=(8, 1))).shape)

(None, 1, 1)
(None, None, None)
(8, 10, 1)


In [170]:
import tensorflow as tf

# Example input tensor of shape [None, 1]
x = tf.keras.Input(shape=(1,))  # Placeholder for the input

# Expand dimensions to [None, 1, 1]
x_expanded = tf.expand_dims(x, axis=1)

# Broadcast to shape [None, 10, 1]
x_broadcasted = tf.broadcast_to(x_expanded, [tf.shape(x)[0], 10, 1])

In [171]:
x_broadcasted

<KerasTensor: shape=(None, 10, 1) dtype=float32 (created by layer 'tf.broadcast_to_85')>

In [163]:
broadcast_shape(x, 10)

<KerasTensor: shape=(None, None, None) dtype=float32 (created by layer 'tf.broadcast_to_78')>

In [41]:
size_value

<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'extra_value_input')>

In [57]:
n_patches.shape[1]

1

In [32]:
n_patches_one_hot

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])

In [33]:
n_patches = tf.argmax(input_example[-1], axis=1)

In [34]:
n_patches

<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 5])>

In [8]:
size_value

<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'extra_value_input')>

In [None]:

input_data = tf.keras.layers.Concatenate()([input_data, size_value, n_patches_hot_encoded_value])
masked_input = tf.keras.layers.Masking(mask_value=0.0)(input_data)