In [41]:
import zmq
import numpy as np
import threading

import zmq
import numpy as np

def create_push_socket(ip, port):
    context = zmq.Context()
    socket = context.socket(zmq.PUSH)
    socket.connect(f"tcp://{ip}:{port}")
    
    return socket

def create_pull_socket(ip, port):
    context = zmq.Context()
    socket = context.socket(zmq.PULL)
    socket.bind(f"tcp://{ip}:{port}")
    
    return socket

def send_number_matrix(push_socket, number, matrix):

    msg = {"number": number, "matrix": matrix.tolist()}

    push_socket.send_json(msg)
    
def receive_number(pull_socket):
    msg = pull_socket.recv_json()
    
    return float(msg["number"])
    
    
def aggregate_numbers_matrices(pull_socket, n):
    local_matrices = []
    local_numbers = []

    for _ in range(n):
        msg = pull_socket.recv_json()
        matrix = np.array(msg["matrix"])
        number = float(msg["number"])
        local_matrices.append(matrix)
        local_numbers.append(number)

    mean_matrix = np.mean(local_matrices, axis=0)
    mean_number = np.mean(local_numbers)

    return mean_number, mean_matrix

def start_aggregation_loop(pull_socket, n, callback):
    def loop():
        while True:
            result = aggregate_numbers_matrices(pull_socket, n)
            callback(result)

    t = threading.Thread(target=loop, daemon=True)
    t.start()
    
    
def on_aggregated(result):
    # TODO: Send back!
    mean_number, mean_matrix = result
    print("Aggregation complete")
    print("Mean number:", mean_number)
    print("Mean matrix:\n", mean_matrix)

In [42]:
push = create_push_socket('localhost', 9205)

In [50]:
send_number_matrix(push, 4., np.random.rand(50,50))

In [10]:
import json

In [11]:
with open('hyperparameters/test_1.json') as f:
    x = json.load(f)

In [14]:
import pprint

In [15]:
pprint.pprint(x)

