In [2]:
import logging
logging.getLogger("tensorflow").setLevel(logging.ERROR)

import os
import tensorflow as tf
import numpy as np

In [39]:
class AmsSketch:
    """
    AMS Sketch class for approximate second moment estimation.
    """

    def __init__(self, depth=5, width=250, with_seed=False):
        self.depth = tf.constant(depth)
        self.width = tf.constant(width)
        
        if with_seed:
            self.F = tf.random.stateless_uniform(shape=(6, depth), minval=0, maxval=(1 << 31) - 1, dtype=tf.int32, seed=(1, 2))
        else:
            self.F = tf.random.uniform(shape=(6, depth), minval=0, maxval=(1 << 31) - 1, dtype=tf.int32)
            
        self.zeros_sketch = tf.zeros(shape=(self.depth, self.width), dtype=tf.float32)

        self.precomputed_dict = {}

    def precompute(self, d):
        pos_tensor = self.tensor_hash31(tf.range(d), self.F[0], self.F[1]) % self.width  # shape=(d, 5)

        self.precomputed_dict[('four', d)] = tf.cast(self.tensor_fourwise(tf.range(d)),
                                                     dtype=tf.float32)  # shape=(d, 5)

        range_tensor = tf.range(self.depth)  # shape=(5,)

        # Expand dimensions to create a 2D tensor with shape (1, `self.depth`)
        range_tensor_expanded = tf.expand_dims(range_tensor, 0)  # shape=(1, 5)

        # Use tf.tile to repeat the range `d` times
        repeated_range_tensor = tf.tile(range_tensor_expanded, [d, 1])  # shape=(d, 5)

        # shape=(`d`, `self.depth`, 2)
        self.precomputed_dict[('indices', d)] = tf.stack([repeated_range_tensor, pos_tensor],
                                                         axis=-1)  # shape=(d, 5, 2)

    @staticmethod
    def hash31(x, a, b):
        r = a * x + b
        fold = tf.bitwise.bitwise_xor(tf.bitwise.right_shift(r, 31), r)
        return tf.bitwise.bitwise_and(fold, 2147483647)

    @staticmethod
    def tensor_hash31(x, a, b):  # GOOD
        """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., tf.range(d)) """

        # Reshape x to have an extra dimension, resulting in a shape of (k, 1)
        x_reshaped = tf.expand_dims(x, axis=-1)

        # shape=(`v_dim`, 7)
        r = tf.multiply(a, x_reshaped) + b

        fold = tf.bitwise.bitwise_xor(tf.bitwise.right_shift(r, 31), r)

        return tf.bitwise.bitwise_and(fold, 2147483647)

    def tensor_fourwise(self, x):
        """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., tf.range(d)) """

        # 1st use the tensor hash31
        in1 = self.tensor_hash31(x, self.F[2], self.F[3])  # shape = (`x_dim`,  `self.depth`)

        # 2st use the tensor hash31
        in2 = self.tensor_hash31(x, in1, self.F[4])  # shape = (`x_dim`,  `self.depth`)

        # 3rd use the tensor hash31
        in3 = self.tensor_hash31(x, in2, self.F[5])  # shape = (`x_dim`,  `self.depth`)

        in4 = tf.bitwise.bitwise_and(in3, 32768)  # shape = (`x_dim`,  `self.depth`)

        return 2 * (tf.bitwise.right_shift(in4, 15)) - 1  # shape = (`x_dim`,  `self.depth`)

    def fourwise(self, x):
        result = 2 * (tf.bitwise.right_shift(tf.bitwise.bitwise_and(
            self.hash31(self.hash31(self.hash31(x, self.F[2], self.F[3]), x, self.F[4]), x, self.F[5]), 32768), 15)) - 1
        return result

    def sketch_for_vector(self, v):
        """ Extremely efficient computation of sketch with only using tensors.

        Args:
        - v (tf.Tensor): Vector to sketch. Shape=(d,).

        Returns:
        - tf.Tensor: An AMS - Sketch. Shape=(`depth`, `width`).
        """

        d = v.shape[0]

        if ('four', d) not in self.precomputed_dict:
            self.precompute(d)

        return self._sketch_for_vector(v, self.precomputed_dict[('four', d)], self.precomputed_dict[('indices', d)])

    @tf.function
    def _sketch_for_vector(self, v, four, indices):
        v_expand = tf.expand_dims(v, axis=-1)  # shape=(d, 1)

        # shape=(d, 5): +- for each value v_i , i = 1, ..., d
        deltas_tensor = tf.multiply(four, v_expand)

        sketch = tf.tensor_scatter_nd_add(self.zeros_sketch, indices, deltas_tensor)  # shape=(5, 250)

        return sketch

    @staticmethod
    def estimate_euc_norm_squared(sketch):
        """ Estimate the Euclidean norm squared of a vector using its AMS sketch.

        Args:
        - sketch (tf.Tensor): AMS sketch of a vector. Shape=(`depth`, `width`).

        Returns:
        - tf.Tensor: Estimated squared Euclidean norm.
        """

        norm_sq_rows = tf.reduce_sum(tf.square(sketch), axis=1)
        return np.median(norm_sq_rows)

