In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import time
import tensorflow as tf

import numpy as np

In [2]:
from tensorflow.experimental import ExtensionType

class AmsSketch(ExtensionType):
    depth: int
    width: int
    F: tf.Tensor
        
        
    def __init__(self, depth=7, width=1500):
        self.depth = depth
        self.width = width
        self.F = tf.random.uniform(shape=(6, depth), minval=0, maxval=(1 << 31) - 1, dtype=tf.int32)

        
    @tf.function
    def hash31(self, x, a, b):

        r = a * x + b
        fold = tf.bitwise.bitwise_xor(tf.bitwise.right_shift(r, 31), r)
        return tf.bitwise.bitwise_and(fold, 2147483647)
    
    
    @tf.function
    def tensor_hash31(self, x, a, b): # GOOD
        """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., tf.range(d)) """

        # Reshape x to have an extra dimension, resulting in a shape of (k, 1)
        x_reshaped = tf.expand_dims(x, axis=-1)

        # shape=(`v_dim`, 7)
        r = tf.multiply(a, x_reshaped) + b

        fold = tf.bitwise.bitwise_xor(tf.bitwise.right_shift(r, 31), r)
        
        return tf.bitwise.bitwise_and(fold, 2147483647)
    
    
    @tf.function
    def tensor_fourwise(self, x):
        """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., tf.range(d)) """
        # 1st use the tensor hash31
        in1 = self.tensor_hash31(x, self.F[2], self.F[3])  # (`x_dim`, 7)
        
        # 2nd (notice we swap the first two params, no change really)
        in2 = self.tensor_hash31(x, in1, self.F[4])  # (`x_dim`, 7)
        
        in3 = self.tensor_hash31(x, in2, self.F[5])  # (`x_dim`, 7)
        
        in4 = tf.bitwise.bitwise_and(in3, 32768)  # (`x_dim`, 7)
        
        return 2 * (tf.bitwise.right_shift(in4, 15)) - 1  # (`x_dim`, 7)
        
        
    @tf.function
    def fourwise(self, x):

        result = 2 * (tf.bitwise.right_shift(tf.bitwise.bitwise_and(self.hash31(self.hash31(self.hash31(x, self.F[2], self.F[3]), x, self.F[4]), x, self.F[5]), 32768), 15)) - 1
        return result
    
    
    @tf.function
    def sketch_for_vector(self, v):
        """ Extremely efficient computation of sketch with only using tensors. """
        
        sketch = tf.zeros(shape=(self.depth, self.width), dtype=tf.float32)
        
        len_v = v.shape[0]
        
        pos_tensor = self.tensor_hash31(tf.range(len_v), self.F[0], self.F[1]) % self.width
        
        v_expand = tf.expand_dims(v, axis=-1)
        
        deltas_tensor = tf.multiply(tf.cast(self.tensor_fourwise(tf.range(len_v)), dtype=tf.float32), v_expand)
        
        range_tensor = tf.range(self.depth)
        
        # Expand dimensions to create a 2D tensor with shape (1, depth)
        range_tensor_expanded = tf.expand_dims(range_tensor, 0)

        # Use tf.tile to repeat the tensor `len_v` times
        repeated_range_tensor = tf.tile(range_tensor_expanded, [len_v, 1])
        
        # shape=(`len_v`, 7, 2)
        indices = tf.stack([repeated_range_tensor, pos_tensor], axis=-1)
        
        sketch = tf.tensor_scatter_nd_add(sketch, indices, deltas_tensor)
        
        return sketch
    
    
    @tf.function
    def sketch_for_vector2(self, v):
        """ Bad implementation for tensorflow. """

        sketch = tf.zeros(shape=(self.depth, self.width), dtype=tf.float32)

        for i in tf.range(tf.shape(v)[0], dtype=tf.int32):
            pos = self.hash31(i, self.F[0], self.F[1]) % self.width
            delta = tf.cast(self.fourwise(i), dtype=tf.float32) * v[i]
            indices_to_update = tf.stack([tf.range(self.depth, dtype=tf.int32), pos], axis=1)
            sketch = tf.tensor_scatter_nd_add(sketch, indices_to_update, delta)

        return sketch
        
    
    @staticmethod
    @tf.function
    def estimate_euc_norm_squared(sketch):

        @tf.function
        def _median(v):
            """ Median of tensor `v` with shape=(n,). Note: Suboptimal O(nlogn) but it's ok bcz n = `depth`"""
            length = tf.shape(v)[0]
            sorted_v = tf.sort(v)
            middle = length // 2

            return tf.cond(
                tf.equal(length % 2, 0),
                lambda: (sorted_v[middle - 1] + sorted_v[middle]) / 2.0,
                lambda: sorted_v[middle]
            )

        return _median(tf.reduce_sum(tf.square(sketch), axis=1))

