In [27]:
import pyspark
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from sample_agg_blb import SampleAggregate
spark = SparkSession.builder.getOrCreate()

In [28]:
from pyspark.sql.types import FloatType, BooleanType

filepath = "../../../dp-test-datasets/data/PUMS_california_demographics/data.csv"
pums = spark.read.load(filepath, format="csv", sep=",",inferSchema="true", header="true")

pums = pums.withColumnRenamed("_c0", "PersonID")
#pums = pums.filter("PersonID < 10000")
pums = pums.withColumn("income", col("income").cast(FloatType()))
pums = pums.withColumn("latino", col("latino").cast(BooleanType()))
pums = pums.withColumn("black", col("black").cast(BooleanType()))
pums = pums.withColumn("asian", col("asian").cast(BooleanType()))
pums = pums.withColumn("married", col("married").cast(BooleanType()))

sa = SampleAggregate(pums)


In [29]:
from itertools import product
cols = ["sex", "married", "income", "age", "educ"] # needs to be a superset of keys
keys = ["married","educ"]
groups = list(product([True, False], range(15)))

In [30]:
sa.sample(cols)

In [31]:
sa.aggregate(["income", "age"], keys, groups) # cols here doesn't need to include keys.  this bootstraps

In [32]:
def mean_income(data):
    cols = list(zip(*list(data)))
    income = cols[0]
    weights = cols[-1]
    weighted_sum = np.sum([i * w for i, a, w in data])
    return float(weighted_sum / np.sum(weights))


def count(data):
    cols = list(zip(*list(data)))
    weights = cols[-1]
    return int(np.sum(weights))

def mean_age(data):
    cols = list(zip(*list(data)))
    age = cols[1]
    weights = cols[-1]
    weighted_sum = np.sum([a * w for i, a, w in data])
    return float(weighted_sum / np.sum(weights))


In [33]:
def mean_income_bootstraps(data):
    income_col = 0
    num_weights = len(data[0][1])
    num_rows = len(data)
    weighted_incomes = []
    all_weights = []
    mean_incomes = []
    for vals, weights in data:
        weighted_incomes.append([vals[income_col] * wt for wt in weights])
        all_weights.append(weights)
    
    for w in range(num_weights):
        sum_incomes = 0
        sum_weights = 0
        for r in range(num_rows):
            sum_incomes += weighted_incomes[r][w]
            sum_weights += all_weights[r][w]
        mean_incomes.append(sum_incomes / sum_weights)    
            
    return np.mean(mean_incomes)
 
sa.apply([mean_income_bootstraps])
sa.applied.take(5)

[Row(group=(True, 1), val=(9600.0,)),
 Row(group=(True, 4), val=(8783.129075441524,)),
 Row(group=(True, 5), val=(9038.975202754495,)),
 Row(group=(True, 6), val=(0.0,)),
 Row(group=(True, 7), val=(9003.299863340126,))]

In [34]:
import math

def mean_estimator(data, eps, delta, lam, parts):    
    if lam is not None:
        data = [-lam if d < -lam else lam if d > lam else d for d in data]
    sd = (2 * lam * math.sqrt(2 * math.log(1.25 / delta)))/(parts*eps)
    np.random.seed()
    noise = np.random.normal(0, sd)
    theta = float(np.nanmean(data))
    return theta + noise
    
def median_estimator(data, eps, delta, lam, parts):
    # not private
    return float(np.median(data))

In [35]:
eps = 1.0
delta = 1E-9
sens = [100000, 65, 350000]

est_mean = sa.estimate(mean_estimator, eps, delta, sens)

est_med = sa.estimate(median_estimator, eps, delta, sens)


est_mean.toDF().show(50, False)
est_med.toDF().show(50, False)

+-----------+---------------------+
|group      |val                  |
+-----------+---------------------+
|[false, 7] |[18645.11218398119]  |
|[true, 4]  |[2281.982280936567]  |
|[false, 8] |[5258.986370168394]  |
|[false, 9] |[29708.2882832217]   |
|[true, 10] |[23403.616917873085] |
|[true, 11] |[39760.65794305519]  |
|[false, 10]|[24052.39671825686]  |
|[true, 9]  |[34483.33409222805]  |
|[false, 11]|[15880.360331671496] |
|[true, 8]  |[5644.037290017994]  |
|[false, 12]|[29759.115894791077] |
|[false, 13]|[14337.414295909026] |
|[true, 14] |[16654.59016626612]  |
|[false, 14]|[26276.122479631136] |
|[true, 13] |[43696.03705429409]  |
|[false, 1] |[-3629.634980108347] |
|[true, 12] |[27543.68267376242]  |
|[true, 2]  |[19316.518361979844] |
|[true, 3]  |[15215.098003449108] |
|[false, 2] |[17880.37184781104]  |
|[true, 1]  |[-244.26163148868]   |
|[false, 3] |[14244.68664625223]  |
|[false, 4] |[15388.822295219088] |
|[false, 5] |[-1150.1092781562475]|
|[true, 6]  |[16537.73143025