This is an attempt to replicate CMS implementation by https://github.com/Samuel-Maddock/pure-LDP

In [17]:
import numpy as np
from scipy.linalg import hadamard
import math
import random
import xxhash
import copy
from collections import Counter

# Server and client

## Hash function creator

In [7]:
def generate_hash_funcs(k, m):
    """
    Generates k hash functions that map data to the range {0, 1,..., m-1}
    Args:
        k: The number of hash functions
        m: The domain {0,1,...,m-1} that hash func maps too
    Returns: List of k hash functions
    """
    hash_funcs = []
    for i in range(0, k):
        hash_funcs.append(generate_hash(m, i))
    return hash_funcs


def generate_hash(m, seed):
    """
    Generate a single hash function that maps data to {0, ... ,m-1}
    Args:
        m: int domain to map too
        seed: int the seed for the hash function

    Returns: Hash function

    """
    return lambda data: xxhash.xxh64(str(data), seed=seed).intdigest() % m


## Client

In [8]:

class CMSClient():
    def __init__(self, epsilon, hash_funcs, m, is_hadamard=False):
        """
        Apple's Count Mean Sketch (CMS) Algorithm
        Args:
            epsilon (float): Privacy Budget Epsilon
            hash_funcs (list of funcs): A list of hash function mapping data to {0...m-1} (can be generated by CMSServer)
            m (int): The length of the hash domain
            is_hadamard (optional bool): If true, uses Hadamard Count Mean Sketch (HCMS)
        """
        self.epsilon = epsilon
        self.index_mapper = lambda x: x - 1
        
        self.sketch_based = True
        self.is_hadamard = is_hadamard
        self.update_params(hash_funcs, m, epsilon)

        if self.is_hadamard:
            self.had = hadamard(self.m)

    def update_params(self, hash_funcs=None, m=None, epsilon=None, index_mapper=None):
        """
        Updates parameters
        Args:
            hash_funcs (optional list): List of k hash functions mapping data to {0...m-1}
            m (optional int): Length of hash domain
            epsilon (optional int): Privacy Budget
            index_mapper (optional function): Index mapper function
        """
        if hash_funcs is not None:
            self.hash_funcs = hash_funcs
            self.k = len(self.hash_funcs)

        self.epsilon = epsilon if epsilon is not None else self.epsilon
        self.m = m if m is not None else self.m

        if epsilon is not None:
            if self.is_hadamard:
                self.prob = 1 / (1 + math.pow(math.e, self.epsilon))
            else:
                self.prob = 1 / (1 + math.pow(math.e, self.epsilon / 2))

    def _one_hot(self, data):
        """
        Used internally to perturb data
        Args:
            data: arbitrary data

        Returns: perturbed vector v and hash index j

        """
        j = random.randint(0, self.k-1)
        h_j = self.hash_funcs[j]
        v = [0] * self.m if self.is_hadamard else np.full(self.m, -1)
        v[h_j(data)] = 1
        return v, j

    def _cms_perturb(self, data):
        """
        Used internally for peturbing data using the CMS algorithm
        Args:
            data: data to be perturbed

        Returns: peturbed cms data

        """
        v, j = self._one_hot(data)
        v[np.random.rand(*v.shape) < self.prob] *= -1 # "flip" bits with prob
        # return np.multiply(v, b), j # Used to generate a random vector b using np.random.choice but it was 3x slower than the above line
        return v, j

    def _hcms_perturb(self, data):
        """
        Used internally for perturbing data using HCMS
        Args:
            data: data to be perturbed

        Returns: peturbed hcms data

        """
        if not (self.m & (self.m - 1)) == 0:
            raise ValueError("m must be a positive integer, and m must be a power of 2 to use hcms")

        v, j = self._one_hot(data)
        b = random.choices([-1, 1], k=1, weights=[self.prob, 1 - self.prob])
        h_j = self.hash_funcs[j]
        w = self.had[:, h_j(data)]
        l = random.randint(0, self.m-1)
        return b[0] * w[l], j, l  # Return (b*w_l, index j, index l)

    def privatise(self, data):
        """
        Privatises data item using CMS/HCMS

        Args:
            data: item to be privatised

        Returns: Privatised data

        """
        data = str(data)
        if self.is_hadamard:
            return self._hcms_perturb(data)
        else:
            return self._cms_perturb(data)

## Server

