# 🧮 Database Reconstruction Attacks (DRAs)

This notebook is adapted from a paper on database reconstruction attacks. You can find the paper [here](https://cacm.acm.org/magazines/2019/3/234925-understanding-database-reconstruction-attacks-on-public-data/fulltext)

There are a number of reasons businesses and governments want to share information about people. It is important that when sharing information, you consider privacy and anonymity of the people that data is derived from. In most cases, aggregate data does little to hinder hackers from being able to re-create a database that is either very close, or exactlythe same as the original data. In this application, we will take a simple example and re-create the database from nothing but aggregate statistics about those people.

We will mainly use [Z3](https://github.com/Z3Prover/z3) for this. Imagine we have the following database that contains information for people within a certain geographic area (going forward we refer to this area as a **block**.)

We have *7* people in total in this block. Alongside **age**, we also have each resident's **smoking status**, **employment status** and whether they are **married** or not, we publish a variety of statistics. You have probably seen something similar in your countries census.

> 📓 To simplify the example, this fictional world has:
> - Two marriage statuses; Married (**True**) or Single (**False**)
> - Two smoking statuses; Non-Smoker (**False**) or Smoker (**True**)
> - Two employment statuses;  Unemployed (**False**) or Employed (**True**)

> 👾 One additional piece of logic we know is that any statistics with a **count of less than 3** is suppressed. Suppression of statistics with low counts is often used as a tactic for protecting privacy. The less people there are to represent a statistic, the more they often stick out in a dataset meaning their privacy is often more at risk than those who 'blend in with the crowd'. As we'll see, simply knowing that a statistic is suppressed can even be used to attack a dataset

In [48]:
from typing import Optional, Union, List, Tuple

from pydantic import BaseModel # Pydantic is a great library for modelling data, similar to dataclasses.
import numpy as np
import pandas as pd
import z3 # This is a library allowing us to model our re-construction attack using constraint satisfaction solvers, more on this shortly.

For this notebook, we'll work with two main files containing our data. 👩🏽 Alice is the owner of our original database about our 7 people. Alice works for a company that conducts Census questionnaries. She conducted a survey asking 7 people the following 5 questions

| Question           | Type    |
|--------------------|---------|
| What is your name? | String  |
| What is you age?   | Integer |
| Are you married?   | Boolean |
| Are you a smoker?  | Boolean |
| Are you employed?  | Boolean |



Here is what Alice can see when she looks at the raw data:

In [49]:
database = pd.read_csv("data/database.csv")
database

Unnamed: 0,name,age,married,smoker,employed
0,Sara Gray,8,False,False,False
1,Joseph Collins,18,False,True,True
2,Vincent Porter,24,False,False,True
3,Tiffany Brown,30,True,True,True
4,Brenda Small,36,True,False,False
5,Dr. Tina Ayala,66,True,False,False
6,Rodney Gonzalez,84,True,True,False


As you can see, there are 7 rows containing the information of 7 different people. Alice wants to share some statistical information about these people so that people like 🧔 Bob can use it to understand more about the people who live in this area. Sharing of this type of information is **incredibly important for open data and open research**. However, as we'll see there can often be some privacy challenges you may need we need to overcome when sharing sensitive data of this nature.

When sharing the stats, this is what Alice will create:

    1. A CSV containing the Block Stats for the 7 people
    2. She'll report the count, mean age and median age of each cohort in the block

In [25]:
# Each cohort in the block will have the following information:

class Stat(BaseModel):
    name: str
    count: Optional[int]
    median: Optional[int]
    mean: Optional[Union[int, float]]


# Each row in the final CSV she will output will be given an ID and contain everything from our Stat class
class BlockStats(BaseModel):
    A1: Stat
    A2: Stat
    B2: Stat
    C2: Stat
    D2: Stat
    A3: Stat
    B3: Stat
    A4: Stat

Below you can see what the final stats that she outputs looks like. As you can see, she has shared various statistics about our cohorts contained within our block stats.

In [26]:
block_stats = pd.read_csv("data/block-stats.csv")
block_stats

Unnamed: 0,statistic,name,count,median,mean
0,A1,total-population,7.0,30.0,38.0
1,A2,non-smoker,4.0,30.0,33.0
2,B2,smoker,3.0,30.0,44.0
3,C2,unemployed,4.0,51.0,48.0
4,D2,employed,3.0,24.0,24.0
5,A3,single-adults,,,
6,B3,married-adults,4.0,51.0,54.0
7,A4,unemployed-non-smoker,3.0,36.0,37.0


What we're going to do is demonstrate how we can fully re-construct the database from nothing more than the summary stats provided above. To do this, we're going to look towards using a class of logical constraint solvers called [Z3](https://github.com/Z3Prover/z3). You can read more about Z3 [here](https://z3prover.github.io/papers/programmingz3.html). Essentially, what were going to be doing is using the stats that Alice provided as a set of constraints we feed into Z3 and ask it to model what the possible combinations of answers to our original 5 questions were.

The basic steps for setting up and solving a problem often follow a similar process to this:

1. Import Z3. ✅
2. Declare a solver 🧮
3. Create the variables 💎
4. Define the constraints 📟
5. Check the solver can find a solution 🧠
6. Access the created model and display the results 🎉

In [27]:
solver = z3.Solver()

In [50]:
stats = BlockStats(**block_stats.replace({np.nan: None}).set_index('statistic').to_dict(orient='index'))
stats

BlockStats(A1=Stat(name='total-population', count=7, median=30, mean=38), A2=Stat(name='non-smoker', count=4, median=30, mean=33), B2=Stat(name='smoker', count=3, median=30, mean=44), C2=Stat(name='unemployed', count=4, median=51, mean=48), D2=Stat(name='employed', count=3, median=24, mean=24), A3=Stat(name='single-adults', count=None, median=None, mean=None), B3=Stat(name='married-adults', count=4, median=51, mean=54), A4=Stat(name='unemployed-non-smoker', count=3, median=36, mean=37))

In [None]:
n = stats.A1.count

min_age = 0
max_age = 125

In [29]:
ages: z3.ArraySort = z3.Array('ages', z3.IntSort(), z3.IntSort())

In [30]:
# Constrain each age to our min and max ages
for i in range(n):
    solver.add(z3.And(z3.Select(ages, i) > min_age, z3.Select(ages, i) < max_age))
    
# median age constraint
from itertools import tee

def pairwise(iterable):
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

def pairwise_sort(solver: z3.Solver, iterable, size: int) -> None:
    """Pairwise Sort
    In order to model our median constraint, we need to ensure that any
    variables related to each of our cohorts is sorted. This pairwise
    sort will iterate through `iterable` in pairs and add constraints 
    ensuring they are sorted smallest -> largest. We use this as part of our
    `add_median_constraint` function below
    Args:
        solver: Our z3.Solver to add the constraints to
        array: The iterable of values we need to sort
        size: The total number of elements we need to sort

    Returns: None

    """
    for a, b in pairwise([iterable[i] for i in range(size)]):
        solver.add(a <= b)

def add_median_constraint(solver: z3.Solver, ages: z3.Array, indices: List[z3.ArithRef], median: int) -> None:
    """Median Constraint
    To calculate the median, we need to ensure our variables are sorted. For this we use a pairwise sort function.
    We then calculate the middle indice of the ages variable related to the indices we're interested in, and
    cater to situations where we have an odd or even number of values in the array.
    Args:
        solver: Our z3.Solver to add the constraints to
        array: The iterable of values we want to constrain to our median
        indices: Various cohorts (I.E smokers, non-smoker, etc) are represented by different indices in the ages 
            array (I.E smokers might be represented by indices (0-4]). We pass in which indices we care about to 
            ensure we're sorting and constraining them
        median: The median we want to constrain to
    """
    pairwise_sort(solver, ages, len(indices))

    med_idx = len(indices) // 2

    if len(indices) % 2 == 0:
        solver.add(ages[indices[med_idx - 1]] + ages[indices[med_idx]] == median * 2)
    else:
        solver.add(z3.Store(ages, indices[med_idx], median) == ages)


def add_mean_constraint(solver: z3.Solver, ages: z3.Array, indices: List[z3.ArithRef], mean: int) -> None:
    """Mean constraint
    
    """
    solver.add(z3.Sum([ages[idx] for idx in indices]) / len(indices) == mean)


add_median_constraint(solver, ages, range(n), stats.A1.median)
add_mean_constraint(solver, ages, range(n), stats.A1.mean)

In [31]:
def split_indices(solver: z3.Solver, name_pair: Tuple[str, str], split_at: int) -> List[z3.ArithRef]:
    first_indices = z3.IntVector(name_pair[0], split_at)
    remaining_indices = z3.IntVector(name_pair[1], n - split_at)

    # indices must between 0 and 7, ensuring we only have 7 people accounted for
    solver.add(*[z3.And(idx >= 0, idx < n) for idx in first_indices + remaining_indices])

    # # indices must be distinct, I.E each person is a distinct person
    solver.add(z3.Distinct(*[idx for idx in first_indices + remaining_indices]))

    # # indices must be sorted for our future median calculations
    pairwise_sort(solver, first_indices, split_at)
    pairwise_sort(solver, remaining_indices, n - split_at)

    return first_indices, remaining_indices

In [32]:
married_indices, single_indices = split_indices(solver=solver, name_pair=('married', 'single'), split_at=stats.B3.count)

# constrain the ages of married people to the legal age
solver.add(*[ages[idx] >= 18 for idx in married_indices])
solver.add(*[ages[idx] >= 0 for idx in single_indices])

# # calculate the average for a subset of our database
# solver.add(z3.Sum([ages[idx] for idx in married_indices]) / stats.B3.count == stats.B3.mean)
add_mean_constraint(solver, ages, married_indices, stats.B3.mean)

# # calculate the median for a subset of our database
add_median_constraint(solver, ages, married_indices, stats.B3.median)

# # This is the supressed statistic, we know that the count must be 0, 1 or 2
single_adult_count = [z3.If(ages[idx] >= 18, 1, 0) for idx in single_indices]

solver.add(z3.Sum(single_adult_count) >= 0)
solver.add(z3.Sum(single_adult_count) <= 2)

In [33]:
smokers, non_smokers = split_indices(solver, name_pair=("smokers", "non-smokers"), split_at=stats.B2.count)

add_mean_constraint(solver=solver, ages=ages, indices=smokers, mean=stats.B2.mean)
add_mean_constraint(solver=solver, ages=ages, indices=non_smokers, mean=stats.A2.mean)

# add median constraints
add_median_constraint(solver=solver, ages=ages, indices=smokers, median=stats.B2.median)
add_median_constraint(solver=solver, ages=ages, indices=non_smokers, median=stats.A2.median)

In [34]:
employed, unemployed = split_indices(solver, name_pair=("employed", "unemployed"), split_at=stats.D2.count)

# add mean constraints
add_mean_constraint(solver=solver, ages=ages, indices=employed, mean=stats.D2.mean)
add_mean_constraint(solver=solver, ages=ages, indices=unemployed, mean=stats.C2.mean)

# add median constraints
add_median_constraint(solver=solver, ages=ages, indices=employed, median=stats.D2.median)
add_median_constraint(solver=solver, ages=ages, indices=unemployed, median=stats.C2.median)

In [35]:
unemployed_non_smokers, _ = split_indices(solver, name_pair=('unemployed_non_smoker', 'other'), split_at=stats.A4.count)

# # intersection of umemployed and non-smoker/ for every unemployed-non-smoker in the database, ensure that one of the unemployed indices and one of the non-smokers are identified as an unemployed-non-smoker
solver.add(
    *[
        z3.And(
            z3.Or(*[i == idx for i in unemployed]),
            z3.Or(*[j == idx for j in non_smokers]),
        )
        for idx in unemployed_non_smokers
    ]
)

# # add mean constraints
add_mean_constraint(solver=solver, ages=ages, indices=unemployed_non_smokers, mean=stats.A4.mean)

# # add median constraints
add_median_constraint(solver=solver, ages=ages, indices=unemployed_non_smokers, median=stats.A4.median)

In [36]:
solver.check()

In [37]:
model = solver.model()

model_df = pd.DataFrame(
    {
        'age': [model.evaluate(z3.Select(ages, i)) for i in range(n)],
        'married': [i in [model[idx] for idx in married_indices] for i in range(n)],
        'smoker': [i in [model[idx] for idx in smokers] for i in range(n)],
        'employed': [i in [model[idx] for idx in employed] for i in range(n)],
    }
)

model_df.age = model_df.age.astype(str).astype(int)
model_df

Unnamed: 0,age,married,smoker,employed
0,9,False,False,False
1,18,False,True,True
2,24,False,False,True
3,30,True,True,True
4,36,True,False,False
5,66,True,False,False
6,84,True,True,False


In [38]:
database

Unnamed: 0,name,age,married,smoker,employed
0,Sara Gray,8,False,False,False
1,Joseph Collins,18,False,True,True
2,Vincent Porter,24,False,False,True
3,Tiffany Brown,30,True,True,True
4,Brenda Small,36,True,False,False
5,Dr. Tina Ayala,66,True,False,False
6,Rodney Gonzalez,84,True,True,False


In [39]:
def check_reconstruction_matches_stats(stats: BaseModel, reconstructed: pd.DataFrame) -> None:
    assert int(reconstructed.age.mean()) == stats.A1.mean
    assert reconstructed.age.median() == stats.A1.median


    married_df, single_df = reconstructed[reconstructed.married == True], reconstructed[reconstructed.married == False]

    assert int(married_df.age.mean()) == stats.B3.mean
    assert married_df.age.median() == stats.B3.median


    smokers_df, non_smokers_df = reconstructed[reconstructed.smoker == True], reconstructed[reconstructed.smoker == False]

    assert int(smokers_df.age.mean()) == stats.B2.mean
    assert smokers_df.age.median() == stats.B2.median
    
    assert int(non_smokers_df.age.mean()) == stats.A2.mean
    assert non_smokers_df.age.median() == stats.A2.median
    
    employed_df, unemployd_df = reconstructed[reconstructed.employed == True], reconstructed[reconstructed.employed == False]

    assert int(employed_df.age.mean()) == stats.D2.mean
    assert employed_df.age.median() == stats.D2.median
    
    assert int(unemployd_df.age.mean()) == stats.C2.mean
    assert unemployd_df.age.median() == stats.C2.median
    
    unemployed_non_smokers_df = reconstructed[
        (reconstructed.employed == False) &
        (reconstructed.smoker == False)
    ]
    
    assert int(unemployed_non_smokers_df.age.mean()) == stats.A4.mean
    assert int(unemployed_non_smokers_df.age.median()) == stats.A4.median

check_reconstruction_matches_stats(stats, model_df)

In [40]:
def check_accuracy(database: pd.DataFrame, reconstructed: pd.DataFrame) -> float:
    match, non_match = 0, 0
    computed = [tuple(v.values()) for v in reconstructed.to_dict(orient='records')]
    original = [tuple(v.values()) for v in database.drop(columns=["name"]).to_dict(orient='records')]

    to_check = [list(zip(computed[i], original[i])) for i in range(len(database))]
    for items in to_check:
        for pair in items:
            if pair[0] == pair[1]:
                match += 1
            else:
                non_match += 1
    return (match / (match + non_match)) * 100

In [47]:
check_accuracy(database, model_df)

96.42857142857143