In [3]:
ams_sketch = AmsSketch(
    depth=7,
    width=1700
)

In [7]:
vec = tf.random.uniform(shape=(10000,), minval=0, maxval=2, dtype=tf.float32)

In [11]:
#%%timeit -r 1

sketch = ams_sketch.sketch_for_vector(vec)

In [12]:
sketch

<tf.Tensor: shape=(7, 1700), dtype=float32, numpy=
array([[-0.71613455,  3.5107355 , -1.4464993 , ..., -0.14513016,
         1.4895973 ,  1.1736515 ],
       [ 2.2559028 , -0.30160308, -2.5438793 , ...,  4.0859685 ,
        -1.2823029 ,  4.0744076 ],
       [ 0.        ,  3.2889836 ,  4.068681  , ...,  0.03727531,
        -1.8788319 ,  0.        ],
       ...,
       [-1.0088367 , -0.1262784 , -3.9266415 , ..., -0.5174842 ,
        -0.2585826 , -1.3709421 ],
       [ 2.2302265 ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  1.7632725 ],
       [-5.12976   ,  0.23835802, -2.274409  , ..., -3.8424373 ,
        -0.3808372 ,  0.10876369]], dtype=float32)>

In [14]:
#%%timeit -r 1

sketch2 = ams_sketch.sketch_for_vector2(vec)

In [15]:
sketch2

<tf.Tensor: shape=(7, 1700), dtype=float32, numpy=
array([[-0.71613455,  3.5107355 , -1.4464993 , ..., -0.14513016,
         1.4895973 ,  1.1736515 ],
       [ 2.2559028 , -0.30160308, -2.5438793 , ...,  4.0859685 ,
        -1.2823029 ,  4.0744076 ],
       [ 0.        ,  3.2889836 ,  4.068681  , ...,  0.03727531,
        -1.8788319 ,  0.        ],
       ...,
       [-1.0088367 , -0.1262784 , -3.9266415 , ..., -0.5174842 ,
        -0.2585826 , -1.3709421 ],
       [ 2.2302265 ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  1.7632725 ],
       [-5.12976   ,  0.23835802, -2.274409  , ..., -3.8424373 ,
        -0.3808372 ,  0.10876369]], dtype=float32)>

In [16]:
ams_sketch.estimate_euc_norm_squared(sketch)

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


<tf.Tensor: shape=(), dtype=float32, numpy=13596.997>

In [17]:
tf.reduce_sum(tf.square(vec))

<tf.Tensor: shape=(), dtype=float32, numpy=13250.656>

In [314]:
depth = 7
width = 1700
F = ams_sketch.F

In [315]:
# first deal with `pos = ` 

In [412]:
@tf.function
def tensor_hash31(x, a, b): # GOOD
    """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., tf.range(d)) """
    
    # Reshape x to have an extra dimension, resulting in a shape of (k, 1)
    x_reshaped = tf.expand_dims(x, axis=-1)
    
    # shape=(`v_dim`, 7)
    r = tf.multiply(a, x_reshaped) + b
    
    fold = tf.bitwise.bitwise_xor(tf.bitwise.right_shift(r, 31), r)
    return tf.bitwise.bitwise_and(fold, 2147483647)