In [14]:
class CMSServer():
    def __init__(self, epsilon, k, m, is_hadamard=False, index_mapper=None):
        """
        Server frequency oracle for Apple's Count Mean Sketch (CMS)

        Args:
            epsilon (float): Privacy Budget
            k (int): Number of hash functions
            m (int): Size of the hash domain
            is_hadamard (optional bool): If True, uses Hadamard Count Mean Sketch (HCMS)
            index_mapper (optional func): Index map function
        """
        self.epsilon = epsilon
        self.n = 0 # The number of data items aggregated

        self.name = "CMSServer" # Name of the frequency oracle for warning messages, set using .set_name(name)
        self.last_estimated = 0

        if index_mapper is None:
            self.index_mapper = lambda x: x - 1
        else:
            self.index_mapper = index_mapper


        self.sketch_based = True
        self.is_hadamard = is_hadamard
        self.update_params(k,m, epsilon, index_mapper=None)
        self.hash_funcs = generate_hash_funcs(k,m)
        self.sketch_matrix = np.zeros((self.k, self.m))
        self.transformed_matrix = np.zeros((self.k, self.m))

        self.last_estimated = self.n
        self.ones = np.ones(self.m)

        if self.is_hadamard:
            self.had = hadamard(self.m)

    def update_params(self, k=None, m=None, epsilon=None, index_mapper=None):
        """
        Updated internal parameters
        Args:
            k (optional int): Number of hash functions
            m (optional int): Size of hash domain
            epsilon (optional float): Privacy Budget
            d (optional int): Size of domain
            index_mapper (optional func): Index map function
        """
        self.k = k if k is not None else self.k
        self.m = m if m is not None else self.m
        self.hash_funcs = generate_hash_funcs(self.k,self.m)
        
        self.epsilon = epsilon if epsilon is not None else self.epsilon # Updating epsilon here will not update any internal probabilities
        # Any class that implements FreqOracleServer, needs to override update_params to update epsilon properly

        self.index_mapper = index_mapper if index_mapper is not None else self.index_mapper
        self.reset()
        

        if epsilon is not None:
            if self.is_hadamard:
                self.c = (math.pow(math.e, epsilon) + 1) / (math.pow(math.e, epsilon) - 1)
            else:
                self.c = (math.pow(math.e, epsilon / 2) + 1) / (math.pow(math.e, epsilon / 2) - 1)

    def _add_to_cms_sketch(self, data):
        """
        Given privatised data, adds it to the sketch matrics (CMS algorithm)
        Args:
            data: privatised data by CMS
        """
        item, hash_index = data
        self.sketch_matrix[hash_index] = self.sketch_matrix[hash_index] + self.k * ((self.c / 2) * item + 0.5 * self.ones)

    def _add_to_hcms_sketch(self, data):
        """
        Given privatised data, adds it to the sketch matrix (HCMS algorithm)

        Args:
            data: privatisd data by HCMS
        """
        bit_value, j, l = data
        self.sketch_matrix[j][l] = self.sketch_matrix[j][l] + self.k * self.c * bit_value

    def _transform_sketch_matrix(self):
        """
        Transforms the sketch matrix using inverse hadamard (HCMS)
        Returns: Transformed sketch matrix

        """
        return np.matmul(self.sketch_matrix, np.transpose(self.had))

    def _update_estimates(self):
        """
        If using HCMS, transforms the sketch matrix using inverse hadamard
        """
        if self.is_hadamard:
            self.last_estimated = self.n # TODO: Is this needed?
            self.transformed_matrix = self._transform_sketch_matrix()

    def get_hash_funcs(self):
        """
        Returns hash functions for CMSClient

        Returns: list of k hash_funcs

        """
        return self.hash_funcs

    def reset(self):
        """
        Resets sketch matrix (i.e resets all aggregated data)
        """
        self.last_estimated = 0
        self.n = 0
        self.sketch_matrix = np.zeros((self.k, self.m))
        self.transformed_matrix = np.zeros((self.k, self.m))

    def aggregate(self, data):
        """
        Aggregates privatised data

        Args:
            data: Data privatised by CMS/HCMS
        """
        if self.is_hadamard:
            self._add_to_hcms_sketch(data)
        else:
            self._add_to_cms_sketch(data)
        self.n += 1

    def estimate(self, data, suppress_warnings=False):
        """
        Estimates the frequency of the data item

        Args:
            data: item to be estimated
            suppress_warnings (optional bool): If True, will suppress estimation warnings

        Returns: Frequency Estimate

        """
        self.check_warnings(suppress_warnings)
        self.check_and_update_estimates()

        # If it's hadamard we need to transform the sketch matrix
            # To prevent this being performance intensive, we only transform if new data has been aggregated since it was last transformed

        sketch = self.sketch_matrix if not self.is_hadamard else self.transformed_matrix

        data = str(data)
        k, m = sketch.shape
        freq_sum = 0
        for i in range(0, k):
            freq_sum += sketch[i][self.hash_funcs[i](data)]

        return (m / (m - 1)) * ((1 / k) * freq_sum - (self.n / m))

# Experiment

In [15]:
N = 100000
epsilon = 3  
m = 2048
k = 1024

cms_params = {"m": m, "k": k, "epsilon": epsilon}
cms = {"client_params": cms_params, "server_params": cms_params}

In [18]:
data = np.concatenate(
    (
        [1] * 8000,
        [2] * 4000,
        [3] * 1000,
        [4] * 500,
        [5] * 1000,
        [6] * 1800,
        [7] * 2000,
        [8] * 300,
    )
)
original_freq = list(Counter(data).values())  
original_freq

[8000, 4000, 1000, 500, 1000, 1800, 2000, 300]

In [16]:
server_cms = CMSServer(epsilon, k, m)
client_cms = CMSClient(epsilon, server_cms.get_hash_funcs(), m)


In [19]:
priv_data = [client_cms.privatise(item) for item in data]
server_cms.aggregate_all(priv_data)
cms_estimates = server_cms.estimate_all(range(1, d + 1))


AttributeError: 'CMSServer' object has no attribute 'aggregate_all'