In [32]:
tf.random.uniform(shape=(6, 250), minval=0, maxval=(1 << 31) - 1, dtype=tf.int32)

<tf.Tensor: shape=(6, 250), dtype=int32, numpy=
array([[1410560183, 1907384848, 1827937125, ..., 1331422091,   27810037,
         546561963],
       [ 613211234, 1919293827,  526580420, ..., 1906185497,  712411972,
         438014311],
       [1209732390,  283849182,  505917244, ...,   92051480,  902810865,
        1141802837],
       [ 484906441,  517211724, 1520503281, ...,  972311044, 1674748914,
        1678547307],
       [  14692046, 1753925916, 1789949480, ...,  260130179,  306354629,
         133210961],
       [ 155437793, 1199598632, 1143681153, ...,  217408221, 1899174095,
          64727188]], dtype=int32)>

In [37]:
tf.random.stateless_uniform(shape=(6, 250), minval=0, maxval=(1 << 31) - 1, dtype=tf.int32, seed=(1, 2))

<tf.Tensor: shape=(6, 250), dtype=int32, numpy=
array([[1105988140, 1738052849, 1811907647, ...,  989478327,  628720460,
         579147759],
       [1858208210, 1188421655,  448884013, ..., 1882590668,  987193485,
        1639371654],
       [2109818543, 1346033867,   47518302, ..., 2050294505,  776851798,
         686051164],
       [1489305460,  476799945, 1354703521, ..., 1850962211, 2076744425,
         737011359],
       [1820924812,  579203535, 1225972498, ..., 1336018225, 2135331679,
         308569661],
       [2126544770,  826392160,  365741576, ..., 1267200896,  797299305,
        1884681349]], dtype=int32)>

In [29]:
tf.random.set_seed(None)

In [48]:
v1 = tf.random.uniform(shape=(d,))

In [41]:
d = 1000

In [42]:
s1 = AmsSketch(with_seed=True)

In [43]:
s2 = AmsSketch(with_seed=True)

In [44]:
s1.precompute(d)

In [45]:
s2.precompute(d)

In [46]:
tf.reduce_sum(s1.precomputed_dict[('four', d)] - s2.precomputed_dict[('four', d)])

<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

In [47]:
tf.reduce_sum(s1.precomputed_dict[('indices', d)] - s2.precomputed_dict[('indices', d)])

<tf.Tensor: shape=(), dtype=int32, numpy=0>

In [49]:
sk1 = s1.sketch_for_vector(v1)

In [51]:
sk2 = s2.sketch_for_vector(v1)

In [52]:
tf.reduce_sum(sk1 - sk2)

<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

In [53]:
AmsSketch.estimate_euc_norm_squared(sk1)

338.2024

In [54]:
AmsSketch.estimate_euc_norm_squared(sk2)

338.2024