In [414]:
@tf.function
def fourwise(x):

    result = 2 * (in5) - 1
    return result

In [415]:
tensor_hash31(tf.range(3), F[0], F[1])

<tf.Tensor: shape=(3, 7), dtype=int32, numpy=
array([[1977271226,  853021704,  389148655,   99585761, 1283153601,
        2109570381, 1216126172],
       [2096472306, 1031553032,  917739298,  229173230, 1373326314,
        1450757621, 1755361742],
       [1875248543, 1210084360, 1446329941,  358760699,  265161065,
         716118328, 2000369983]], dtype=int32)>

In [416]:
ams_sketch.fourwise(5)

<tf.Tensor: shape=(7,), dtype=int32, numpy=array([-1,  1, -1, -1,  1,  1, -1], dtype=int32)>

In [432]:
x = tf.range(3) # will be the 

In [435]:
# 1st use the tensor hash31
in1 = tensor_hash31(x, F[2], F[3]) # (`x_dim`, 7)

In [436]:
in1

<tf.Tensor: shape=(3, 7), dtype=int32, numpy=
array([[1809828449, 1344524341, 1281281549,  778659535,  831478468,
         892795837, 2147363248],
       [1354134907,  912368886, 1282422683, 1176130933, 1321417572,
        2041092559,  560225867],
       [ 223130968, 1125705181, 1283563817, 1573602331,  820653682,
        1105578014, 1027152312]], dtype=int32)>

In [447]:
# 2nd (notice we swap the first two params, no change really)
in2 = tensor_hash31(x, in1, F[4])

In [448]:
in2

<tf.Tensor: shape=(3, 7), dtype=int32, numpy=
array([[1182004548,  996894528,  737908117, 1246823400, 1414580191,
         602577027, 1942232991],
       [1758827840, 1909263414, 2020330800, 1872012962, 1558969532,
        1651297709, 1792508437],
       [1628266484, 1046662405,  989931544,   99060766, 1239079740,
        1481234240,  298429680]], dtype=int32)>

In [449]:
in3 = tensor_hash31(x, in2, F[5])

In [450]:
in3

<tf.Tensor: shape=(3, 7), dtype=int32, numpy=
array([[ 981790249, 1700078957, 2135000611, 1465974970, 2021823790,
         878007205,  600160209],
       [1554349206,  685624924,  139635884,  956979363,  714173973,
        1765662381, 1902298649],
       [  56644078,  501563528,  180103596, 1664096502,  205015974,
         454491610, 1197019569]], dtype=int32)>

In [459]:
in4 = tf.bitwise.bitwise_and(in3, 32768)

In [460]:
in4

<tf.Tensor: shape=(3, 7), dtype=int32, numpy=
array([[32768,     0, 32768,     0, 32768,     0, 32768],
       [    0, 32768, 32768,     0,     0, 32768, 32768],
       [    0,     0,     0,     0,     0, 32768,     0]], dtype=int32)>

In [462]:
in5 = tf.bitwise.right_shift(in4, 15)

In [463]:
in5

<tf.Tensor: shape=(3, 7), dtype=int32, numpy=
array([[1, 0, 1, 0, 1, 0, 1],
       [0, 1, 1, 0, 0, 1, 1],
       [0, 0, 0, 0, 0, 1, 0]], dtype=int32)>

In [465]:
out = 2 * (in5) - 1

In [466]:
out

<tf.Tensor: shape=(3, 7), dtype=int32, numpy=
array([[ 1, -1,  1, -1,  1, -1,  1],
       [-1,  1,  1, -1, -1,  1,  1],
       [-1, -1, -1, -1, -1,  1, -1]], dtype=int32)>

In [476]:
v = tf.constant([1, -1, 1], shape=(3,))

In [477]:
v_reshape = tf.expand_dims(v, axis=-1)

