# Policy Laplace on Spark

Spark implementation of Policy Laplace from Differentially Private Set Union [https://arxiv.org/abs/2002.09745]

In [1]:
import pyspark
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
spark = SparkSession.builder.getOrCreate()

In [2]:
from os import path
if not path.exists('clean_askreddit.csv'):
    if not path.exists('clean_askreddit.csv'):
        !pip install wget
        import wget
        zip_path = 'https://dp-test-datasets.s3.amazonaws.com/clean_askreddit.csv.zip'
        wget.download(zip_path)
    import zipfile
    with zipfile.ZipFile('clean_askreddit.csv.zip', 'r') as zip:
        zip.extractall('.')
    



In [3]:
filepath = "clean_askreddit.csv"
reddit = spark.read.load(filepath, format="csv", sep=",",inferSchema="true", header="true").dropna()

### Prepare Data for Processing

Load the data from file and tokenize.  This code can be any caller-specific tokenization routine, and is independent of differential privacy.  Output RDD should include one list of tokens per row, but can have multiple rows per user, and does not need to be odered in any way.  This stage can be combined with other n_grams (e.g. 2-grams, 3-grams) and persisted to feed to DPSU.

In [4]:
import nltk

n_grams = 1
distinct = True

def tokenize(user_post):
    user, post = user_post
    tokens = post.split(" ")
    if n_grams > 1:
        tokens = list(nltk.ngrams(tokens, n_grams))
        tokens = ["_".join(g) for g in tokens]
    if distinct:
        tokens = list(set(tokens))
    return (user, tokens)
        
tokenized = reddit.select("author", "clean_text").rdd.map(tokenize).persist()


### Instantiate DPSU Processor

Create the object and pass in the privacy parameters.

In [5]:
from policy_laplace import PolicyLaplace

epsilon = 3.0
delta = np.exp(-10)
alpha = 5.0
tokens_per_user = 500
prune_tail_below = None
num_partitions = 1

pl = PolicyLaplace(epsilon, delta, alpha, tokens_per_user, prune_tail_below, num_partitions)

Params Delta_0=500, delta=4.54e-05, l_param=0.3333333333333333, l_rho=5.175812754165355, Gamma=6.842479420832021


In [6]:
# prune the tail
pruned = pl.prune_tail(tokenized)

# reservoir sample the input tokens
sampled = pl.reservoir_sample(pruned, distinct).persist()

In [7]:
counted = pl.process_partitions(sampled)

In [8]:
good = counted.filter(lambda row: pl.exceeds_threshold(row[1])).map(lambda row: row[0])

print("Retrieved {0} words from {1}".format(good.count(),counted.count()))
print(good.take(5))

Retrieved 13210 words from 150588
['ca', 'changed', 'reddit', 'america', 'oh']
