In [2]:
import logging
logging.getLogger("tensorflow").setLevel(logging.ERROR)

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf

In [3]:
from fdavg.models.advanced_cnn import get_compiled_and_built_advanced_cnn

In [168]:

class AmsSketch:
    """ 
    AMS Sketch class for approximate second moment estimation.
    """
        
    def __init__(self, depth=5, width=250, d=2592202):
        self.depth = depth
        self.width = width
        self.d = d
        self.F = tf.random.uniform(shape=(6, depth), minval=0, maxval=(1 << 31) - 1, dtype=tf.int32)
        
        self.initialize()

        
    def initialize(self):
        
        pos_tensor = self.tensor_hash31(tf.range(self.d), self.F[0], self.F[1]) % self.width  # shape=(d, 5)
        
        self.four = tf.cast(self.tensor_fourwise(tf.range(self.d)), dtype=tf.float32)  # shape=(d, 5)
        
        range_tensor = tf.range(self.depth)  # shape=(5,)
        
        # Expand dimensions to create a 2D tensor with shape (1, `self.depth`)
        range_tensor_expanded = tf.expand_dims(range_tensor, 0)  # shape=(1, 5)
        
        # Use tf.tile to repeat the range `len_v` times
        repeated_range_tensor = tf.tile(range_tensor_expanded, [len_v, 1])  # shape=(d, 5)
        
        # shape=(`len_v`, `self.depth`, 2)
        self.indices = tf.stack([repeated_range_tensor, pos_tensor], axis=-1)  # shape=(d, 5, 2)
    
        
    @staticmethod
    def hash31(x, a, b):
        r = a * x + b
        fold = tf.bitwise.bitwise_xor(tf.bitwise.right_shift(r, 31), r)
        return tf.bitwise.bitwise_and(fold, 2147483647)

    @staticmethod
    def tensor_hash31(x, a, b):  # GOOD
        """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., tf.range(d)) """

        # Reshape x to have an extra dimension, resulting in a shape of (k, 1)
        x_reshaped = tf.expand_dims(x, axis=-1)

        # shape=(`v_dim`, 7)
        r = tf.multiply(a, x_reshaped) + b

        fold = tf.bitwise.bitwise_xor(tf.bitwise.right_shift(r, 31), r)
        
        return tf.bitwise.bitwise_and(fold, 2147483647)

    def tensor_fourwise(self, x):
        """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., tf.range(d)) """
        # 1st use the tensor hash31
        in1 = self.tensor_hash31(x, self.F[2], self.F[3])  # shape = (`x_dim`,  `self.depth`)
        
        # 2st use the tensor hash31
        in2 = self.tensor_hash31(x, in1, self.F[4])  # shape = (`x_dim`,  `self.depth`)
        
        # 3rd use the tensor hash31
        in3 = self.tensor_hash31(x, in2, self.F[5])  # shape = (`x_dim`,  `self.depth`)
        
        in4 = tf.bitwise.bitwise_and(in3, 32768)  # shape = (`x_dim`,  `self.depth`)
        
        return 2 * (tf.bitwise.right_shift(in4, 15)) - 1  # shape = (`x_dim`,  `self.depth`)

    def fourwise(self, x):
        result = 2 * (tf.bitwise.right_shift(tf.bitwise.bitwise_and(self.hash31(self.hash31(self.hash31(x, self.F[2], self.F[3]), x, self.F[4]), x, self.F[5]), 32768), 15)) - 1
        return result

    @tf.function
    def sketch_for_vector(self, v):
        """ Extremely efficient computation of sketch with only using tensors.

        Args:
        - v (tf.Tensor): Vector to sketch. Shape=(d,).

        Returns:
        - tf.Tensor: An AMS - Sketch. Shape=(`depth`, `width`).
        """
        
        sketch = tf.zeros(shape=(self.depth, self.width), dtype=tf.float32)  # shape=(5, 250)
        
        v_expand = tf.expand_dims(v, axis=-1)  # shape=(d, 1)
        
        # shape=(d, 5): +- for each value v_i , i = 1, ..., d
        deltas_tensor = tf.multiply(self.four, v_expand)
        
        sketch = tf.tensor_scatter_nd_add(sketch, self.indices, deltas_tensor)  # shape=(5, 250)
        
        return sketch

    @staticmethod
    def estimate_euc_norm_squared(sketch):
        """ Estimate the Euclidean norm squared of a vector using its AMS sketch.

        Args:
        - sketch (tf.Tensor): AMS sketch of a vector. Shape=(`depth`, `width`).

        Returns:
        - tf.Tensor: Estimated squared Euclidean norm.
        """

        def _median(v):
            """ Median of tensor `v` with shape=(n,). Note: Suboptimal O(nlogn) but it's ok bcz n = `depth`"""
            length = tf.shape(v)[0]
            sorted_v = tf.sort(v)
            middle = length // 2

            return tf.cond(
                tf.equal(length % 2, 0),
                lambda: (sorted_v[middle - 1] + sorted_v[middle]) / 2.0,
                lambda: sorted_v[middle]
            )

        return _median(tf.reduce_sum(tf.square(sketch), axis=1))

