# Basic privacy measures

References:
- [ ] https://github.com/Nuclearstar/K-Anonymity
- [ ] Disclosure risk, our implementation?

**IMPORTANT:** refer to the [README](https://github.com/kasra-hosseini/privgem#credits) for credits.

In [None]:
# solve issue with autocomplete
%config Completer.use_jedi = False

%load_ext autoreload
%autoreload 2
%matplotlib inline

from warnings import simplefilter
# ignore all future warnings
simplefilter(action='ignore', category=FutureWarning)

In [None]:
import os
import pandas as pd

### Helper functions

In [None]:
def get_spans(df, partition, discrete_columns, scale=None):
    """
    :param        df: the dataframe for which to calculate the spans
    :param partition: the partition for which to calculate the spans
    :param     scale: if given, the spans of each column will be divided
                      by the value in `scale` for that column
    :        returns: The spans of all columns in the partition
    """
    spans = {}
    for column in df.columns:
        if column in discrete_columns:
            span = len(df[column][partition].unique())
        else:
            span = df[column][partition].max()-df[column][partition].min()
        if scale is not None:
            span = span/scale[column]
        spans[column] = span
    return spans

def split(df, partition, column, discrete_columns):
    """
    :param        df: The dataframe to split
    :param partition: The partition to split
    :param    column: The column along which to split
    :        returns: A tuple containing a split of the original partition
    """
    dfp = df[column][partition]
    if column in discrete_columns:
        values = dfp.unique()
        lv = set(values[:len(values)//2])
        rv = set(values[len(values)//2:])
        return dfp.index[dfp.isin(lv)], dfp.index[dfp.isin(rv)]
    else:        
        median = dfp.median()
        dfl = dfp.index[dfp < median]
        dfr = dfp.index[dfp >= median]
        return (dfl, dfr)
    
    
def is_k_anonymous(df, partition, sensitive_column, k=3):
    """
    :param               df: The dataframe on which to check the partition.
    :param        partition: The partition of the dataframe to check.
    :param sensitive_column: The name of the sensitive column
    :param                k: The desired k
    :returns               : True if the partition is valid according to our k-anonymity criteria, False otherwise.
    """
    if len(partition) < k:
        return False
    return True

def partition_dataset(df, feature_columns, sensitive_column, scale, is_valid, discrete_columns, k_anon=3):
    """
    :param               df: The dataframe to be partitioned.
    :param  feature_columns: A list of column names along which to partition the dataset.
    :param sensitive_column: The name of the sensitive column (to be passed on to the `is_valid` function)
    :param            scale: The column spans as generated before.
    :param         is_valid: A function that takes a dataframe and a partition and returns True if the partition is valid.
    :returns               : A list of valid partitions that cover the entire dataframe.
    """
    finished_partitions = []
    partitions = [df.index]
    while partitions:
        partition = partitions.pop(0)
        spans = get_spans(df[feature_columns], partition, scale)
        for column, span in sorted(spans.items(), key=lambda x:-x[1]):
            lp, rp = split(df, partition, column, discrete_columns)
            if not is_valid(df, lp, sensitive_column, k_anon) or not is_valid(df, rp, sensitive_column):
                continue
            partitions.extend((lp, rp))
            break
        else:
            finished_partitions.append(partition)
    return finished_partitions

In [None]:
discrete_columns = [
    'workclass',
    'education',
    'marital-status',
    'occupation',
    'relationship',
    'race',
    'sex',
    'native-country',
    'income']

In [None]:
path2synth_file = "./test/patectgan_001/synthetic_output.csv"
synth_output = pd.read_csv(path2synth_file)
full_spans = get_spans(synth_output, synth_output.index, discrete_columns)
full_spans

In [None]:
# we apply our partitioning method to two columns of our dataset, using "income" as the sensitive attribute
feature_columns = full_spans.keys()
feature_columns = ["age", "race"]
sensitive_column = 'income'
finished_partitions = partition_dataset(synth_output, 
                                        feature_columns, 
                                        sensitive_column, 
                                        full_spans, 
                                        is_k_anonymous,
                                        discrete_columns=discrete_columns,
                                        k_anon=3)

In [None]:
# we get the number of partitions that were created
len(finished_partitions)