In [None]:
import os
os.chdir('..')

In [None]:
from fair_graphs.launchers.launch_fap import optim_fairAutoEncoder

In [None]:
lambdas = [1, 1e-1, 1e-2, 1e-3]
num_splits = 30

In [None]:
import matplotlib.pyplot as plt

def print_scatter(mses_dp, fair_losses_dp,
                  mses_eo0, fair_losses_eo0,
                  mses_eo1, fair_losses_eo1):
    fig, axs = plt.subplots(1, 3, figsize=(15, 5), constrained_layout = True)

    for ic, color in enumerate(['blue', 'orange', 'green', 'red']):
        axs[0].scatter(mses_dp[ic], fair_losses_dp[ic], c=color, s=200, label='lambda =' + str(lambdas[ic]),
                alpha=0.9, edgecolors='none')
        axs[1].scatter(mses_eo0[ic], fair_losses_eo0[ic], c=color, s=200, label='lambda =' + str(lambdas[ic]),
                alpha=0.9, edgecolors='none')
        axs[2].scatter(mses_eo1[ic], fair_losses_eo1[ic], c=color, s=200, label='lambda =' + str(lambdas[ic]),
                alpha=0.9, edgecolors='none')

    axs[0].set(xlabel='MSE', ylabel='DP')
    axs[1].set(xlabel='MSE', ylabel='EO-')
    axs[2].set(xlabel='MSE', ylabel='EO+')
    axs[0].legend()
    axs[1].legend()
    axs[2].legend()
    plt.show()

## German

In [None]:
from fair_graphs.datasets.graph_datasets import GermanData
                                                 
data = GermanData(sensitive_attribute = 'Gender',
                  target_attribute = 'GoodCustomer',
                  include_sensitive = True,
                  num_samples = 0,
                  pre_scale_features = False)

data.samples.shape, data.sensitive.shape, data.labels.shape, data.adj_mtx

In [None]:
mses_dp, fair_losses_dp = optim_fairAutoEncoder(num_splits = num_splits,
                                                data = data,
                                                lambdas = lambdas,
                                                metric = 'dp')

In [None]:
mses_eo0, fair_losses_eo0 = optim_fairAutoEncoder(num_splits = num_splits,
                                                  data = data,
                                                  lambdas = lambdas,
                                                  metric = 'eo',
                                                  pos = 0)

In [None]:
mses_eo1, fair_losses_eo1 = optim_fairAutoEncoder(num_splits = num_splits,
                                                  data = data,
                                                  lambdas = lambdas,
                                                  metric = 'eo',
                                                  pos = 1)

In [None]:
print_scatter(mses_dp, fair_losses_dp,
              mses_eo0, fair_losses_eo0,
              mses_eo1, fair_losses_eo1)

## Bail

In [None]:
from fair_graphs.datasets.graph_datasets import BailData
data = BailData(sensitive_attribute = 'WHITE',
                  target_attribute = 'RECID',
                  include_sensitive = True,
                  num_samples = 0,
                  pre_scale_features = False)

data.samples.shape, data.sensitive.shape, data.labels.shape, data.adj_mtx

In [None]:
mses_dp, fair_losses_dp = optim_fairAutoEncoder(num_splits = num_splits,
                                                data = data,
                                                lambdas = lambdas,
                                                metric = 'dp')

In [None]:
mses_eo0, fair_losses_eo0 = optim_fairAutoEncoder(num_splits = num_splits,
                                                  data = data,
                                                  lambdas = lambdas,
                                                  metric = 'eo',
                                                  pos = 0)

In [None]:
mses_eo1, fair_losses_eo1 = optim_fairAutoEncoder(num_splits = num_splits,
                                                  data = data,
                                                  lambdas = lambdas,
                                                  metric = 'eo',
                                                  pos = 1)

In [None]:
print_scatter(mses_dp, fair_losses_dp,
              mses_eo0, fair_losses_eo0,
              mses_eo1, fair_losses_eo1)

## Credit

In [None]:
from fair_graphs.datasets.graph_datasets import CreditData

data = CreditData(sensitive_attribute = 'Age',
                  target_attribute = 'NoDefaultNextMonth',
                  include_sensitive = True,
                  num_samples = 0,
                  pre_scale_features = False)

data.samples.shape, data.sensitive.shape, data.labels.shape, data.adj_mtx

In [None]:
mses_dp, fair_losses_dp = optim_fairAutoEncoder(num_splits = num_splits,
                                                data = data,
                                                lambdas = lambdas,
                                                metric = 'dp')