In [143]:
num_clients = 10

In [144]:
cnns = [get_compiled_and_built_advanced_cnn((None, 28, 28), (28, 28, 1), 10) for _ in range(num_clients)]

In [145]:
v = get_compiled_and_built_advanced_cnn((None, 28, 28), (28, 28, 1), 10).trainable_vars_as_vector()

In [146]:
server = get_compiled_and_built_advanced_cnn((None, 28, 28), (28, 28, 1), 10).trainable_vars_as_vector()
server2 = get_compiled_and_built_advanced_cnn((None, 28, 28), (28, 28, 1), 10).trainable_vars_as_vector()

In [147]:
from math import sqrt

In [202]:
ams_sketch, epsilon = AmsSketch(11, 1500, 2592202), 1. / sqrt(250)

In [106]:
pos1 = ams_sketch.pos_tensor[:, 0]

pos1

<tf.Tensor: shape=(2592202,), dtype=int32, numpy=array([173, 181, 189, ..., 174, 166, 158], dtype=int32)>

In [107]:
v

<tf.Tensor: shape=(2592202,), dtype=float32, numpy=
array([ 0.03389673, -0.06132888,  0.00831743, ...,  0.        ,
        0.        ,  0.        ], dtype=float32)>

In [119]:
h1_sketch = tf.zeros(shape=(250, ), dtype=tf.float32)

In [121]:
tf.tensor_scatter_nd_add(h1_sketch, pos1, v)

InvalidArgumentError: {{function_node __wrapped__TensorScatterAdd_device_/job:localhost/replica:0/task:0/device:CPU:0}} Inner dimensions of output shape must match inner dimensions of updates shape. Output: [250] updates: [2592202] [Op:TensorScatterAdd] name: 

In [125]:
%%timeit

# Create a tensor of zeros with the shape [250] or more
zeros_tensor = tf.zeros([250], dtype=tf.float32)

# Use tensor_scatter_nd_add to add v to the zeros_tensor at the positions pos1
result_tensor = tf.tensor_scatter_nd_add(zeros_tensor, tf.reshape(pos1, (-1, 1)), v)

66 ms ± 628 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [250]:
v = tf.random.normal(shape=(2592202,))
sk = ams_sketch.sketch_for_vector(v)
real=ams_sketch.estimate_euc_norm_squared(sk)

In [251]:
est=tf.reduce_sum(tf.multiply(v, v))

In [252]:
((real - est)/ real).numpy() 

-0.02384311

## Sketch

In [138]:
%%timeit

delta_is = [
    cnn.trainable_vars_as_vector() - server
    for cnn in cnns
]

euc_norm_squared_clients = [
   tf.reduce_sum(tf.square(delta_i))
   for delta_i in delta_is 
]

sketches = [
    ams_sketch.sketch_for_vector(delta_i)
    for delta_i in delta_is
]

S_1 = tf.reduce_mean(euc_norm_squared_clients)
S_2 = tf.reduce_mean(sketches, axis=0)

# See theoretical analysis above
var = S_1 - (1. / (1. + epsilon)) * AmsSketch.estimate_euc_norm_squared(S_2)

3.65 s ± 15.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [78]:
%%timeit

sk = ams_sketch.sketch_for_vector(v)

352 ms ± 4.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [67]:
%%timeit

ams_sketch.estimate_euc_norm_squared(sk)

1.45 ms ± 31.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Linear

In [51]:
def ksi_unit(w_t0, w_tminus1):
    if tf.reduce_all(tf.equal(w_t0, w_tminus1)):
        # if equal then ksi becomes a random vector (will only happen in round 1)
        ksi = tf.random.stateless_normal(shape=w_t0.shape, seed=[1, 2])

    else:
        ksi = w_t0 - w_tminus1

    # Normalize and return
    return tf.divide(ksi, tf.norm(ksi))

In [58]:
%%timeit

ksi = ksi_unit(server, server2)

delta_is = [
    cnn.trainable_vars_as_vector() - server
    for cnn in cnns
]

euc_norm_squared_clients = [
   tf.reduce_sum(tf.square(delta_i))
   for delta_i in delta_is 
]

ksi_delta_clients = [
    tf.reduce_sum(tf.multiply(ksi, delta_i))
    for delta_i in delta_is
]

S_1 = tf.reduce_mean(euc_norm_squared_clients)
S_2 = tf.reduce_mean(ksi_delta_clients)

var = S_1 - S_2**2

203 ms ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
