# Database reconstruction attack demonstration

David Colls, Sep 2022, MIT License

Revisions prior to publishing May 2023

🧮 This notebook demonstrates a database reconstruction attack; that it is possible to recreate individual rows of a database from summary statistics. It is inspired by this [paper](https://queue.acm.org/detail.cfm?id=3295691) on database reconstruction attacks. Note that the intent is of this notebook is not to compromise any private data, but to raise awareness of  the potential for privacy breaches due to reconstruction attacks.

🧑‍🎓 This notebook was developed in parallel and collaboratively with Mitchell Lisle's [solution](https://github.com/kjam/practical-data-privacy/blob/main/database-reconstruction-attack.ipynb), which provides further documentation and is referenced in the book *Practical Data Privacy*. This notebook demonstrates that individual rows of a database may be reconstructed, even if only summary statistics are shared, by considering the constraints that the statistics place on possible values of the data.

🖥 This notebook uses the CP-SAT constraint programming solver from [OR-Tools](https://developers.google.com/optimization).

# Setup

▶️ You can run this notebook as is in [Colab](https://colab.research.google.com/), or a [Jupyter](https://jupyter.org/) instance where the environment includes `pip` (see [pip installation instructions](https://pip.pypa.io/en/stable/installation/)).


✔︎ Install ortools if not already in your environment

In [1]:
import sys
!{sys.executable} -m pip list | grep ortools || {sys.executable} -m pip install ortools

ortools                          9.6.2534


✔︎ Import required modules

In [2]:
from ortools.sat.python import cp_model

import itertools as it
import numpy as np
import pandas as pd
from pydantic import BaseModel
from typing import Optional

# Problem formulation

This section briefly describes the problem. Please see [more detailed documentation](https://github.com/kjam/practical-data-privacy/blob/main/database-reconstruction-attack.ipynb) if required.

## Background

📋 Alice has conducted a survey asking each of 7 people the following 5 questions:
1. What is your name? (String response)
2. What is you age? (Integer response)
3. Are you married? (Boolean response)
4. Are you a smoker? (Boolean response)
5. Are you employed? (Boolean response)

🔒 She collects the data shown below, which she intends to keep private.

In [3]:
database = pd.DataFrame([
    ("Sara Gray", 8, False, False, False),
    ("Joseph Collins", 18, False, True, True),
    ("Vincent Porter", 24, False, False, True),
    ("Tiffany Brown", 30, True, True, True),
    ("Brenda Small", 36, True, False, False),
    ("Dr. Tina Ayala", 66, True, False, False),
    ("Rodney Gonzalez", 84, True, True, False)],
    columns=("name", "age", "married", "smoker", "employed")
)

database

Unnamed: 0,name,age,married,smoker,employed
0,Sara Gray,8,False,False,False
1,Joseph Collins,18,False,True,True
2,Vincent Porter,24,False,False,True
3,Tiffany Brown,30,True,True,True
4,Brenda Small,36,True,False,False
5,Dr. Tina Ayala,66,True,False,False
6,Rodney Gonzalez,84,True,True,False


📊 Alice does, however, define some statistics that she wants to share with other researchers. She will provide the count, median and mean age of cohorts based on classes such as "smoker".

In [4]:
# Each cohort includes some or all of the following information
class Stat(BaseModel):
    name: str
    count: Optional[int]
    median: Optional[int]
    mean: Optional[int]
    denom: Optional[int]  # divide mean by denom if defined to get rational mean

# A number of cohorts are defined
class BlockStats(BaseModel):
    A1: Stat
    A2: Stat
    B2: Stat
    C2: Stat
    D2: Stat
    A3: Stat
    B3: Stat
    A4: Stat

# The statistics for the cohorts are caluclated and shared as follows
stats = BlockStats(
    A1=Stat(name="global_population", count=7, median=30, mean=38),
    A2=Stat(name="non_smoker", count=4, median=30, mean=67, denom=2),
    B2=Stat(name="smoker", count=3, median=30, mean=44),
    C2=Stat(name="unemployed", count=4, median=51, mean=97, denom=2),
    D2=Stat(name="employed", count=3, median=24, mean=24),
    A3=Stat(name="single_adults", count=None, median=None, mean=None),
    B3=Stat(name="married_adults", count=4, median=51, mean=54),
    A4=Stat(name="unemployed-non-smoker", count=3, median=36, mean=110, denom=3)
)

pd.DataFrame(stats.dict()).transpose()

Unnamed: 0,name,count,median,mean,denom
A1,global_population,7.0,30.0,38.0,
A2,non_smoker,4.0,30.0,67.0,2.0
B2,smoker,3.0,30.0,44.0,
C2,unemployed,4.0,51.0,97.0,2.0
D2,employed,3.0,24.0,24.0,
A3,single_adults,,,,
B3,married_adults,4.0,51.0,54.0,
A4,unemployed-non-smoker,3.0,36.0,110.0,3.0


## The question

🕵️ Do the aggreate statistics Alice shares actually provide enough information to reconstruct the individual rows of the database she intends to keep private?

#OR-Tools CP-SAT model

⚛️ To use the CP-SAT solver, we first define a **model**. The model contains a collection of **variables** and **constraints**. The constraints determine the possible values of the variables.

🔢 We define **variables** as containers for the values in the original database. We infer the columns include age, an integer, and a number of classes based on the statistics that were shared - married, smoker and employed, Booleans. The number of rows will equal the global count statistic. We define a variable for each cell.

🆗 We define **constraints** inferred from the aggregate statistics. The constraints take various forms based on what we can infer from the statistics.

🎯 When we are able to **solve the model** (below), the values of the variables will enable us to reconstruct the private database.


## Model and variables

In the definition of the age variable, we constrain ages to be from 👶 0 to 🧓 115.

In [5]:
def init_model(global_stat):
  # initialise the constraint programming model
  return cp_model.CpModel()

In [6]:
def add_age_variables(model, global_stat, vmin=0, vmax=115):
  # add variables to represent each row in the age column
  rows = range(global_stat.count)
  return [model.NewIntVar(vmin, vmax, f'age_{i}') for i in rows]

In [7]:
def add_class_variables(model, class_stat, global_stat):
  # add variables to represent each row in one of the class columns
  rows = range(global_stat.count)
  return [model.NewBoolVar(f'class_{class_stat.name}_{i}') for i in rows]

🎬 We initialise the model and add age variables. We'll add class variables later.

In [8]:
model = init_model(stats.A1)
ages = add_age_variables(model, stats.A1)

##Global statistic constraints 🌏

🆗 We define constraints to ensure values are sorted, have a specified median (which assumes sorted), and specified mean.

In [9]:
def add_sorted_constraint(model, values):
  # require that the list of values is sorted
  for k in zip(values[:-1], values[1:]):
    model.Add(k[0] <= k[1])

In [10]:
def add_median_constraint(model, values, median):
  # require that the middle value in the list is equal to the specified median
  # assumes that the list is sorted
  mid = len(values) // 2
  constraint = model.Add(values[mid - 1] + values[mid] == median * 2) \
    if len(values) % 2 == 0 \
    else model.Add(values[mid] == median)
  return constraint

In [11]:
def add_mean_constraint(model, values, mean, denom):
  # require that the calculated mean of the values is equal to the specified mean,
  # if an integer, or falls between the two adjacent integers, if rational
  denominator = denom if denom is not None else 1
  mean_lb = mean // denominator
  interval_width = 1 if denom is not None else 0
  mean_ub = mean_lb + interval_width
  constraint_lb = model.Add(sum(values) >= len(values) * mean_lb)
  constraint_ub = model.Add(sum(values) <= len(values) * mean_ub)
  return constraint_lb, constraint_ub

In [12]:
def add_global_stat_constraint(model, values, global_stat):
  # add constraints inferred from the global statistic
  add_sorted_constraint(model, values)
  add_median_constraint(model, values, global_stat.median)
  add_mean_constraint(model, values, global_stat.mean, global_stat.denom)

✅ We add these constraints to the model. Now the databse values must satisfy the constraints resulting from the global statistic A1.

In [13]:
add_global_stat_constraint(model, ages, stats.A1)

## Class and cohort statistic constraints 👪

🔀 Most of the statistics are for cohorts rather than the global population. We don't know exactly which records make up the cohort though, so we're going to explore every possible cohort for a statistic, using `itertools.combinations`.

In [14]:
def get_cohorts(global_count, cohort_count):
  # enumerate the population and each cohort that could be selected
  # according to the global count, and cohort count
  population = list(range(global_count))
  cohorts = list(it.combinations(population, cohort_count))
  return population, cohorts

🔗 We add a link variable that allows us to specifically enforce the constraints for each possible cohort combination. We're using the channelling technique described at https://developers.google.com/optimization/cp/channeling

In [15]:
def add_combination_link_variables(model, class_stat, cohorts):
  # add a channeling link variable for each combination of cohort for the statistic
  combinations = range(len(cohorts))
  return [model.NewBoolVar(f'combn_{class_stat.name}_{j}') for j in combinations]

🆗 We add constraints to ensure that only one link variable (cohort combination) may be active, and that this activation causes the correct cohort combination to be considered.

In [16]:
def add_combination_constraints(model, classes, combination_links, cohorts, count):
  # require that one cohort combination is considered for class statistics requirements
  model.Add(sum(classes) == count)
  model.Add(sum(combination_links) == 1)
  for j, cohort in enumerate(cohorts):
    cohort_classes = [c for i, c in enumerate(classes) if i in cohort]
    model.Add(sum(cohort_classes) == count).OnlyEnforceIf(combination_links[j])

🆗 We can extend our median and mean constraints to be applied to one cohort combination.

In [17]:
def add_median_constraint_cohort(model, values, median, link):
  # require the median of values only if link is activated
  constraint = add_median_constraint(model, values, median)
  constraint.OnlyEnforceIf(link)

In [18]:
def add_mean_constraint_cohort(model, values, mean, denom, link):
  # require the mean (rational over denom) of values only if link is activated
  constraint_lb, constraint_ub = add_mean_constraint(model, values, mean, denom)
  constraint_lb.OnlyEnforceIf(link)
  constraint_ub.OnlyEnforceIf(link)

In [19]:
def add_cohort_stat_constraints(model, values, links, cohorts, cohort_stat):
  # require the cohort statistics for the values in one of the possible cohorts
  for j, cohort in enumerate(cohorts):
    combn_values = [v for i, v in enumerate(values) if i in cohort]
    add_mean_constraint_cohort(model, combn_values, cohort_stat.mean, cohort_stat.denom, links[j])
    add_median_constraint_cohort(model, combn_values, cohort_stat.median, links[j])

🆗 We may apply all of these constraints in a coordinated manner to require the class statistics apply to the database values. Where statistics come in complementary pairs (positive and negative classes), we apply them together.

In [20]:
def add_class_stat(model, values, global_stat, class_stat, class_neg_stat=None):
  # configure variables and constraints to add a class statistic
  # class_neg_stat should be a complementary statistic, if it exists
  population, cohorts = get_cohorts(global_stat.count, class_stat.count)
  classes = add_class_variables(model, class_stat, global_stat)
  combination_links = add_combination_link_variables(model, class_stat, cohorts)
  add_combination_constraints(model, classes, combination_links, cohorts, class_stat.count)

  add_cohort_stat_constraints(model, values, combination_links, cohorts, class_stat)
  if class_neg_stat is not None:
    neg_cohorts = [sorted(set(population) - set(s)) for s in cohorts]
    add_cohort_stat_constraints(model, values, combination_links, neg_cohorts, class_neg_stat)

  return classes, combination_links

✅ Apply the actual statistics that were shared.

In [21]:
classes_B3, _ = add_class_stat(model, ages, stats.A1, stats.B3)
classes_B2A2, _ = add_class_stat(model, ages, stats.A1, stats.B2, class_neg_stat=stats.A2)
classes_D2C2, _ = add_class_stat(model, ages, stats.A1, stats.D2, class_neg_stat=stats.C2)

# note we might also add explicit constraints with implicit information from stat A3
# but that is not needed to identify a single solution in this case

classes = classes_B3 + classes_B2A2 + classes_D2C2

## Class cross constraints 👪➕👪

➕ One statistic crosses two classes  `A4=Stat(name="unemployed-non-smoker", count=3, median=36, mean=36)`. We use a similar pattern to the single class cohorts, but in this case create links for the cross combinations.

In [22]:
def add_cross_link_variables(model, cross_stat, cohorts):
  # add a channeling link variable for each cross combination for the statistic
  combinations = range(len(cohorts))
  return [model.NewBoolVar(f'cross_{cross_stat.name}_{j}') for j in combinations]

In [23]:
def add_cross_stat_constraints(model, crossed_classes, class_neg, cross_links, cohorts, count):
  # require the cohort statistics for the values in one of the possible cross cohorts
  model.Add(sum(cross_links) == 1)
  for j, cohort in enumerate(cohorts):
    for k, xcls in enumerate(crossed_classes):
      cross_classes = [c for i, c in enumerate(xcls) if i in cohort]
      target = 0 if class_neg[k] else count
      model.Add(sum(cross_classes) == target).OnlyEnforceIf(cross_links[j])

In [24]:
def add_cross_stat(model, values, global_stat, cross_stat, crossed_classes, class_neg):
  # configure variables and constraints to add a cross statistic
  _, cohorts = get_cohorts(global_stat.count, cross_stat.count)
  cross_links = add_cross_link_variables(model, cross_stat, cohorts)
  add_cross_stat_constraints(model, crossed_classes, class_neg, cross_links, cohorts, cross_stat.count)
  add_cohort_stat_constraints(model, values, cross_links, cohorts, cross_stat)
  return cross_links

✅ Apply the cross statistics that were shared.

In [25]:
crosses = add_cross_stat(model, ages, stats.A1, stats.A4, [classes_B2A2, classes_D2C2], [True, True])

#Solve the model

🎯 Now we apply a solver to the defined constraint model to find permissible values for the variables

## Solver helpers

🦾 These helper classes and methods simplify the OR-Tools interface.

In [26]:
class SolutionAccumulator(cp_model.CpSolverSolutionCallback):

    def __init__(self, variables, df, max_solns):
        cp_model.CpSolverSolutionCallback.__init__(self)
        self.__variables = variables
        self.__solution_count = 0
        self.__max_solns = max_solns
        self.__df = df

    def on_solution_callback(self):
        self.__df.loc[self.__solution_count] = [self.Value(v) for v in self.__variables]
        self.__solution_count += 1
        if self.__solution_count >= self.__max_solns:
          self.StopSearch()

    def solution_count(self):
        return self.__solution_count

In [27]:
def get_dra_solver(model, variables):
  # get the solver for database reconstruction
  model.AddDecisionStrategy(variables,
                            cp_model.CHOOSE_FIRST,
                            cp_model.SELECT_MIN_VALUE)
  solver = cp_model.CpSolver()
  solver.parameters.search_branching = cp_model.FIXED_SEARCH
  solver.parameters.enumerate_all_solutions = True
  return solver

def find_dra_solutions(model, solver, variables, max_solns):
  # find possible solutions for database reconstruction
  df = pd.DataFrame(columns=[str(v) for v in variables])
  solution_logger = SolutionAccumulator(variables, df, max_solns)
  solver.Solve(model, solution_logger)
  return df

##Run solver

🤖 We execute the solver. In general, the solver may find multiple solutions, up to the specified limit. In this instance, the constraints are such that there is only one unique solution. Note that there are about 5*10^20 possible original database configurations, without constraints.

In [28]:
solver = get_dra_solver(model, ages + classes)
sol_df = find_dra_solutions(model, solver, ages + classes, 500)

print(len(sol_df.index))
sol_df.head().transpose()

1


Unnamed: 0,0
age_0,8
age_1,18
age_2,24
age_3,30
age_4,36
age_5,66
age_6,84
class_married_adults_0,0
class_married_adults_1,0
class_married_adults_2,0


# Evaluate solutions

⚖️ Compare reconstructed solutions to the original private database

In [29]:
def reconstruct_original_from_solution(df, solution_id):
  # reconstruct the original database from the solution variable values
  rows = stats.A1.count
  ages = list(df.iloc[solution_id, 0:rows])
  married = list(df.iloc[solution_id, rows:2*rows].astype(bool))
  smoker = list(df.iloc[solution_id, 2*rows:3*rows].astype(bool))
  employed = list(df.iloc[solution_id, 3*rows:4*rows].astype(bool))

  reconstructed = pd.DataFrame([ages, married, smoker, employed]).transpose()
  reconstructed.columns = ("age", "married", "smoker", "employed")
  return reconstructed

In [30]:
reconstructed = reconstruct_original_from_solution(sol_df, 0)
reconstructed

Unnamed: 0,age,married,smoker,employed
0,8,False,False,False
1,18,False,True,True
2,24,False,False,True
3,30,True,True,True
4,36,True,False,False
5,66,True,False,False
6,84,True,True,False


In [31]:
# check that every value is the same as the original private database
np.sum((reconstructed == database[reconstructed.columns]).to_numpy()) == np.size(reconstructed.to_numpy())

True

# Further challenges

Is there a more elegant solution than creating all those linkage variables?