This notebook is used to visualize the pre-calculated pf surfaces

In [59]:
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt

import pandas as pd
from scipy.interpolate import griddata

# Setup and Functions

In [60]:
# Initialize figure settings
SMALL_SIZE = 22
MEDIUM_SIZE = SMALL_SIZE + 2
BIGGER_SIZE = SMALL_SIZE + 5

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

rc_fonts = {
    "text.usetex": True,
    'text.latex.preamble':
        r"""
        \usepackage{amsmath}
        \usepackage{libertine}
        \usepackage[libertine]{newtxmath}
        """,
}

plt.rcParams.update(rc_fonts)

In [2]:
def is_pareto_efficient(costs, return_mask = False):
        """
            Find the pareto-efficient points
            :param costs: An (n_points, n_costs) array
            :param return_mask: True to return a mask
            :return: An array of indices of pareto-efficient points.
                If return_mask is True, this will be an (n_points, ) boolean array
                Otherwise it will be a (n_efficient_points, ) integer array of indices.
        """
        is_efficient = np.arange(costs.shape[0])
        n_points = costs.shape[0]
        next_point_index = 0  # Next index in the is_efficient array to search for
        while next_point_index<len(costs):
            nondominated_point_mask = np.any(costs<costs[next_point_index], axis=1)
            nondominated_point_mask[next_point_index] = True
            is_efficient = is_efficient[nondominated_point_mask]  # Remove dominated points
            costs = costs[nondominated_point_mask]
            next_point_index = np.sum(nondominated_point_mask[:next_point_index])+1
        if return_mask:
            is_efficient_mask = np.zeros(n_points, dtype = bool)
            is_efficient_mask[is_efficient] = True
            return is_efficient_mask
        else:
            return is_efficient

In [125]:
dataset = 'mnist'
method = 'dpsgd-g-a'
fairness_var = 'tau'
coverage = False
loss_function_file_path = f"./{dataset}/{method}/"

# Graphing

In [126]:
# Read data from a csv
data = pd.read_csv(f'{loss_function_file_path}results.csv')
if max(data['accuracy'].to_numpy()) <= 1:
    data['accuracy'] = data['accuracy'] * 100
# Get pf points
if method == 'fairPATE':
    pf_index = is_pareto_efficient(np.stack([data['achieved_epsilon'].to_numpy(), 
                                            data['achieved_fairness_gap'].to_numpy(),
                                            -data['accuracy'].to_numpy(),
                                            -data['coverage'].to_numpy(),], axis=1))
elif method == 'dpsgd-g-a':
    pf_index = is_pareto_efficient(np.stack([data['achieved_epsilon'].to_numpy(), 
                                            data['achieved_fairness_gap'].to_numpy(),
                                            -data['accuracy'].to_numpy()], axis=1))

In [127]:
if method == 'fairPATE':
    results = np.stack([data[fairness_var].to_numpy(), 
                        data['achieved_epsilon'].to_numpy(), 
                        data['achieved_fairness_gap'].to_numpy(),
                        data['accuracy'].to_numpy(),
                        data['coverage'].to_numpy()], axis=1)
elif method == 'dpsgd-g-a':
    results = np.stack([data[fairness_var].to_numpy(), 
                    data['achieved_epsilon'].to_numpy(), 
                    data['achieved_fairness_gap'].to_numpy(),
                    data['accuracy'].to_numpy()], axis=1)

In [135]:
x = np.array(data['achieved_epsilon'])[pf_index]
y = np.array(data['achieved_fairness_gap'])[pf_index]
z = np.array(data['accuracy'])[pf_index]


xi = np.linspace(x.min(), x.max(), 100)
yi = np.linspace(y.min(), y.max(), 100)

X,Y = np.meshgrid(xi,yi)

Z = griddata((x,y),z,(X,Y), method='linear')
if method == 'fairPATE':
    c = np.array(data['coverage'])[pf_index]
    C = griddata((x,y),c,(X,Y), method='linear')
    colorscale='Viridis'
else:
    C = None
    colorscale='Plasma'

In [None]:

fig = go.Figure(go.Surface(x=xi,y=yi,z=Z, surfacecolor=C, colorscale=colorscale))
fig.update_layout(width=700, 
                  height=500,
                  margin=dict(l=0, r=50, b=10, t=10),
                  scene_camera= dict(
                    up=dict(x=0, y=0, z=10),
                    center=dict(x=0, y=0, z=-0.2),
                    eye=dict(x=-1.25, y=1.25, z=0.8)
                  ), 
                  scene=dict(
                    xaxis_title='Epsilon Budget Achieved',
                    yaxis_title='Max Fairness Violation',
                    zaxis_title='Accuray',
                     ))
fig.update_yaxes(automargin='left+top')
fig.show()

In [137]:
fig.write_image(f"../../visualizations/figures/{dataset}_{method}_pf.pdf")