# 🧮 Database Reconstruction Attacks (DRAs)

This notebook is adapted from a paper on database reconstruction attacks. You can find the paper [here](https://queue.acm.org/detail.cfm?id=3295691)

There are a number of reasons businesses and governments want to share information about people. It is important that when sharing information, you consider privacy and anonymity of the people that data is derived from. In most cases, aggregate data does little to hinder hackers from being able to re-create a database that is either very close, or exactlythe same as the original data. In this application, we will take a simple example and re-create the database from nothing but aggregate statistics about those people.

We will mainly use [Z3](https://github.com/Z3Prover/z3) for this. Imagine we have the following database that contains information for people within a certain geographic area (going forward we refer to this area as a **block**.)

We have *7* people in total in this block. Alongside **age**, we also have each resident's **smoking status**, **employment status** and whether they are **married** or not, we publish a variety of statistics. You have probably seen something similar in your countries census.

> 📓 To simplify the example, this fictional world has:
> - Two marriage statuses; Married (**True**) or Single (**False**)
> - Two smoking statuses; Non-Smoker (**False**) or Smoker (**True**)
> - Two employment statuses;  Unemployed (**False**) or Employed (**True**)

> 👾 One additional piece of logic we know is that any statistics with a **count of less than 3** is suppressed. Suppression of statistics with low counts is often used as a tactic for protecting privacy. The less people there are to represent a statistic, the more they often stick out in a dataset meaning their privacy is often more at risk than those who 'blend in with the crowd'. As we'll see, simply knowing that a statistic is suppressed can even be used to attack a dataset

The basic steps for setting up and solving a problem like this will often follow a similar process to this:

1. Import your chosen constraint library (In our case we will use a library called Z3). ✅
2. Declare a solver 🧮
3. Create the variables 💎
4. Define the constraints and add to the solver 📟
5. Check the solver can find a solution 🧠
6. Access the created model and display the results 🎉

## 1. Import Z3. ✅

In [1]:
from typing import Optional, Union, List, Tuple

from pydantic import BaseModel # Pydantic is a great library for modelling data, similar to dataclasses.
import numpy as np
import pandas as pd
import z3 # This is a library allowing us to model our re-construction attack using constraint satisfaction solvers, more on this shortly.

For this notebook, we'll work with two main files containing our data. 👩🏽 Alice is the owner of our original database about our 7 people. Alice works for a company that conducts Census questionnaries. She conducted a survey asking 7 people the following 5 questions

| Question           | Type    |
|--------------------|---------|
| What is your name? | String  |
| What is you age?   | Integer |
| Are you married?   | Boolean |
| Are you a smoker?  | Boolean |
| Are you employed?  | Boolean |



Here is what Alice can see when she looks at the raw data:

In [2]:
database = pd.read_csv("data/database.csv")
database

Unnamed: 0,name,age,married,smoker,employed
0,Sara Gray,8,False,False,False
1,Joseph Collins,18,False,True,True
2,Vincent Porter,24,False,False,True
3,Tiffany Brown,30,True,True,True
4,Brenda Small,36,True,False,False
5,Dr. Tina Ayala,66,True,False,False
6,Rodney Gonzalez,84,True,True,False


As you can see, there are 7 rows containing the information of 7 different people. Alice wants to share some statistical information about these people so that people like 🧔 Bob can use it to understand more about the people who live in this area. Sharing of this type of information is **incredibly important for open data and open research**. However, as we'll see there can often be some privacy challenges you may need we need to overcome when sharing sensitive data of this nature.

When sharing the stats, this is what Alice will create:

    1. A CSV containing the Block Stats for the 7 people
    2. She'll report the count, mean age and median age of each cohort in the block

Below you can see what the final stats that she outputs looks like. As you can see, she has shared various statistics about our cohorts contained within our block stats. Also note the suppression of the A3 stat for single adults. This is because the count of people in this stat is < 3.

In [3]:
block_stats = pd.read_csv("data/block-stats.csv")
block_stats

Unnamed: 0,statistic,name,count,median,mean
0,A1,total-population,7.0,30.0,38.0
1,A2,non-smoker,4.0,30.0,33.0
2,B2,smoker,3.0,30.0,44.0
3,C2,unemployed,4.0,51.0,48.0
4,D2,employed,3.0,24.0,24.0
5,A3,single-adults,,,
6,B3,married-adults,4.0,51.0,54.0
7,A4,unemployed-non-smoker,3.0,36.0,37.0


What we're going to do is demonstrate how we can fully re-construct the database from nothing more than the summary stats provided above. To do this, we're going to look towards using a class of logical constraint solvers called [Z3](https://github.com/Z3Prover/z3). You can read more about Z3 [here](https://z3prover.github.io/papers/programmingz3.html). Essentially, what were going to be doing is using the stats that Alice provided as a set of constraints we feed into Z3 and ask it to model what the possible combinations of answers to our original 5 questions were.

## 2. Declare a solver 🧮
A solver is an object that holds all our constraints. The solver then has a number of methods for adding constraints and computing solutions to those constraints. 

In [4]:
solver = z3.Solver()

# the output of the solver is a blank list. 
# After we add a constraint to the solver, printing it will reveal all of the constraints that have been added
solver

Before we add constraints to our solver, let's create a BlockStats object that will allow us to access the data in an easy manner

In [5]:
# Each cohort will have the following information:
class Stat(BaseModel):
    name: str
    count: Optional[int]
    median: Optional[int]
    mean: Optional[Union[int, float]]


# Each row in the final CSV she will output will be given an ID and contain everything from our Stat class
class BlockStats(BaseModel):
    A1: Stat
    A2: Stat
    B2: Stat
    C2: Stat
    D2: Stat
    A3: Stat
    B3: Stat
    A4: Stat
    
stats = BlockStats(**block_stats.replace({np.nan: None}).set_index('statistic').to_dict(orient='index'))
stats

BlockStats(A1=Stat(name='total-population', count=7, median=30, mean=38), A2=Stat(name='non-smoker', count=4, median=30, mean=33), B2=Stat(name='smoker', count=3, median=30, mean=44), C2=Stat(name='unemployed', count=4, median=51, mean=48), D2=Stat(name='employed', count=3, median=24, mean=24), A3=Stat(name='single-adults', count=None, median=None, mean=None), B3=Stat(name='married-adults', count=4, median=51, mean=54), A4=Stat(name='unemployed-non-smoker', count=3, median=36, mean=37))

This object stores each stat by its code. We can then use the values through attribute access, I.E to find the count of smokers: `stats.B2.count`

## 3. Create the variables 💎
Now that we have a solver and we have our stats object setup, we can start creating our variables. What we're trying to re-create here is the answers to the questions:

1. What is your age?
2. Are you a smoker?
3. Are you married?
4. Are you employed?

Let's step through the approach how we're going to model these are Z3 variables.

1. What is your age? This is pretty straight forward. For this we need an Array of integers for our population size (7) that represents the ages. What we want at the end of this is a list of 7 combinations of ages that exist in order to create the published statistical tables above.
2. Are you a smoker? This one is a bit harder. It's possible to model boolean values in Z3 (with `z3.Bool`). However, what is tricky is choosing a subset of our 7 people to model for each cohort. For instance, we have 3 people who are smokers. How do we know which of our 7 ages are associated to those who are smokers?. There are a few ways to do this, but the approach we will follow is this:
    - Create a function that can create two `z3.IntVectors`. 
        - The first one will be of the length pertaining to our cohort (E.g Smokers=3).
        - The second one will be the remaining people in that cohort (E.g Non-Smokers=4). It's important that the length of both of these IntVectors == 7
    - The IntVectors must contain distinct integers from (0-6) that represent which index in our ages array pertain to that cohort
    - From here, we can determine which index of our ages aray are part of a cohort by checking if that index exists in our IntVector

    So, for example if we have the following ages array:
        - [0, 1, 2, 3, 4, 5, 6] 
    
    And we've added constraints that tell us smokers exist at positions: [0, 1, 2]
    And non-smokers occupy the remaining positions [3, 4, 5, 6]. 
    
    We can iterate through the ages array and check that that position is in our smokers array to determine if the person at that index is a smoker
    ```
    for i in range(len(ages):
        if i in smokers_positions:
            return True
    ```
3. Are you married? Same as above
4. Are you employed? Same as above

Let's start with our ages array. For this we will use `z3.Array`

In [6]:
ages: z3.ArraySort = z3.Array('ages', z3.IntSort(), z3.IntSort())

Now we'll start adding some basic constraints to our ages. We know one thing for certain: how many people are in this array. `z3.Array` is an unbounded data type, so we constrain the size by setting constraints up to the nth index of the array. Since we'll need to be iterating over this number quite a bit, we'll store it in the variable `n`.

The second constraint we can add is where we can use some auxillary information to make our logic about ages more sound. Your age is a positive integer between 0 and an upper bound. But what should the upper bound be? A little bit of research reveals that [Jeanne_Calment](https://en.wikipedia.org/wiki/Jeanne_Calment) is the oldest known person to have lived at the age of 122! This sounds like a decent number to cap our ages at, and just to be safe we can add a little buffer.

> ⚠️ Note: It's important to think about what auxillary information people might have when they're attacking your data. Even a little information can be used to improve an attack

## 4. Define the constraints and add to the solver 📟

In [7]:
n = stats.A1.count

min_age = 0
max_age = 125

In [8]:
# Constrain each age to our min and max ages
for i in range(n):
    solver.add(z3.And(z3.Select(ages, i) > min_age, z3.Select(ages, i) < max_age))

Now that we've added some constraints, lets have a look at the state of the solver

In [9]:
solver

As you can see, the solver now has the state of our constraints for each of the ages in our ages variables. As we build up the constraints, you can print out the contents of the solver to see all the constraints it has knowledge of.add_mean_constraint(solver, ages, range(n), stats.A1.mean)

Next, lets work on our mean calculation. To calculate means in Z3, we use the `z3.Sum()` method and then divide by the number of elements in our array. To do this, we'll create a function we can re-use later for the constraints in our other cohorts.

In [10]:
def add_mean_constraint(solver: z3.Solver, ages: z3.Array, indices: List[z3.ArithRef], mean: int) -> None:
    """Mean constraint
    The mean constraint is added by summing the values in `ages` if they exist in `indices` and dividing
    by the length of `indices`.
    
    Args:
        solver: Our z3.Solver to add the constraints to
        array: The iterable of values we want to constrain to our median
        indices: Various cohorts (I.E smokers, non-smoker, etc) are represented by different indices in the ages 
            array (I.E smokers might be represented by indices (0-4]). We pass in which indices we care about to 
            ensure we're sorting and constraining them
        mean: The mean we want to constrain to
        
    Returns: None
    """
    # We'll dive into the indices argument more shortly - it's not necessary for the total population stat, but is
    # required in future mean constraints.
    solver.add(z3.Sum([ages[idx] for idx in indices]) / len(indices) == mean)

    
# we call the function, pass in the solver, ages and value from out stats object that contains our mean   
add_mean_constraint(solver, ages, range(n), stats.A1.mean)

Next up, let's add the constraint for our total population median. Medians are a little bit harder to do since we first need to make sure the values are sorted so we can pluck the middle number. Sorting is generally a simple algorithm when you know what the values are ahead of time - but in our case, we don't know what the values are - thats the solvers job at the end of this process. What we need to do for this is apply a constraint that tells the solver to sort the values before ensuring the middle number equals our median.

For this approach we will iterate through our `ages` variable, and ensure that the first element is less than or equal to the next element until we reach the end of our array (I.E `ages[0]` is <= `ages[1]`, `ages[1]` is <= `ages[2]` etc)

In [11]:
# median age constraint
from itertools import tee

def pairwise(iterable):
    """Pairwise
    Given an iterable, return a generator of pairs of values where the second
    generator starts at position 1 and the first starts at position 0.
    Example:
        [a for a in pairwise([1,2,3])]
        # returns: [(1, 2), (2, 3)]
    
    Args:
        iterable: The iterable to generate pairs for
    
    Returns: A generator yields pairs of items
    """
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

def pairwise_sort(solver: z3.Solver, iterable, size: int) -> None:
    """Pairwise Sort
    In order to model our median constraint, we need to ensure that any
    variables related to each of our cohorts is sorted. This pairwise
    sort will iterate through `iterable` in pairs and add constraints 
    ensuring they are sorted smallest -> largest. We use this as part of our
    `add_median_constraint` function below
    Args:
        solver: Our z3.Solver to add the constraints to
        array: The iterable of values we need to sort
        size: The total number of elements we need to sort

    Returns: None

    """
    for a, b in pairwise([iterable[i] for i in range(size)]):
        solver.add(a <= b)

def add_median_constraint(solver: z3.Solver, ages: z3.Array, indices: List[z3.ArithRef], median: int) -> None:
    """Median Constraint
    To calculate the median, we need to ensure our variables are sorted. For this we use a pairwise sort function.
    We then calculate the middle indice of the ages variable related to the indices we're interested in, and
    cater to situations where we have an odd or even number of values in the array.
    Args:
        solver: Our z3.Solver to add the constraints to
        array: The iterable of values we want to constrain to our median
        indices: Various cohorts (I.E smokers, non-smoker, etc) are represented by different indices in the ages 
            array (I.E smokers might be represented by indices (0-4]). We pass in which indices we care about to 
            ensure we're sorting and constraining them
        median: The median we want to constrain to
    """
    pairwise_sort(solver, ages, len(indices))

    med_idx = len(indices) // 2

    if len(indices) % 2 == 0:
        solver.add(ages[indices[med_idx - 1]] + ages[indices[med_idx]] == median * 2)
    else:
        solver.add(z3.Store(ages, indices[med_idx], median) == ages)

# here we call our median constraint function, pass in the solver, ages variable and our median we want to constrain to
add_median_constraint(solver, ages, range(n), stats.A1.median)

Ok, up to now we've been iterating over our ages array for our total population and adding constraints on everyone. Now, we need to add constraints for our cohorts (I.E smokers, non-smokers, employed, unemployed, married, single etc). I've mentioned the approach we were going to take above, and this function below is the implementation we'll be using.

Essentially, this function takes our solver, a pair of names and an index to split at and returns two z3 objects containing integers that reference which index in our ages array is part of that cohort.

We also add three constraints on these new variables to ensure they give us the right values:
    1. Indices must be between 0 and 7, since thats the total size of our population
    2. Indices must be distinct, I.E no person can appear in the same cohort twice
    3. The indices must be sorted (to aid in our median calculation later)

In [12]:
def split_indices(solver: z3.Solver, name_pair: Tuple[str, str], split_at: int) -> List[z3.ArithRef]:
    first_indices = z3.IntVector(name_pair[0], split_at)
    remaining_indices = z3.IntVector(name_pair[1], n - split_at)

    # indices must between 0 and 7, ensuring we only have 7 people accounted for
    solver.add(*[z3.And(idx >= 0, idx < n) for idx in first_indices + remaining_indices])

    # # indices must be distinct, I.E each person is a distinct person
    solver.add(z3.Distinct(*[idx for idx in first_indices + remaining_indices]))

    # # indices must be sorted for our future median calculations
    pairwise_sort(solver, first_indices, split_at)
    pairwise_sort(solver, remaining_indices, n - split_at)

    return first_indices, remaining_indices

Let's go ahead and call this function and see what we get back:

In [13]:
# here we generate the indices for our married and single cohorts by using the number of married people (stat B3) as our splitting point
married_indices, single_indices = split_indices(solver=solver, name_pair=('married', 'single'), split_at=stats.B3.count)

married_indices, single_indices

([married__0, married__1, married__2, married__3],
 [single__0, single__1, single__2])

we now have two new variables we can use to set constraints on our cohorts.

> Note: The suppressed statistic is included in this cohort - here we use that information to make our result more accurate. It's crazy to think that by suppressing a statistic we are possible giving information to our attackers! This is why privacy and privacy engineering is so important - having knowledge of these issues can help to mitigate against them in the future.

In [14]:
# constrain the ages of married people to the legal age
solver.add(*[ages[idx] >= 18 for idx in married_indices])
solver.add(*[ages[idx] >= 0 for idx in single_indices])

# # calculate the average for a subset of our database
# solver.add(z3.Sum([ages[idx] for idx in married_indices]) / stats.B3.count == stats.B3.mean)
add_mean_constraint(solver, ages, married_indices, stats.B3.mean)

# # calculate the median for a subset of our database
add_median_constraint(solver, ages, married_indices, stats.B3.median)

# # This is the supressed statistic, we know that the count must be 0, 1 or 2
single_adult_count = [z3.If(ages[idx] >= 18, 1, 0) for idx in single_indices]

solver.add(z3.Sum(single_adult_count) >= 0)
solver.add(z3.Sum(single_adult_count) <= 2)

We can then go ahead and do the same things for our smokers cohort:

In [15]:
smokers, non_smokers = split_indices(solver, name_pair=("smokers", "non-smokers"), split_at=stats.B2.count)

add_mean_constraint(solver=solver, ages=ages, indices=smokers, mean=stats.B2.mean)
add_mean_constraint(solver=solver, ages=ages, indices=non_smokers, mean=stats.A2.mean)

# add median constraints
add_median_constraint(solver=solver, ages=ages, indices=smokers, median=stats.B2.median)
add_median_constraint(solver=solver, ages=ages, indices=non_smokers, median=stats.A2.median)

employed cohort:

In [16]:
employed, unemployed = split_indices(solver, name_pair=("employed", "unemployed"), split_at=stats.D2.count)

# add mean constraints
add_mean_constraint(solver=solver, ages=ages, indices=employed, mean=stats.D2.mean)
add_mean_constraint(solver=solver, ages=ages, indices=unemployed, mean=stats.C2.mean)

# add median constraints
add_median_constraint(solver=solver, ages=ages, indices=employed, median=stats.D2.median)
add_median_constraint(solver=solver, ages=ages, indices=unemployed, median=stats.C2.median)

And our unemployed, non-smokers cohort:

In [17]:
unemployed_non_smokers, _ = split_indices(solver, name_pair=('unemployed_non_smoker', 'other'), split_at=stats.A4.count)

# intersection of unemployed and non-smoker/ for every unemployed-non-smoker in the database, ensure that one of the 
# unemployed indices and one of the non-smokers are identified as an unemployed-non-smoker
solver.add(
    *[
        z3.And(
            z3.Or(*[i == idx for i in unemployed]),
            z3.Or(*[j == idx for j in non_smokers]),
        )
        for idx in unemployed_non_smokers
    ]
)

# # add mean constraints
add_mean_constraint(solver=solver, ages=ages, indices=unemployed_non_smokers, mean=stats.A4.mean)

# # add median constraints
add_median_constraint(solver=solver, ages=ages, indices=unemployed_non_smokers, median=stats.A4.median)

Before we check the solver can find a solution, we'll have a look at all the constraints we've added to get an idea of how complex this problem can be. After running the cell bellow, you'll see all the constraints and logic that the solver has to keep in mind when finding values that suit out variables.

In [18]:
solver

## 5. Check the solver can find a solution 🧠
Running `solver.check()` ensures that the constraints and variables we've stored in the solver can be used to form a valid result. If the solver can find a solution, the check will return `z3.sat`. If there was no solution it would return `z3.unsat`.

In [19]:
solver.check()

## 6. Access the created model and display the results 🎉

Once we've run `solver.check()`, we can then access the model it has created of our variables and see if the re-creation matches the original database.

In [20]:
model = solver.model()
model

To make it easier to work with, let's turn this into a pandas `DataFrame`

In [21]:
model_df = pd.DataFrame(
    {
        'age': [model.evaluate(z3.Select(ages, i)) for i in range(n)],
        'married': [i in [model[idx] for idx in married_indices] for i in range(n)],
        'smoker': [i in [model[idx] for idx in smokers] for i in range(n)],
        'employed': [i in [model[idx] for idx in employed] for i in range(n)],
    }
)

model_df.age = model_df.age.astype(str).astype(int)
model_df

Unnamed: 0,age,married,smoker,employed
0,9,False,False,False
1,18,False,True,True
2,24,False,False,True
3,30,True,True,True
4,36,True,False,False
5,66,True,False,False
6,84,True,True,False


And below we can visually compare it with our original database:

In [22]:
database.drop(columns={"name"})

Unnamed: 0,age,married,smoker,employed
0,8,False,False,False
1,18,False,True,True
2,24,False,False,True
3,30,True,True,True
4,36,True,False,False
5,66,True,False,False
6,84,True,True,False


As you can see, we've been able to get pretty close! In some cases the solution you find may be exactly the same, or might possibly be off by 1-2 years. The solver will choose one of a possible range of solutions if more than one solution exists. 

Let's write a quick function to make sure our answers can also produce the same summary stats:

In [23]:
def check_reconstruction_matches_stats(stats: BaseModel, reconstructed: pd.DataFrame) -> None:
    """Check Re-Construction against original stats
    
    This function serves as a pseudo test-suite for out model. We check that the results it produces all match what we're
    expecting against our stats.
    
    Args:
        stats: A BaseModel object containing our stats
        reconstructed: The dataframe that we've re-created from our z3 model
    
    Returns: None
    
    Raises:
        AssertionError if our re-constructed table doesn't meet the conditions in our stats object
    """
    assert int(reconstructed.age.mean()) == stats.A1.mean
    assert reconstructed.age.median() == stats.A1.median


    married_df, single_df = reconstructed[reconstructed.married == True], reconstructed[reconstructed.married == False]

    assert int(married_df.age.mean()) == stats.B3.mean
    assert married_df.age.median() == stats.B3.median


    smokers_df, non_smokers_df = reconstructed[reconstructed.smoker == True], reconstructed[reconstructed.smoker == False]

    assert int(smokers_df.age.mean()) == stats.B2.mean
    assert smokers_df.age.median() == stats.B2.median
    
    assert int(non_smokers_df.age.mean()) == stats.A2.mean
    assert non_smokers_df.age.median() == stats.A2.median
    
    employed_df, unemployd_df = reconstructed[reconstructed.employed == True], reconstructed[reconstructed.employed == False]

    assert int(employed_df.age.mean()) == stats.D2.mean
    assert employed_df.age.median() == stats.D2.median
    
    assert int(unemployd_df.age.mean()) == stats.C2.mean
    assert unemployd_df.age.median() == stats.C2.median
    
    unemployed_non_smokers_df = reconstructed[
        (reconstructed.employed == False) &
        (reconstructed.smoker == False)
    ]
    
    assert int(unemployed_non_smokers_df.age.mean()) == stats.A4.mean
    assert int(unemployed_non_smokers_df.age.median()) == stats.A4.median

check_reconstruction_matches_stats(stats, model_df)

And finally, lets calculate the total % of the database that we've been able to re-create:

In [24]:
def check_accuracy(database: pd.DataFrame, reconstructed: pd.DataFrame) -> float:
    match, non_match = 0, 0
    computed = [tuple(v.values()) for v in reconstructed.to_dict(orient='records')]
    original = [tuple(v.values()) for v in database.drop(columns=["name"]).to_dict(orient='records')]

    to_check = [list(zip(computed[i], original[i])) for i in range(len(database))]
    for items in to_check:
        for pair in items:
            if pair[0] == pair[1]:
                match += 1
            else:
                non_match += 1
    return (match / (match + non_match)) * 100

In [25]:
check_accuracy(database, model_df)

96.42857142857143

And with that, we have seen how you can re-create a database from nothing more than summary statistics. Why is this possible? Well the answer to that comes from the same people that have come up with the best technique we know of to protecr against this type of attack:

> **"[Giving] overly accurate answers to too many questions will destroy privacy in a spectacular way"**
>
> Cynthia Dwork and Aaron Roth, Authors of 'The Algorithmic foundations of Differential Privacy'