In [None]:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')

from plotly import express as px
from typing import List, Tuple
import pickle

In [None]:
import sys
sys.path.append('../../../../')

## load

In [None]:
with open('/home/iscb/wolfson/doririmon/home/order/ubinet/repo/ubinet/datasets/patch_to_score/data_for_training/03_04_with_pesto_and_coord/folds_training_dicts.pkl', 'rb') as f:
    folds_training_dicts = pickle.load(f)

In [None]:
fold = folds_training_dicts[0]
fold.keys()

In [None]:
input_data = fold['components_train']
coordinates = fold['coordinates_train']
size_value = fold['sizes_train']
n_patches_hot_encoded_value = fold['num_patches_train']
max_number_of_patches = 10

## analyze

In [None]:
pairwise_distances = tf.norm(tf.expand_dims(coordinates, axis=1) - tf.expand_dims(coordinates, axis=2), axis=-1)
pairwise_distances = tf.cast(pairwise_distances, tf.float32)

In [None]:
px.imshow(pairwise_distances[5,:,:])

In [None]:
flat_distances = pairwise_distances.numpy().flatten()
flat_distances = flat_distances[flat_distances != 0]

In [None]:
len(flat_distances)

In [None]:
px.histogram(flat_distances[::100], histnorm='percent')

## RBF layer

In [None]:
import tensorflow as tf


class RBFGaussianEmbedding(tf.keras.layers.Layer):
    def __init__(self, num_kernels: int, 
                 init_std: float, 
                 init_range: tuple, 
                 **kwargs):
        """
        Args:
            num_kernels: Number of RBF kernels (channel dimension C)
            init_std: Initial value for std of Gaussian
            init_range: Tuple of (min, max) for initializing centers uniformly
        """
        super().__init__(**kwargs)
        self.supports_masking = True
        self.num_kernels = num_kernels
        self.init_std = init_std
        self.init_range = init_range

    def build(self, input_shape):
        # Learnable RBF centers (shape: [C])
        low, high = self.init_range
        centers_init = tf.linspace(low, high, self.num_kernels)
        self.centers = tf.Variable(
            initial_value=centers_init, trainable=True, name="rbf_centers"
        )

        # Learnable RBF stds (shape: [C])
        self.stds = tf.Variable(
            initial_value=tf.fill([self.num_kernels], self.init_std),
            trainable=True,
            name="rbf_stds"
        )

    def call(self, inputs, training=False, mask=None):
        """
        Args:
            inputs: Tensor of shape (B, N, N) - pairwise distances
                    Distances of exactly 0 are considered masked
        Returns:
            Tensor of shape (B, N, N, C)
        """
        distance_matrix = inputs  # shape: (B, N, N)
        B, N, _ = tf.unstack(tf.shape(distance_matrix))

        # Expand dims to shape (B, N, N, 1) for broadcasting
        dists = tf.expand_dims(distance_matrix, axis=-1)  # (B, N, N, 1)

        # Reshape centers and stds for broadcasting
        centers = tf.reshape(self.centers, shape=[1, 1, 1, self.num_kernels])  # (1, 1, 1, C)
        stds = tf.reshape(self.stds, shape=[1, 1, 1, self.num_kernels])        # (1, 1, 1, C)

        # Compute RBF: exp( - (d - c)^2 / (2 * std^2) )
        rbf = tf.exp(- tf.square(dists - centers) / (2.0 * tf.square(stds)))  # (B, N, N, C)

        if mask is not None:
            mask = tf.cast(mask, dtype=rbf.dtype)
            mask = tf.expand_dims(mask, axis=-1)  # (B, N, 1)
            mask = tf.expand_dims(mask, axis=-1)  # (B, N, 1, 1)
            rbf = rbf * mask  # (B, N, N, C)

        return rbf
    
    def compute_mask(self, inputs, mask=None):
        # Just return the input mask unchanged
        return mask

In [None]:
rbf_layer = RBFGaussianEmbedding(num_kernels=64, init_std=0.5, init_range=(0.0, 100.0))

## work on masking

In [None]:
from model import create_masked_inputs, create_broadcasted_features, mask_inputs

features, pairwise_distances = create_masked_inputs(
        input_data, coordinates, size_value, n_patches_hot_encoded_value, max_number_of_patches)

In [None]:
pairwise_distances._keras_mask[0]

In [None]:
px.imshow(pairwise_distances[0])

## RBF output visualization

In [None]:
rbf_layer(pairwise_distances)[0][0][-1]

In [None]:
import tensorflow as tf
import plotly.graph_objs as go
import numpy as np


# Get values of centers and stds (convert to NumPy)
centers = rbf_layer.centers.numpy()
stds = rbf_layer.stds.numpy()

# Create x values for distances
x = np.linspace(0, 100, 1_500)

# Create traces for each kernel
traces = []
for i, (c, s) in enumerate(zip(centers, stds)):
    y = np.exp(-((x - c) ** 2) / (2 * s ** 2))
    trace = go.Scatter(x=x, y=y, mode='lines', name=f'Kernel {i+1}<br>c={c:.2f}, σ={s:.2f}')
    traces.append(trace)

# Plot
layout = go.Layout(
    title='RBF Gaussian Kernels',
    xaxis=dict(title='Distance'),
    yaxis=dict(title='RBF Activation')
)
fig = go.Figure(data=traces, layout=layout)
fig.show()