{'clients': {'lr': 0.001,
             'network': [{'data_path': '/home/mtheologitis/FDA-Opt-Sys/data/glue/mrpc/c2_0.pth',
                          'id': 0,
                          'ip': 'localhost',
                          'port': 9199},
                         {'data_path': '/home/mtheologitis/FDA-Opt-Sys/data/glue/mrpc/c2_1.pth',
                          'id': 1,
                          'ip': 'localhost',
                          'port': 9201}]},
 'dataset': {'batch_size': 8,
             'dirichlet_alpha': 1,
             'name': 'mrpc',
             'path': 'glue'},
 'model': {'checkpoint': 'prajjwal1/bert-tiny', 'num_labels': 2},
 'server': {'network': {'ip': 'localhost',
                        'ip_pull_socket': 'localhost',
                        'port': 8083,
                        'port_pull_socket': 9101},
            'strategy': {'eta': 0.0001, 'fda': True, 'name': 'FedAdam'}},
 'training': {'clients_per_round': 2,
              'local_epochs': 1,
              

In [17]:
pprint.pprint(x['clients']['network'])

[{'data_path': '/home/mtheologitis/FDA-Opt-Sys/data/glue/mrpc/c2_0.pth',
  'id': 0,
  'ip': 'localhost',
  'port': 9199},
 {'data_path': '/home/mtheologitis/FDA-Opt-Sys/data/glue/mrpc/c2_1.pth',
  'id': 1,
  'ip': 'localhost',
  'port': 9201}]


In [8]:
with open('hyperparameters/test_1.json', 'w', encoding='utf-8') as f:
    json.dump(x, f, ensure_ascii=False)

In [2]:
import torch
from math import sqrt
import random
import gc

DEVICE = 'cpu'


class AmsSketch:
    """
    AMS Sketch class for approximate second moment estimation in PyTorch.
    """

    def __init__(self, depth=3, width=500, seed=42):
        
        torch.manual_seed(seed)
        random.seed(seed)

        self.depth = depth
        self.width = width

        self.epsilon = 1. / sqrt(width)

        self.F = torch.randint(0, (1 << 31) - 1, (6, depth), dtype=torch.int32)

        # Dictionary to store precomputed results
        self.precomputed_dict = {}

    def precompute(self, d):
        pos_tensor = self.tensor_hash31(torch.arange(d), self.F[0], self.F[1]) % self.width  # shape=(d, depth)
        four = self.tensor_fourwise(torch.arange(d)).float()  # shape=(d, depth)
        self.precomputed_dict[('pos_tensor', d)] = pos_tensor.to(DEVICE)  # shape=(d, depth)
        self.precomputed_dict[('four', d)] = four.to(DEVICE)  # shape=(d, depth)

    @staticmethod
    def hash31(x, a, b):
        r = a * x + b
        fold = torch.bitwise_xor(r >> 31, r)
        return fold & 2147483647

    @staticmethod
    def tensor_hash31(x, a, b):
        """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., torch.arange(d)) """
        x_reshaped = x.unsqueeze(-1)
        r = a * x_reshaped + b
        fold = torch.bitwise_xor(r >> 31, r)
        return fold & 2147483647

    def tensor_fourwise(self, x):
        """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., torch.arange(d)) """
        in1 = self.tensor_hash31(x, self.F[2], self.F[3])  # shape = (`x_dim`, `depth`)
        in2 = self.tensor_hash31(x, in1, self.F[4])  # shape = (`x_dim`, `depth`)
        in3 = self.tensor_hash31(x, in2, self.F[5])  # shape = (`x_dim`, `depth`)

        in4 = in3 & 32768  # shape = (`x_dim`, `depth`)
        return 2 * (in4 >> 15) - 1  # shape = (`x_dim`, `depth`)

    def sketch_for_vector(self, v):
        """ Efficient computation of sketch using PyTorch tensors.

        Args:
        - v (torch.Tensor): Vector to sketch. Shape=(d,).

        Returns:
        - torch.Tensor: An AMS Sketch. Shape=(`depth`, `width`).
        """
        d = v.shape[0]

        if ('four', d) not in self.precomputed_dict:
            self.precompute(d)

        four, pos_tensor = self.precomputed_dict[('four', d)], self.precomputed_dict[('pos_tensor', d)]

        sketch = self._sketch_for_vector(v, four, pos_tensor)

        gc.collect()

        return sketch

    def _sketch_for_vector(self, v, four, pos_tensor):
        """
        PyTorch translation of the TensorFlow function using a simple for loop.

        Args:
        - v (torch.Tensor): Vector to sketch. Shape=(d,).
        - four (torch.Tensor): Precomputed fourwise tensor. Shape=(d, depth).
        - indices (torch.Tensor): Precomputed indices for scattering. Shape=(d, depth, 2).

        Returns:
        - sketch (torch.Tensor): The AMS sketch tensor. Shape=(depth, width).
        """

        # Expand the input vector v to match dimensions for element-wise multiplication
        v_expand = v.unsqueeze(-1).to(DEVICE)  # shape=(d, 1)

        # Element-wise multiply v_expand and four to get deltas
        deltas_tensor = four * v_expand  # shape=(d, depth)

        # Initialize the sketch tensor with zeros
        sketch = torch.zeros((self.depth, self.width), dtype=torch.float32).to(DEVICE)

        # Loop over each depth and scatter the corresponding values
        for i in range(self.depth):
            # Compute the width indices on the fly
            width_indices = pos_tensor[:, i]  # shape=(d,), indices for the width dimension

            deltas = deltas_tensor[:, i]

            # Add the deltas_tensor[:, i] (shape=(d,)) into the correct rows
            # using index_add on the width dimension
            sketch[i].index_add_(0, width_indices, deltas)

        return sketch

    @staticmethod
    def estimate_euc_norm_squared(sketch):
        """ Estimate the Euclidean norm squared of a vector using its AMS sketch.

        Args:
        - sketch (torch.Tensor): AMS sketch of a vector. Shape=(`depth`, `width`).

        Returns:
        - float: Estimated squared Euclidean norm.
        """
        norm_sq_rows = torch.sum(sketch ** 2, dim=1)
        return torch.median(norm_sq_rows).item()

In [3]:
sk = AmsSketch()

In [5]:
import numpy as np

In [6]:
x = torch.from_numpy(np.random.rand(20)).to(torch.float32)

In [7]:
sketch = sk.sketch_for_vector(x)

In [8]:
s = sketch.cpu().numpy()

In [9]:
s

array([[0.       , 0.       , 0.       , ..., 0.       , 0.       ,
        0.       ],
       [0.       , 0.       , 0.8792607, ..., 0.       , 0.       ,
        0.       ],
       [0.       , 0.       , 0.       , ..., 0.       , 0.       ,
        0.       ]], shape=(3, 500), dtype=float32)

In [11]:
AmsSketch.estimate_euc_norm_squared(torch.from_numpy(s))

7.666426658630371

In [38]:
s

array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.21161994, ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]], shape=(3, 500), dtype=float32)

In [40]:
s.tolist()

[[0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  -0.4632640480995178,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  -0.3484123945236206,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.19549338519573212,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