In [478]:
tf.multiply(out, v_reshape)

<tf.Tensor: shape=(3, 7), dtype=int32, numpy=
array([[ 1, -1,  1, -1,  1, -1,  1],
       [ 1, -1, -1,  1,  1, -1, -1],
       [-1, -1, -1, -1, -1,  1, -1]], dtype=int32)>

In [360]:
fourwise(tf.range(3))

ValueError: in user code:

    File "/tmp/ipykernel_5827/1903414477.py", line 4, in fourwise  *
        result = 2 * (tf.bitwise.right_shift(tf.bitwise.bitwise_and(hash31(hash31(hash31(x, F[2], F[3]), x, F[4]), x, F[5]), 32768), 15)) - 1
    File "/tmp/ipykernel_5827/3155749302.py", line 7, in hash31  *
        r = tf.multiply(a, x_reshaped) + b

    ValueError: Dimensions must be equal, but are 3 and 7 for '{{node add}} = AddV2[T=DT_INT32](Mul, b)' with input shapes: [3,7,3], [7].


In [351]:
hash31_test(tf.range(3), F[0], F[1])

<tf.Tensor: shape=(3, 7), dtype=int32, numpy=
array([[1977271226,  853021704,  389148655,   99585761, 1283153601,
        2109570381, 1216126172],
       [2096472306, 1031553032,  917739298,  229173230, 1373326314,
        1450757621, 1755361742],
       [1875248543, 1210084360, 1446329941,  358760699,  265161065,
         716118328, 2000369983]], dtype=int32)>

In [324]:
tf.add(F[0], F[1])

<tf.Tensor: shape=(7,), dtype=int32, numpy=
array([-2096472307,  1031553032,   917739298,   229173230, -1373326315,
       -1450757622,  1755361742], dtype=int32)>

In [318]:
ams_sketch.hash31(0, F[0], F[1])

<tf.Tensor: shape=(7,), dtype=int32, numpy=
array([1977271226,  853021704,  389148655,   99585761, 1283153601,
       2109570381, 1216126172], dtype=int32)>

In [319]:
ams_sketch.hash31(1, F[0], F[1])

<tf.Tensor: shape=(7,), dtype=int32, numpy=
array([2096472306, 1031553032,  917739298,  229173230, 1373326314,
       1450757621, 1755361742], dtype=int32)>

In [320]:
ams_sketch.hash31(2, F[0], F[1])

<tf.Tensor: shape=(7,), dtype=int32, numpy=
array([1875248543, 1210084360, 1446329941,  358760699,  265161065,
        716118328, 2000369983], dtype=int32)>

In [None]:
@tf.function
def sketch_for_vector2(self, v):
    """ Quicker implementation? Not really... """

    print("retrace")

    sketch = tf.zeros(shape=(self.depth, self.width), dtype=tf.float32)

    len_v = v.shape[0]

    pos_array = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
    deltas_array = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)

    for i in tf.range(len_v, dtype=tf.int32):
        pos = self.hash31(i, self.F[0], self.F[1]) % self.width
        delta = tf.cast(self.fourwise(i), dtype=tf.float32) * v[i]

        pos_array = pos_array.write(i, pos)
        deltas_array = deltas_array.write(i, delta)

    pos_stacked = pos_array.stack()

    # shape=(`len_v`, 7)
    deltas_stacked = deltas_array.stack()

    range_tensor = tf.range(self.depth)

    # Expand dimensions to create a 2D tensor with shape (1, depth)
    range_tensor_expanded = tf.expand_dims(range_tensor, 0)

    # Use tf.tile to repeat the tensor `len_v` times
    repeated_range_tensor = tf.tile(input_tensor_expanded, [len_v, 1])

    # shape=(`len_v`, 7, 2)
    indices = tf.stack([repeated_range_tensor, pos_stacked], axis=-1)

    sketch = tf.tensor_scatter_nd_add(sketch, indices, deltas_stacked)

    return sketch