In [None]:
mses_eo0, fair_losses_eo0 = optim_fairAutoEncoder(num_splits = num_splits,
                                                  data = data,
                                                  lambdas = lambdas,
                                                  metric = 'eo',
                                                  pos = 0)

In [None]:
mses_eo1, fair_losses_eo1 = optim_fairAutoEncoder(num_splits = num_splits,
                                                  data = data,
                                                  lambdas = lambdas,
                                                  metric = 'eo',
                                                  pos = 1)

In [None]:
print_scatter(mses_dp, fair_losses_dp,
              mses_eo0, fair_losses_eo0,
              mses_eo1, fair_losses_eo1)

## Pokec 

In [None]:
from fair_graphs.datasets.graph_datasets import PokecData

data = PokecData(sensitive_attribute='region', 
                 target_attribute='marital_status_indicator',
                 include_sensitive=True, 
                 num_samples = 0, 
                 pre_scale_features=False)

data.samples.shape, data.sensitive.shape, data.labels.shape, data.adj_mtx

In [None]:
mses_dp, fair_losses_dp = optim_fairAutoEncoder(num_splits = num_splits,
                                                data = data,
                                                lambdas = lambdas,
                                                metric = 'dp',
                                                train_percentage = .4)

In [None]:
mses_eo0, fair_losses_eo0 = optim_fairAutoEncoder(num_splits = num_splits,
                                                  data = data,
                                                  lambdas = lambdas,
                                                  metric = 'eo',
                                                  pos = 0,
                                                  train_percentage = .4)

In [None]:
mses_eo1, fair_losses_eo1 = optim_fairAutoEncoder(num_splits = num_splits,
                                                  data = data,
                                                  lambdas = lambdas,
                                                  metric = 'eo',
                                                  pos = 1,
                                                  train_percentage = .4)

In [None]:
print_scatter(mses_dp, fair_losses_dp,
              mses_eo0, fair_losses_eo0,
              mses_eo1, fair_losses_eo1)

## Facebook

In [None]:
from fair_graphs.datasets.graph_datasets import FacebookData

data = FacebookData(sensitive_attribute='gender', 
                    target_attribute='egocircle',
                    include_sensitive=True, 
                    num_samples=0)

data.samples.shape, data.sensitive.shape, data.labels.shape, data.adj_mtx

In [None]:
mses_dp, fair_losses_dp = optim_fairAutoEncoder(num_splits = 1,
                                                data = data,
                                                lambdas = lambdas,
                                                metric = 'dp')

In [None]:
mses_eo0, fair_losses_eo0 = optim_fairAutoEncoder(num_splits = num_splits,
                                                  data = data,
                                                  lambdas = lambdas,
                                                  metric = 'eo',
                                                  pos = 0)

In [None]:
mses_eo1, fair_losses_eo1 = optim_fairAutoEncoder(num_splits = num_splits,
                                                  data = data,
                                                  lambdas = lambdas,
                                                  metric = 'eo',
                                                  pos = 1)

In [None]:
print_scatter(mses_dp, fair_losses_dp,
              mses_eo0, fair_losses_eo0,
              mses_eo1, fair_losses_eo1)

## Google Plus

In [None]:
from fair_graphs.datasets.graph_datasets import GooglePlusData

data = GooglePlusData(sensitive_attribute='gender', 
                      target_attribute='egocircle',
                      include_sensitive=True, 
                      num_samples=0)

data.samples.shape, data.sensitive.shape, data.labels.shape, data.adj_mtx

In [None]:
mses_dp, fair_losses_dp = optim_fairAutoEncoder(num_splits = num_splits,
                                                data = data,
                                                lambdas = lambdas,
                                                metric = 'dp')

In [None]:
mses_eo0, fair_losses_eo0 = optim_fairAutoEncoder(num_splits = num_splits,
                                                  data = data,
                                                  lambdas = lambdas,
                                                  metric = 'eo',
                                                  pos = 0)

In [None]:
mses_eo1, fair_losses_eo1 = optim_fairAutoEncoder(num_splits = num_splits,
                                                  data = data,
                                                  lambdas = lambdas,
                                                  metric = 'eo',
                                                  pos = 1)

In [None]:
print_scatter(mses_dp, fair_losses_dp,
              mses_eo0, fair_losses_eo0,
              mses_eo1, fair_losses_eo1)