# Goal: Fix Learning Fair Representation

# Import Libraries





In [1]:
try:
  from google.colab import drive
  drive.mount('/content/drive')
  import sys
  path_to_project = '/content/drive/MyDrive/FairAlgorithm'
  sys.path.append(path_to_project)
  !sudo apt install libcairo2-dev pkg-config python3-dev
  %pip install -r /content/drive/MyDrive/FairAlgorithm/source/requirements.txt  #UPDATE THIS LINE
  IN_COLAB = True
except:
  IN_COLAB = False

Mounted at /content/drive
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
python3-dev is already the newest version (3.10.6-1~22.04.1).
python3-dev set to manually installed.
The following packages were automatically installed and are no longer required:
  libbz2-dev libpkgconf3 libreadline-dev
Use 'sudo apt autoremove' to remove them.
The following additional packages will be installed:
  libblkid-dev libblkid1 libcairo-script-interpreter2 libffi-dev
  libglib2.0-dev libglib2.0-dev-bin libice-dev liblzo2-2 libmount-dev
  libmount1 libpixman-1-dev libselinux1-dev libsepol-dev libsm-dev
  libxcb-render0-dev libxcb-shm0-dev
Suggested packages:
  libcairo2-doc libgirepository1.0-dev libglib2.0-doc libgdk-pixbuf2.0-bin
  | libgdk-pixbuf2.0-dev libxml2-utils libice-doc cryptsetup-bin libsm-doc
The following packages will be REMOVED:
  pkgconf r-base-dev
The following NEW packages will be installed:
  libblkid-dev libcairo-script-interpreter2 

In [2]:
#import libraries
import numpy as np
import pandas as pd
import pickle
from sklearn import metrics
from sklearn.model_selection import cross_validate,cross_val_score,cross_val_predict,train_test_split,StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier, AdaBoostClassifier, BaggingClassifier
from tqdm.notebook import tqdm

from sklearn.metrics import classification_report, recall_score, accuracy_score, precision_score, confusion_matrix, roc_curve
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, normalize

# Mitigation
np.random.seed(1234)
from rich import print
from rich.columns import Columns
from rich.panel import Panel
from rich.align import Align
from source.utils.print_util import *
from source.utils.data_preprocessing import *
import matplotlib.pyplot as plt
from fairlearn.metrics import  MetricFrame, count, false_negative_rate, false_positive_rate, selection_rate, equalized_odds_difference, demographic_parity_difference
from fairlearn.preprocessing import CorrelationRemover
from fairlearn.adversarial import AdversarialFairnessClassifier
from fairlearn.reductions import ExponentiatedGradient, Moment
from fairlearn.postprocessing import ThresholdOptimizer
from aif360.datasets import BinaryLabelDataset, StructuredDataset, StandardDataset
from aif360.metrics import BinaryLabelDatasetMetric, ClassificationMetric, Metric
from aif360.algorithms.preprocessing import DisparateImpactRemover, Reweighing, LFR, OptimPreproc
from aif360.algorithms.preprocessing.optim_preproc_helpers.opt_tools import OptTools
from aif360.algorithms.inprocessing import PrejudiceRemover, AdversarialDebiasing, ExponentiatedGradientReduction
from aif360.algorithms.postprocessing import RejectOptionClassification, CalibratedEqOddsPostprocessing, EqOddsPostprocessing
from aif360.algorithms import Transformer
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

pip install 'aif360[inFairness]'


# Configure the notebook


In the next code cell, set all the variables that are used throughout the notebook.  
The variables are used to configure the notebook, and to set the paths to the data files.

Modify the variables in the next code cell to configure the notebook

- `dataset_name`: The name of the dataset file.
- `dataset_path`: The path to the dataset file.
- `target`: The target feature to predict.
- `target_variable_labels`: The labels for the target feature.
- `sensible_attribute`: The sensible attribute to use for bias mitigation.

In [3]:
# options: [fl-cr, fl-to, aif360-rw, aif360-di, aif360-lfr, aif360-op, aif360-ad, aif360-pr, aif360-er, aif360-ce, aif360-eo, aif360-roc]
mitigation = 'aif360-lfr'

In [4]:
#INPUT
#dataset_name = "diabetes-prediction"
dataset_name = 'diabetes-women' #'stroke-prediction'

if dataset_name == "diabetes-women":
  ignore_cols = ['Age']
  target_variable = 'Outcome'
  target_variable_labels= [1,0]
  sensible_attribute = 'AgeCategory'
  default_mappings = {
      'label_maps': [{1.0: 'Diabetic', 0.0: 'NonDiabetic'}],
      'protected_attribute_maps': [{1.0: 'Adult', 0.0: 'Young'}]
  }

elif dataset_name == "sepsis":
  ignore_cols = []
  target_variable = 'Mortality'
  target_variable_labels= [1,0]
  sensible_attribute = 'Gender_cat'

elif dataset_name == "diabetes-prediction":
  ignore_cols = []
  target_variable = 'diabetes'
  target_variable_labels= [1,0]
  sensible_attribute = 'race_category'
  default_mappings = {
      'label_maps': [{1.0: 'Diabetic', 0.0: 'NonDiabetic'}],
      'protected_attribute_maps': [{1.0: 'Caucasian', 0.0: 'Non-Caucasian'}]
  }

elif dataset_name == 'stroke-prediction':
  ignore_cols = []
  target_variable = 'stroke_prediction'
  target_variable_labels= [1,0]
  sensible_attribute = 'residence_category'

  default_mappings = {
     'label_maps': [{1.0: 'Stroke', 0.0: 'No Stroke'}],
     'protected_attribute_maps': [{1.0: 'Urban', 0.0: 'Rural'}]
  }

In [5]:
n_estimators = 30
random_seed = 1234
n_splits= 10

models = {'Logistic Regression':LogisticRegression(max_iter=500),
          'Decision Tree':DecisionTreeClassifier(max_depth=None),
          'Bagging':BaggingClassifier(DecisionTreeClassifier(max_depth=3),n_estimators=n_estimators),
          'Random Forest':RandomForestClassifier(n_estimators=n_estimators),
          'Extremely Randomized Trees':ExtraTreesClassifier(n_estimators=n_estimators),
          'Ada Boost':AdaBoostClassifier(DecisionTreeClassifier(max_depth=3),n_estimators=n_estimators)}

family = ['division', 'subtraction']
fairness_catalogue = ['GroupFairness', 'PredictiveParity', 'PredictiveEquality', 'EqualOpportunity', 'EqualizedOdds', 'ConditionalUseAccuracyEquality', 'OverallAccuracyEquality', 'TreatmentEquality', 'FORParity', 'FN', 'FP']

all_mitigations = ['original','fl-cr', 'fl-to', 'aif360-rw', 'aif360-di', 'aif360-lfr', 'aif360-op', 'aif360-ad', 'aif360-pr', 'aif360-er', 'aif360-ce', 'aif360-eo', 'aif360-roc']

without_model_mitigations = ['aif360-ad', 'aif360-pr', 'aif360-er']
new_dataset_mitigations = ["fl-cr", "aif360-di", "aif360-op" "aif360-lfr"]

In [18]:
# Load the correct source dataset, considering that pre-processing techniques
# modify the original dataset, while in- and post- processing do not
if mitigation in new_dataset_mitigations:
  dataset_path = path_to_project + '/data/mitigated/mitigated-{}-{}.csv'.format(dataset_name, mitigation) if IN_COLAB else 'data/mitigated/mitigated-{}-{}.csv'.format(dataset_name, mitigation)
else:
  dataset_path = path_to_project + '/data/preprocessed/preprocessed-{}.csv'.format(dataset_name) if IN_COLAB else 'data/preprocessed/preprocessed-{}.csv'.format(dataset_name)

In [24]:
df = pd.read_csv(dataset_path)
feature_cols = df.columns

In [8]:
config = {}
config['df']= df
config['target_variable'] = target_variable
config['sensible_attribute'] = sensible_attribute
config['path_to_project'] = path_to_project
config['n_splits'] = n_splits
config['models'] = models
config['n_estimators'] = n_estimators
config['random_seed'] = random_seed

In [9]:
def unpack_config(config):
  return config['df'], config['target_variable'], config['sensible_attribute'], config['path_to_project'], config['n_splits'], config['models'], config['n_estimators'],config['random_seed']

# AIF Utils

In [10]:
def train_test_splitting(df, n_splits):
  df_splitting = {}

  w = int(len(df)/n_splits)
  window = w
  start_point = 0
  for i in range(0,n_splits):
      train = {}
      test = {}
      df_train_1 = {}
      df_train_2 = {}
      df_test = df[start_point:window]
      if i != 0:
        df_train_1 = df[0: start_point]

      if i != n_splits-1:
        df_train_2 = df[window: len(df)]

      if (i != 0 and  i != n_splits-1):
        concat_df = [df_train_1, df_train_2]
        df_train = pd.concat(concat_df)
      elif i != 0:
        df_train = df_train_1
      else:
        df_train = df_train_2

      start_point= window
      window = window + w

      df_splitting[i] = {'train': df_train, 'test': df_test}
  return df_splitting

In [11]:
def df_X_Y_split(df_train, df_test, target_variable):
  Y_train = df_train[target_variable]
  X_train = df_train.drop(target_variable, axis=1)
  Y_test = df_test[target_variable]
  X_test = df_test.drop(target_variable, axis=1)
  return X_train, Y_train, X_test, Y_test

In [12]:
def compute_predictions_and_tests(df, target_variable, n_splits, models, n_estimators, random_seed):
  predicted_and_real_values = {}
  for model_name in tqdm(models):
    clf = models[model_name]
    df_splitting = train_test_splitting(df, n_splits)
    pred_and_y = {}
    for i in range(0,n_splits):
      df_split = df_splitting[i]
      df_train = df_split['train']
      df_test = df_split['test']

      X_train, Y_train, X_test, Y_test = df_X_Y_split(df_train, df_test, target_variable)
      clf.fit(X_train,Y_train)
      y_pred = clf.predict(X_test)

      S_test = X_test[sensible_attribute].values
      pred_and_y[i] = {'y_test': Y_test.to_numpy().astype(int), 'y_pred': y_pred.astype(int), 's_test':  S_test.astype(int)}

    predicted_and_real_values[model_name] = pred_and_y

  return predicted_and_real_values


In [13]:
def check_results(results):
  for model_name in tqdm(models):
    for i in range(0,n_splits):
      y_test = results[model_name][i]['y_test']
      y_pred = results[model_name][i]['y_pred']
      for j in range(0,67):
        if y_test[j] != y_pred[j]:
          print(y_test[j], y_pred[j])

#Apply Mitigation

Apply mitigation LFR to produce the predictions.

In [20]:
from sklearn import preprocessing

# Build a list with format (feature, correlation with target)
features_corr = [(column, correlation) for column, correlation in zip(df.columns, df.corr()[target_variable])]

# Sort the features by correlation
sorted_features = sorted(features_corr, key=lambda x: x[1], reverse=True)

# Clean and take top 4
main_features = [feature[0] for feature in sorted_features if feature[1] > 0][:5]

# Add sensitive attribute
main_features = main_features + [sensible_attribute]

# Create a new reduced df
df_reduced = df[main_features]
df_reduced = df_reduced[:5000]

In [25]:
df_reduced = df

In [26]:
data_orig_aif = BinaryLabelDataset(
    favorable_label = 1,
    unfavorable_label = 0,
    df = df.copy(),
    label_names = [target_variable],
    protected_attribute_names = [sensible_attribute])

privileged_groups = [{sensible_attribute: 1}]
unprivileged_groups = [{sensible_attribute: 0}]

print(data_orig_aif.privileged_protected_attributes)

In [27]:
def aif360_lfr(config):
  mitigation = 'aif360-lfr'
  df, target_variable,sensible_attribute, path_to_project, n_splits, models, n_estimators, random_seed = unpack_config(config)

  predictions_and_tests = {}

  for k in range(1, 100):
    print("\n Trying for k=", k)

    try:
      TR = LFR(unprivileged_groups=unprivileged_groups,
          privileged_groups=privileged_groups,
          #seed= random_seed, k=10, Ax=0.01, Ay=1.0, Az=50.0,
          # seed= random_seed, k=92, Ax=1.0, Ay=1.0, Az=1.0, # Balanced for reduced
          seed= random_seed, k=k, verbose=1)
      data_orig_aif = BinaryLabelDataset(favorable_label = 1, unfavorable_label = 0, df = df_reduced.copy(), label_names = [target_variable], protected_attribute_names = [sensible_attribute])

      TR = TR.fit(data_orig_aif, maxiter=500, maxfun=500) # Changed 5000 to 500
      transf_dataset = TR.transform(data_orig_aif)
      mit_aif360_lfr = transf_dataset.convert_to_dataframe()[0]
      #save_mitigated_dataset(mit_aif360_lfr,path_to_project,dataset_name, mitigation)


      predictions_and_tests = compute_predictions_and_tests(mit_aif360_lfr, target_variable, n_splits, models, n_estimators, random_seed)

      # If it didn't throw an exception make it print the balancing of the new dataset
      print("\n Balance: ")
      print(mit_aif360_lfr[mit_aif360_lfr[target_variable] == 1.0].count(), '\n')
      print(mit_aif360_lfr[mit_aif360_lfr[target_variable] == 0.0].count(), '\n')

    except Exception as e:
      print("Error k=", k, " not working. \n")
  #save_predictions_and_tests(predictions_and_tests, mitigation, dataset_name, path_to_project)

  return predictions_and_tests, mit_aif360_lfr

In [28]:
predictions_and_tests, mitigated_dataset = aif360_lfr(config)

step: 0, loss: 74.63384855904496, L_x: 7394.473961508567,  L_y: 0.6891089439592808,  L_z: 0.0


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.74245253644887, L_x: 7376.505147187849,  L_y: 0.6439923050756607,  L_z: 0.006668175189894293
step: 250, loss: 74.38313291665759, L_x: 7374.042527350397,  L_y: 0.6394754901006561,  L_z: 6.464306105929962e-05
step: 500, loss: 74.37379961375636, L_x: 7372.861263103856,  L_y: 0.6381272540904444,  L_z: 0.0001411945725470487


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 75.07607875314095, L_x: 7393.765989235988,  L_y: 0.6463734785991285,  L_z: 0.009840907643638755
step: 250, loss: 68.8993793733554, L_x: 6816.123776549033,  L_y: 0.6858186662388132,  L_z: 0.0010464588325252513
step: 500, loss: 50.63506500711643, L_x: 4848.0998271324515,  L_y: 2.1533657214700925,  L_z: 1.4020286436448917e-05


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.89698287045452, L_x: 7395.312691954083,  L_y: 0.695815797157301,  L_z: 0.004960803075127737
step: 250, loss: 73.51051656990974, L_x: 7242.526618902525,  L_y: 0.6287850032935365,  L_z: 0.009129307551819257
step: 500, loss: 53.50898776850322, L_x: 5149.263730703891,  L_y: 1.5821858408052936,  L_z: 0.008683292413180239


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.79477296452298, L_x: 7386.300940249184,  L_y: 0.7139498690555427,  L_z: 0.0043562738595119945
step: 250, loss: 72.25503613716903, L_x: 7140.319650646787,  L_y: 0.6419427084917463,  L_z: 0.004197938444188568
step: 500, loss: 35.903440474742006, L_x: 2062.185001079182,  L_y: 10.424183606176978,  L_z: 0.0971481371554642


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.80341490055346, L_x: 7385.518392961096,  L_y: 0.6871871948425704,  L_z: 0.005220875521998743
step: 250, loss: 73.74607681316303, L_x: 7218.741886020132,  L_y: 1.019179277798805,  L_z: 0.010789573503258177
step: 500, loss: 65.62904569434382, L_x: 6431.127517312547,  L_y: 1.2528299789719108,  L_z: 0.0012988108449286295


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.71353358298462, L_x: 7392.652779531389,  L_y: 0.6683533107580338,  L_z: 0.002373049538254005
step: 250, loss: 73.52199883990447, L_x: 7261.362582926038,  L_y: 0.6712919139736075,  L_z: 0.004741621933409573
step: 500, loss: 58.45603243738658, L_x: 5720.0431272044825,  L_y: 1.2544198460049483,  L_z: 2.36263867361264e-05


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.76638237793001, L_x: 7385.52378069128,  L_y: 0.6845057370986163,  L_z: 0.00453277667837188
step: 250, loss: 73.04409393154904, L_x: 7097.41381290692,  L_y: 0.8804323495581952,  L_z: 0.023790469058432955
step: 500, loss: 71.87839579302731, L_x: 6979.952218635054,  L_y: 0.8095119397889211,  L_z: 0.025387233337756908


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.75810298526068, L_x: 7386.411433821777,  L_y: 0.7337309069644584,  L_z: 0.003205154801568964
step: 250, loss: 73.9948830025903, L_x: 7318.889445677805,  L_y: 0.7436255565657539,  L_z: 0.0012472597849295286
step: 500, loss: 70.69264383259303, L_x: 6994.134619832282,  L_y: 0.637580609779987,  L_z: 0.0022743404898046523


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.78196915019979, L_x: 7391.403246901618,  L_y: 0.7743846404414466,  L_z: 0.0018710408148429614
step: 250, loss: 74.21209358298698, L_x: 7347.545966610264,  L_y: 0.6825044756249604,  L_z: 0.0010825888251872126
step: 500, loss: 71.26182434331919, L_x: 6938.766052193797,  L_y: 1.8214811008391811,  L_z: 0.0010536544108406937


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.69476126964403, L_x: 7386.179139341881,  L_y: 0.742890556023855,  L_z: 0.0018015864040273812
step: 250, loss: 74.27433609910825, L_x: 7353.675324190261,  L_y: 0.6658995206376209,  L_z: 0.0014336667313604495
step: 500, loss: 73.12996829269262, L_x: 7235.171016650531,  L_y: 0.6426542957627804,  L_z: 0.002712076608490712


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.77291343518444, L_x: 7392.7314542543,  L_y: 0.7194195391446723,  L_z: 0.0025235870699352875
step: 250, loss: 74.14382099258889, L_x: 7321.681893613501,  L_y: 0.7834764472094774,  L_z: 0.0028705121848879565
step: 500, loss: 72.34247893974646, L_x: 7107.298028509528,  L_y: 0.8991676728195375,  L_z: 0.007406619636632711


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.83834236110788, L_x: 7398.84905233159,  L_y: 0.7696569481411352,  L_z: 0.0016038977930167174
step: 250, loss: 74.70392750603966, L_x: 7394.3740034342,  L_y: 0.693483875245152,  L_z: 0.001334071929050372
step: 500, loss: 74.09487595271676, L_x: 7334.965657441644,  L_y: 0.6795671933410756,  L_z: 0.0013130436991850244


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.80376451342848, L_x: 7393.909181367589,  L_y: 0.7934882070549428,  L_z: 0.001423689853952832
step: 250, loss: 74.6605306635736, L_x: 7389.519299718588,  L_y: 0.7051283136639579,  L_z: 0.0012041870544751228
step: 500, loss: 73.99168416112555, L_x: 7322.219395552071,  L_y: 0.7205099449854767,  L_z: 0.0009796052123872982


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.80935609322647, L_x: 7394.794246737811,  L_y: 0.7567992696388567,  L_z: 0.0020922871241897943
step: 250, loss: 74.7035779618009, L_x: 7391.128723996692,  L_y: 0.695618023179756,  L_z: 0.0019334539730843851
step: 500, loss: 73.7850946041941, L_x: 7281.6884685292825,  L_y: 0.8817840706666192,  L_z: 0.0017285169646927795


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.85726224982035, L_x: 7403.544494499143,  L_y: 0.7650322702595047,  L_z: 0.0011357006913879867
step: 250, loss: 74.75422937634445, L_x: 7400.081162830241,  L_y: 0.7038275218490372,  L_z: 0.0009918045238602067
step: 500, loss: 74.07074660501134, L_x: 7326.014264618114,  L_y: 0.7450751193910257,  L_z: 0.0013105767887834285


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.73516304430854, L_x: 7385.919424780929,  L_y: 0.7624540793924645,  L_z: 0.002270294342135459
step: 250, loss: 74.62355967370391, L_x: 7381.623885406576,  L_y: 0.7014409966068644,  L_z: 0.0021175964606256472
step: 500, loss: 73.95358099157828, L_x: 7307.835757009986,  L_y: 0.7952676181195358,  L_z: 0.0015991160671772957


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.62527630253126, L_x: 7379.675258037223,  L_y: 0.7402222039215617,  L_z: 0.001766030364749574
step: 250, loss: 74.54053957510432, L_x: 7376.130741881337,  L_y: 0.6953852961770379,  L_z: 0.001676937202278132
step: 500, loss: 74.03543003952396, L_x: 7316.908388376949,  L_y: 0.8272639572907539,  L_z: 0.0007816439692744125


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.66515192476561, L_x: 7385.531541432637,  L_y: 0.7498856036414114,  L_z: 0.0011990181359565202
step: 250, loss: 74.57939993837938, L_x: 7382.097294323063,  L_y: 0.7032869203975248,  L_z: 0.001102801495024519
step: 500, loss: 74.10840189030753, L_x: 7322.1174025450255,  L_y: 0.8324298263710225,  L_z: 0.0010959607697250792


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.70121781457348, L_x: 7387.449177247993,  L_y: 0.7551760549775355,  L_z: 0.0014309997423202416
step: 250, loss: 74.62168537615614, L_x: 7384.509862249381,  L_y: 0.7086047094154254,  L_z: 0.001359640884938063
step: 500, loss: 74.25105493502274, L_x: 7347.409487161562,  L_y: 0.7254229192982034,  L_z: 0.0010307428821782828


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.71110685736159, L_x: 7389.0827879153285,  L_y: 0.7390908305126925,  L_z: 0.0016237629539119249
step: 250, loss: 74.64248676494394, L_x: 7386.315131517204,  L_y: 0.7010214518226904,  L_z: 0.001566279958984183
step: 500, loss: 74.22405068802738, L_x: 7342.85459953206,  L_y: 0.7449872927966174,  L_z: 0.001010347998202965


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.72949103243589, L_x: 7395.701622228831,  L_y: 0.7277057150155553,  L_z: 0.0008953819026401702
step: 250, loss: 74.66473580734697, L_x: 7392.869896317858,  L_y: 0.6941429675798402,  L_z: 0.0008378775317711688
step: 500, loss: 74.25188689886791, L_x: 7347.218507681831,  L_y: 0.7485341445674719,  L_z: 0.0006233535496422402


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.63142119704298, L_x: 7387.188250560627,  L_y: 0.7087216094518641,  L_z: 0.0010163416396967077
step: 250, loss: 74.57561298994878, L_x: 7384.411532276565,  L_y: 0.6831271252805617,  L_z: 0.0009674108380513841
step: 500, loss: 73.93190913922986, L_x: 7315.230390699397,  L_y: 0.7500732631631593,  L_z: 0.0005906393814545959


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.6614141747711, L_x: 7388.940563417667,  L_y: 0.7125961191576275,  L_z: 0.0011882484287361712
step: 250, loss: 74.61081048462603, L_x: 7386.6052071998065,  L_y: 0.6873902973012094,  L_z: 0.001147362306534926
step: 500, loss: 74.08667093309782, L_x: 7332.262675535022,  L_y: 0.7247389499415533,  L_z: 0.000786104556120715


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.73688197261971, L_x: 7397.79949812827,  L_y: 0.7168471984682191,  L_z: 0.0008407958573757926
step: 250, loss: 74.73688196917796, L_x: 7397.799498112319,  L_y: 0.7168471981772917,  L_z: 0.0008407957975495073
step: 500, loss: 74.68153095704515, L_x: 7395.05195452862,  L_y: 0.6900853479078611,  L_z: 0.0008185212770219075


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.70326055194835, L_x: 7389.055655247026,  L_y: 0.7424964356266037,  L_z: 0.0014041512770297665
step: 250, loss: 74.70326058831017, L_x: 7389.055652195117,  L_y: 0.7424964564359273,  L_z: 0.001404152198461552
step: 500, loss: 74.63986805068622, L_x: 7386.297198353772,  L_y: 0.7087215708625467,  L_z: 0.001363489925718959


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.63011050778191, L_x: 7383.572742095175,  L_y: 0.745355990245394,  L_z: 0.000980541931695132
step: 250, loss: 74.63011050712954, L_x: 7383.572742079343,  L_y: 0.7453559900024438,  L_z: 0.0009805419266734393
step: 500, loss: 74.56838255869712, L_x: 7380.961043271434,  L_y: 0.7108328130399032,  L_z: 0.0009587862588575914


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.64691250254612, L_x: 7386.645137230595,  L_y: 0.736738185603016,  L_z: 0.0008744588927430061
step: 250, loss: 74.64691250044119, L_x: 7386.645135655075,  L_y: 0.7367382014491679,  L_z: 0.0008744588488253812
step: 500, loss: 74.58948132612461, L_x: 7384.105927073113,  L_y: 0.7058929803868134,  L_z: 0.0008505815001333152


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.64323448235801, L_x: 7386.537199736109,  L_y: 0.7322270420895683,  L_z: 0.0009127088581470265
step: 250, loss: 74.6432344702883, L_x: 7386.537192682579,  L_y: 0.7322270699106558,  L_z: 0.0009127094710372883
step: 500, loss: 74.59425992097287, L_x: 7384.437440211252,  L_y: 0.7052851249031596,  L_z: 0.0008920078791439109


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.6784675390596, L_x: 7388.410008823708,  L_y: 0.7362716551818286,  L_z: 0.0011619159128136952
step: 250, loss: 74.67846754727645, L_x: 7388.410007196696,  L_y: 0.7362716662561462,  L_z: 0.001161916181067082
step: 500, loss: 74.63042763508626, L_x: 7386.463325826664,  L_y: 0.7088040817879457,  L_z: 0.0011398059006337688


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.72338272792773, L_x: 7393.753426864451,  L_y: 0.7474189487819567,  L_z: 0.0007685902100250927
step: 250, loss: 74.72338270047194, L_x: 7393.75341913559,  L_y: 0.7474189984494801,  L_z: 0.0007685902133315176
step: 500, loss: 74.67265876509518, L_x: 7391.7862301458135,  L_y: 0.7173531027431479,  L_z: 0.0007488672178777286


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.67014791742426, L_x: 7389.665903444917,  L_y: 0.7354845707777745,  L_z: 0.0007600862439462869
step: 250, loss: 74.67014780330747, L_x: 7389.665885710652,  L_y: 0.7354846441095457,  L_z: 0.0007600860418280292
step: 500, loss: 74.62387952734414, L_x: 7387.759810224945,  L_y: 0.709142536731269,  L_z: 0.000742777767268575


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.68145318998863, L_x: 7391.2660609776085,  L_y: 0.7302498865239945,  L_z: 0.0007708538737710892
step: 250, loss: 74.68145319413307, L_x: 7391.266060756142,  L_y: 0.7302498875416248,  L_z: 0.0007708539806002997
step: 500, loss: 74.63816630981931, L_x: 7389.4305380895275,  L_y: 0.7059412663993456,  L_z: 0.0007583932504937973


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.71869480745204, L_x: 7395.795598420478,  L_y: 0.7199991930756899,  L_z: 0.0008147926034312473
step: 250, loss: 74.71869480581805, L_x: 7395.795598409266,  L_y: 0.7199991931636105,  L_z: 0.0008147925712354213
step: 500, loss: 74.67650403345236, L_x: 7393.817328831547,  L_y: 0.6980970752974759,  L_z: 0.0008046733967880732


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.6712170823538, L_x: 7389.94136623632,  L_y: 0.7235156419799321,  L_z: 0.0009657555602133806
step: 250, loss: 74.67121707771724, L_x: 7389.941363932674,  L_y: 0.7235156354151429,  L_z: 0.0009657560595072211
step: 500, loss: 74.62754139218468, L_x: 7387.849935351522,  L_y: 0.7015152657541971,  L_z: 0.0009505354583052481


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.64664934673003, L_x: 7386.725540263451,  L_y: 0.7411308597473252,  L_z: 0.0007652616869638783
step: 250, loss: 74.64664934658597, L_x: 7386.725540252744,  L_y: 0.7411308598435722,  L_z: 0.0007652616842991087
step: 500, loss: 74.60004103844841, L_x: 7384.696741994973,  L_y: 0.7153538108889355,  L_z: 0.0007543961521946195


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.64274920063319, L_x: 7387.261407644954,  L_y: 0.7373048578747613,  L_z: 0.0006566053261777879
step: 250, loss: 74.64274918104556, L_x: 7387.261406404705,  L_y: 0.7373048518463174,  L_z: 0.0006566053030437319
step: 500, loss: 74.59844443890394, L_x: 7385.347701948733,  L_y: 0.712746228499827,  L_z: 0.0006444238183356721


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.63194602830792, L_x: 7387.099368102622,  L_y: 0.7253654129118992,  L_z: 0.0007117386873959843
step: 250, loss: 74.6319459803241, L_x: 7387.099362560697,  L_y: 0.725365403441717,  L_z: 0.0007117390255084788
step: 500, loss: 74.59265836246949, L_x: 7385.3102040703125,  L_y: 0.7045959979250717,  L_z: 0.0006992064768261509
step: 750, loss: 74.59265836482922, L_x: 7385.310204072097,  L_y: 0.7045959983477325,  L_z: 0.0006992065152104541


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.65646345805936, L_x: 7387.531679350703,  L_y: 0.7369839088154488,  L_z: 0.0008832551147377626
step: 250, loss: 74.65646344970143, L_x: 7387.5316781232505,  L_y: 0.7369839046477051,  L_z: 0.000883255276424528
step: 500, loss: 74.61794113659234, L_x: 7385.949927664931,  L_y: 0.7148038867460333,  L_z: 0.0008727594639395553
step: 750, loss: 74.61794103020452, L_x: 7385.949912453755,  L_y: 0.7148038992785907,  L_z: 0.0008727601277676342


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.68607795216393, L_x: 7391.06037898507,  L_y: 0.7451293474543428,  L_z: 0.0006068962971777952
step: 250, loss: 74.68607787598604, L_x: 7391.0603733572025,  L_y: 0.7451293275170227,  L_z: 0.0006068962979397128
step: 500, loss: 74.64422536102609, L_x: 7389.373041662142,  L_y: 0.7206575274087215,  L_z: 0.0005967483399189385
step: 750, loss: 74.6442253514681, L_x: 7389.373035439549,  L_y: 0.7206575590574327,  L_z: 0.0005967487603033915


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.6626525305242, L_x: 7389.564580074609,  L_y: 0.7349012839510727,  L_z: 0.0006421089165404687
step: 250, loss: 74.66265235531473, L_x: 7389.564566165174,  L_y: 0.7349012561988996,  L_z: 0.0006421087492818058
step: 500, loss: 74.62664703097258, L_x: 7388.076937428232,  L_y: 0.7142012058105867,  L_z: 0.0006335290175934502
step: 750, loss: 74.62664715590988, L_x: 7388.076947993877,  L_y: 0.7142012399299015,  L_z: 0.0006335287208240105


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.6201393422299, L_x: 7387.087104512248,  L_y: 0.7223476664206998,  L_z: 0.0005384126137342635
step: 250, loss: 74.62013934399975, L_x: 7387.087104354869,  L_y: 0.7223476661137506,  L_z: 0.0005384126867462717
step: 500, loss: 74.58599103586651, L_x: 7385.525918076066,  L_y: 0.7042255344228833,  L_z: 0.000530126413659387
step: 750, loss: 74.58599099063758, L_x: 7385.5259156914,  L_y: 0.7042255305741624,  L_z: 0.0005301260629883267


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.668597953267, L_x: 7391.874025688665,  L_y: 0.7168066465721176,  L_z: 0.0006610209961644215
step: 250, loss: 74.66859795234832, L_x: 7391.874025678786,  L_y: 0.7168066467495748,  L_z: 0.0006610209762178625
step: 500, loss: 74.63623981564093, L_x: 7390.323741859605,  L_y: 0.7002437336840828,  L_z: 0.0006551732672161212
step: 750, loss: 74.63623980441776, L_x: 7390.3237406353965,  L_y: 0.70024374568437,  L_z: 0.0006551730475881792


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.64511231944147, L_x: 7388.781421997847,  L_y: 0.7171792099311072,  L_z: 0.0008023777906376775
step: 250, loss: 74.64511230579872, L_x: 7388.7814202955215,  L_y: 0.717179197213662,  L_z: 0.0008023781125968614
step: 500, loss: 74.61307586908916, L_x: 7387.21020605062,  L_y: 0.7010559164206673,  L_z: 0.0007983578432456534
step: 750, loss: 74.61307592237323, L_x: 7387.210211774436,  L_y: 0.7010559127540327,  L_z: 0.0007983578374966562


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.61102662451728, L_x: 7386.082055916352,  L_y: 0.720728595568243,  L_z: 0.0005895493957100669
step: 250, loss: 74.61102662452163, L_x: 7386.082055906934,  L_y: 0.7207285957394655,  L_z: 0.0005895493942566479
step: 500, loss: 74.57688643161389, L_x: 7384.4042925319545,  L_y: 0.703662979686319,  L_z: 0.0005836105321606866
step: 750, loss: 74.5768863192796, L_x: 7384.404282812639,  L_y: 0.7036629636284571,  L_z: 0.0005836105504949276


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.63915268423226, L_x: 7388.4794537414,  L_y: 0.7263011317998374,  L_z: 0.0005611403003683392
step: 250, loss: 74.63915266145615, L_x: 7388.479452707662,  L_y: 0.7263011203072415,  L_z: 0.0005611402814456852
step: 500, loss: 74.60646903083347, L_x: 7386.951403651106,  L_y: 0.7091442816584294,  L_z: 0.0005562142532795996
step: 750, loss: 74.606468847952, L_x: 7386.951403643157,  L_y: 0.7091440975506333,  L_z: 0.0005562142793957244


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.6280499522957, L_x: 7387.866477185985,  L_y: 0.720736904153445,  L_z: 0.0005729655256482362
step: 250, loss: 74.62804989917338, L_x: 7387.866472650313,  L_y: 0.7207368852504048,  L_z: 0.0005729657483969393
step: 500, loss: 74.59900422285004, L_x: 7386.530978012459,  L_y: 0.7053574468085968,  L_z: 0.000566739918337108
step: 750, loss: 74.59900407280537, L_x: 7386.530978017593,  L_y: 0.7053572967085077,  L_z: 0.0005667399184185437


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.62889982257194, L_x: 7387.113492871973,  L_y: 0.7237783198200923,  L_z: 0.0006797314806422705
step: 250, loss: 74.62889981031554, L_x: 7387.11349189233,  L_y: 0.7237783120466615,  L_z: 0.000679731586911537
step: 500, loss: 74.60109862564613, L_x: 7385.882659129169,  L_y: 0.7085539347452616,  L_z: 0.0006743619921834599
step: 750, loss: 74.60109838867822, L_x: 7385.882647936272,  L_y: 0.7085537897281937,  L_z: 0.0006743623917460983


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.6555125761409, L_x: 7390.902351669494,  L_y: 0.7212612309702713,  L_z: 0.0005045565695137074
step: 250, loss: 74.65551249647251, L_x: 7390.902347122183,  L_y: 0.7212611968395735,  L_z: 0.0005045565682221058
step: 500, loss: 74.62777578242833, L_x: 7389.64327331556,  L_y: 0.7063517872594139,  L_z: 0.0004998252402663616
step: 750, loss: 74.62777550456173, L_x: 7389.643261914898,  L_y: 0.7063516265792299,  L_z: 0.0004998251766702191


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.61738869045942, L_x: 7388.6200060276005,  L_y: 0.704215360483057,  L_z: 0.0005394653940068485
step: 250, loss: 74.6173885301638, L_x: 7388.619994924937,  L_y: 0.7042153169974827,  L_z: 0.0005394652783389266
step: 500, loss: 74.61738868955081, L_x: 7388.62000602961,  L_y: 0.7042153604862659,  L_z: 0.0005394653753687591
step: 750, loss: 74.59286419659398, L_x: 7387.392617902099,  L_y: 0.6921934110412407,  L_z: 0.0005348921306347899
step: 1000, loss: 74.59286417804631, L_x: 7387.3926167307845,  L_y: 0.6921934118196019,  L_z: 0.0005348919783772804


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.5892478327041, L_x: 7386.994736388879,  L_y: 0.697372728163664,  L_z: 0.0004385548130332043
step: 250, loss: 74.58924783357733, L_x: 7386.994736258732,  L_y: 0.6973727276766979,  L_z: 0.00043855486626633825
step: 500, loss: 74.58924779056794, L_x: 7386.994733673639,  L_y: 0.6973727280922644,  L_z: 0.0004385545147857446
step: 750, loss: 74.56467122657861, L_x: 7385.658470662662,  L_y: 0.6863699039024274,  L_z: 0.0004343323209912299
step: 1000, loss: 74.5646712339971, L_x: 7385.6584732574265,  L_y: 0.6863699050734449,  L_z: 0.0004343319269879636


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.66603503502252, L_x: 7394.150128417966,  L_y: 0.6976596162950849,  L_z: 0.0005374826909552214
step: 250, loss: 74.66603503415914, L_x: 7394.1501284100395,  L_y: 0.6976596162606608,  L_z: 0.0005374826759615236
step: 500, loss: 74.66603503499395, L_x: 7394.1501284226515,  L_y: 0.697659616296126,  L_z: 0.0005374826894263816
step: 750, loss: 74.64218820132184, L_x: 7392.8391069633135,  L_y: 0.687037491318685,  L_z: 0.0005351928074003149
step: 1000, loss: 74.6421882077302, L_x: 7392.839107741216,  L_y: 0.6870374843993908,  L_z: 0.0005351929183727984


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.6485841052945, L_x: 7391.26462107698,  L_y: 0.7042438560439724,  L_z: 0.0006338807696144803
step: 250, loss: 74.64858410388668, L_x: 7391.264619549624,  L_y: 0.7042438582368552,  L_z: 0.0006338810030717923
step: 500, loss: 74.6485840785505, L_x: 7391.264618661005,  L_y: 0.7042438556556394,  L_z: 0.0006338807256959563
step: 750, loss: 74.62415863561674, L_x: 7389.971349419803,  L_y: 0.6928400335519794,  L_z: 0.0006321021573344714
step: 1000, loss: 74.6241585514866, L_x: 7389.971345203468,  L_y: 0.6928399989496672,  L_z: 0.0006321020100449732


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.57637718771421, L_x: 7385.9542595958155,  L_y: 0.6927856224445832,  L_z: 0.0004809793862294429
step: 250, loss: 74.57637718755245, L_x: 7385.954259587814,  L_y: 0.6927856224131151,  L_z: 0.0004809793852239172
step: 500, loss: 74.57637710483326, L_x: 7385.9542494698835,  L_y: 0.6927856228411786,  L_z: 0.0004809797458644855
step: 750, loss: 74.55256388139915, L_x: 7384.571674928029,  L_y: 0.6829416588509374,  L_z: 0.00047810946535843397
step: 1000, loss: 74.5525639639201, L_x: 7384.571679920517,  L_y: 0.6829416855134692,  L_z: 0.00047810958402884015


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.62002632793187, L_x: 7390.107445797088,  L_y: 0.6931314755163006,  L_z: 0.0005164078888938758
step: 250, loss: 74.62002631997731, L_x: 7390.107444882624,  L_y: 0.6931314774821119,  L_z: 0.0005164078733793511
step: 500, loss: 74.62002631030357, L_x: 7390.107443462369,  L_y: 0.6931314756120379,  L_z: 0.0005164080013569971
step: 750, loss: 74.59745521839682, L_x: 7388.806453710845,  L_y: 0.6837161342767784,  L_z: 0.0005134909402318099
step: 1000, loss: 74.59745521907912, L_x: 7388.806453725298,  L_y: 0.6837161342530311,  L_z: 0.0005134909514624544


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.62325674403571, L_x: 7389.578275505446,  L_y: 0.7030606987336504,  L_z: 0.00048826580495180474
step: 250, loss: 74.62325671630511, L_x: 7389.578271531264,  L_y: 0.7030607019017195,  L_z: 0.0004882659818147643
step: 500, loss: 74.62325671061745, L_x: 7389.578272013989,  L_y: 0.7030606981480627,  L_z: 0.0004882658465899565
step: 750, loss: 74.6009912209749, L_x: 7388.4275387859325,  L_y: 0.6925084138440315,  L_z: 0.0004841483854305808
step: 1000, loss: 74.60099116560795, L_x: 7388.4275337778145,  L_y: 0.6925083992506743,  L_z: 0.0004841485715827709


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.60440840726574, L_x: 7387.2444473190135,  L_y: 0.7042429952317963,  L_z: 0.0005544187768761138
step: 250, loss: 74.6044084038892, L_x: 7387.244446488339,  L_y: 0.7042429965091993,  L_z: 0.0005544188499321194
step: 500, loss: 74.60440828518814, L_x: 7387.244435091607,  L_y: 0.7042429939650672,  L_z: 0.0005544188061400368
step: 750, loss: 74.58303144692309, L_x: 7386.169104342164,  L_y: 0.6937594602792105,  L_z: 0.0005516188644445293
step: 1000, loss: 74.58303143767453, L_x: 7386.169107310859,  L_y: 0.6937594331075777,  L_z: 0.0005516186291675833


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.65302270567491, L_x: 7392.293701471869,  L_y: 0.7080345293556429,  L_z: 0.0004410232320114523
step: 250, loss: 74.65302267082042, L_x: 7392.293697472343,  L_y: 0.7080345349406578,  L_z: 0.00044102322312661784
step: 500, loss: 74.65302270602412, L_x: 7392.293701234559,  L_y: 0.7080345293284607,  L_z: 0.0004410232870012425
step: 750, loss: 74.63057978165652, L_x: 7391.173134521487,  L_y: 0.6969366234514451,  L_z: 0.0004382362598038855
step: 1000, loss: 74.6305798446529, L_x: 7391.173142268325,  L_y: 0.6969366101078687,  L_z: 0.00043823623723582685


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.65240528808614, L_x: 7390.019157100533,  L_y: 0.7287600839518434,  L_z: 0.00046907266257945935
step: 250, loss: 74.65240519290239, L_x: 7390.019147319969,  L_y: 0.7287600909128398,  L_z: 0.00046907257579699717
step: 500, loss: 74.65240528748555, L_x: 7390.019157101828,  L_y: 0.7287600839994625,  L_z: 0.00046907264935606833
step: 750, loss: 74.62807024412916, L_x: 7388.970501490171,  L_y: 0.7150176627691083,  L_z: 0.00046695132916695477
step: 1000, loss: 74.6280702322935, L_x: 7388.970500778469,  L_y: 0.7150176611662566,  L_z: 0.00046695126685118775


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.63423929623487, L_x: 7387.9137478358425,  L_y: 0.7359049450411984,  L_z: 0.00038393745670487687
step: 250, loss: 74.63423929717598, L_x: 7387.913747723578,  L_y: 0.7359049451273803,  L_z: 0.00038393749625648555
step: 500, loss: 74.63423925533498, L_x: 7387.913745471036,  L_y: 0.7359049383894136,  L_z: 0.00038393724470421804
step: 750, loss: 74.6085903179494, L_x: 7386.8353298295215,  L_y: 0.7211443710488405,  L_z: 0.00038185297210694
step: 1000, loss: 74.60859029230063, L_x: 7386.83532780341,  L_y: 0.7211443767657147,  L_z: 0.0003818527500162129


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.69806125533894, L_x: 7394.804201782725,  L_y: 0.726939975135856,  L_z: 0.0004615852475169507
step: 250, loss: 74.69806125481828, L_x: 7394.80420177535,  L_y: 0.7269399751884609,  L_z: 0.00046158523752677525
step: 500, loss: 74.69806125537635, L_x: 7394.804201785806,  L_y: 0.726939975193644,  L_z: 0.0004615852464929953
step: 750, loss: 74.67381468316681, L_x: 7393.715001343153,  L_y: 0.7136688824916791,  L_z: 0.00045991574487209746
step: 1000, loss: 74.6738146669685, L_x: 7393.715000556682,  L_y: 0.7136688708162976,  L_z: 0.0004599158117075446


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.68871257892356, L_x: 7394.290593978216,  L_y: 0.721926497450249,  L_z: 0.0004776028338228005
step: 250, loss: 74.6887125698129, L_x: 7394.2905925703435,  L_y: 0.7219264936402456,  L_z: 0.0004776030093845945
step: 500, loss: 74.688712548793, L_x: 7394.290591758417,  L_y: 0.7219264911552207,  L_z: 0.00047760280107229843
step: 750, loss: 74.66508395858524, L_x: 7393.184372224693,  L_y: 0.709423665079212,  L_z: 0.00047633142518177643
step: 1000, loss: 74.6650839074043, L_x: 7393.184365865154,  L_y: 0.7094236747533426,  L_z: 0.0004763314799882055


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.64574044150922, L_x: 7390.466781477984,  L_y: 0.7218270721871712,  L_z: 0.0003849110908443079
step: 250, loss: 74.64574044144892, L_x: 7390.4667814710665,  L_y: 0.7218270722361125,  L_z: 0.00038491109004298
step: 500, loss: 74.64574034727843, L_x: 7390.466771901816,  L_y: 0.7218270594676021,  L_z: 0.0003849113758534702
step: 750, loss: 74.62139195896128, L_x: 7389.302668027744,  L_y: 0.7092031407806909,  L_z: 0.000383242758062985
step: 1000, loss: 74.6213921888993, L_x: 7389.302687595991,  L_y: 0.709203176389244,  L_z: 0.0003832427310030035
step: 1250, loss: 74.62139218802245, L_x: 7389.302687604817,  L_y: 0.7092031762523607,  L_z: 0.00038324271443845333


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.66754136187588, L_x: 7393.063554871418,  L_y: 0.7149296549813167,  L_z: 0.0004395231636077181
step: 250, loss: 74.66754134963685, L_x: 7393.063553994825,  L_y: 0.7149296520762799,  L_z: 0.000439523152246178
step: 500, loss: 74.66754134076973, L_x: 7393.063552727252,  L_y: 0.7149296510164321,  L_z: 0.0004395232496156054
step: 750, loss: 74.6452561843819, L_x: 7391.965670365138,  L_y: 0.7037100456109692,  L_z: 0.00043778870239101226
step: 1000, loss: 74.64525618518466, L_x: 7391.965670379734,  L_y: 0.7037100454079481,  L_z: 0.00043778871958752546
step: 1250, loss: 74.64525614531226, L_x: 7391.965667789256,  L_y: 0.7037100391235411,  L_z: 0.00043778856592334176


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.64349758788393, L_x: 7390.661176102085,  L_y: 0.7160394221608284,  L_z: 0.00041692809404489655
step: 250, loss: 74.64349755464633, L_x: 7390.661172602635,  L_y: 0.7160394173504186,  L_z: 0.00041692822539109793
step: 500, loss: 74.64349754811698, L_x: 7390.661172991755,  L_y: 0.716039411925248,  L_z: 0.00041692812548334117
step: 750, loss: 74.62251034148424, L_x: 7389.65852242909,  L_y: 0.7052015425562567,  L_z: 0.0004144714927418177
step: 1000, loss: 74.62251030924915, L_x: 7389.658520068422,  L_y: 0.7052015402770011,  L_z: 0.0004144713657583835
step: 1250, loss: 74.62251028707213, L_x: 7389.65851530112,  L_y: 0.7052015612155043,  L_z: 0.0004144714569083971


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.62280709600913, L_x: 7388.667070179983,  L_y: 0.712259408315798,  L_z: 0.0004775397178697672
step: 250, loss: 74.62280708931807, L_x: 7388.66706942757,  L_y: 0.7122594064094172,  L_z: 0.0004775397726590984
step: 500, loss: 74.62280696889061, L_x: 7388.6670592879545,  L_y: 0.7122593889567462,  L_z: 0.0004775397410861141
step: 750, loss: 74.60335357767093, L_x: 7387.742552119337,  L_y: 0.702129877282093,  L_z: 0.00047596358390935314
step: 1000, loss: 74.6033535710505, L_x: 7387.742549435785,  L_y: 0.7021299026443012,  L_z: 0.00047596348096715216
step: 1250, loss: 74.60335360755936, L_x: 7387.742555622934,  L_y: 0.7021298804539986,  L_z: 0.00047596341752031095


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.65643074438827, L_x: 7392.652011887497,  L_y: 0.711163378978698,  L_z: 0.0003749449306920162
step: 250, loss: 74.65643070079014, L_x: 7392.652008387569,  L_y: 0.7111633705919119,  L_z: 0.0003749449264505243
step: 500, loss: 74.65643074409653, L_x: 7392.652011680508,  L_y: 0.7111633786693755,  L_z: 0.0003749449724414656
step: 750, loss: 74.63716645098124, L_x: 7391.717505709124,  L_y: 0.701328262978599,  L_z: 0.00037326261822824325
step: 1000, loss: 74.63716647173028, L_x: 7391.71750830101,  L_y: 0.7013282587684618,  L_z: 0.0003732625990346218
step: 1250, loss: 74.6371664536334, L_x: 7391.7175067206945,  L_y: 0.7013282641034431,  L_z: 0.00037326244646045933


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.64448385531851, L_x: 7391.036475285876,  L_y: 0.7135283425111413,  L_z: 0.00041181519897249126
step: 250, loss: 74.64448375140353, L_x: 7391.036466544347,  L_y: 0.713528329392285,  L_z: 0.0004118151313553503
step: 500, loss: 74.64448385481396, L_x: 7391.036475287268,  L_y: 0.7135283425045034,  L_z: 0.0004118151887355036
step: 750, loss: 74.62498542100076, L_x: 7390.102989145293,  L_y: 0.7034271094109822,  L_z: 0.00041056840273715037
step: 1000, loss: 74.62498542419763, L_x: 7390.102988765186,  L_y: 0.7034271131432652,  L_z: 0.00041056846805019446
step: 1250, loss: 74.62498541490862, L_x: 7390.102987858225,  L_y: 0.7034271175142862,  L_z: 0.0004105683762418489


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.63027626050759, L_x: 7390.017488725495,  L_y: 0.7138503630103511,  L_z: 0.00032502020484567376
step: 250, loss: 74.63027626083235, L_x: 7390.017488620619,  L_y: 0.7138503628336574,  L_z: 0.0003250202358500591
step: 500, loss: 74.6302762313511, L_x: 7390.017486552048,  L_y: 0.7138503637715106,  L_z: 0.0003250200411821949
step: 750, loss: 74.61029233894831, L_x: 7389.049731075325,  L_y: 0.7036061656596443,  L_z: 0.0003237772507081084
step: 1000, loss: 74.6102922125346, L_x: 7389.049729484783,  L_y: 0.7036060561802179,  L_z: 0.0003237772301309632
step: 1250, loss: 74.61029221229266, L_x: 7389.049729327522,  L_y: 0.7036060603055865,  L_z: 0.00032377717423707933


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.66256301044648, L_x: 7394.4400976021425,  L_y: 0.6985503829851976,  L_z: 0.0003922330287968205
step: 250, loss: 74.66256300991198, L_x: 7394.440097594442,  L_y: 0.6985503828966212,  L_z: 0.00039223302141861376
step: 500, loss: 74.66256301042543, L_x: 7394.440097604944,  L_y: 0.698550382973963,  L_z: 0.00039223302804052414
step: 750, loss: 74.64497489385964, L_x: 7393.487845223324,  L_y: 0.6905343217412157,  L_z: 0.0003912423977035963
step: 1000, loss: 74.64497482981396, L_x: 7393.48784394038,  L_y: 0.6905342745168932,  L_z: 0.00039124231786540997
step: 1250, loss: 74.64497468678822, L_x: 7393.487833238771,  L_y: 0.6905342303936786,  L_z: 0.00039124248013649597


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.67855327908326, L_x: 7394.753909410751,  L_y: 0.7103523060237669,  L_z: 0.0004132375790395561
step: 250, loss: 74.6785532799462, L_x: 7394.753908155877,  L_y: 0.7103523124956649,  L_z: 0.00041323771783526383
step: 500, loss: 74.67855325913679, L_x: 7394.753907442559,  L_y: 0.7103523069510179,  L_z: 0.00041323755520353625
step: 750, loss: 74.65979042538729, L_x: 7393.81085766867,  L_y: 0.7010575261041009,  L_z: 0.00041248645192976926
step: 1000, loss: 74.65979019887588, L_x: 7393.810844899209,  L_y: 0.7010574273869553,  L_z: 0.00041248644993650173
step: 1250, loss: 74.65979029132603, L_x: 7393.810857676797,  L_y: 0.7010573911731092,  L_z: 0.0004124864676991177


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.63477152316172, L_x: 7390.937216387343,  L_y: 0.7068507411250187,  L_z: 0.0003709723632654457
step: 250, loss: 74.63477152297418, L_x: 7390.937216380267,  L_y: 0.7068507410404586,  L_z: 0.00037097236262107386
step: 500, loss: 74.63477145214945, L_x: 7390.937207940108,  L_y: 0.7068507430599175,  L_z: 0.00037097259376902045
step: 750, loss: 74.61522535346533, L_x: 7389.915757875294,  L_y: 0.6975730347124801,  L_z: 0.0003698947999984416
step: 1000, loss: 74.61522529179553, L_x: 7389.915757874765,  L_y: 0.6975729724923068,  L_z: 0.0003698948111109892
step: 1250, loss: 74.61522529159076, L_x: 7389.9157578810755,  L_y: 0.6975729727194875,  L_z: 0.000369894801210497


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.65849866216122, L_x: 7393.182432250822,  L_y: 0.7066943270435432,  L_z: 0.00039960025218893874
step: 250, loss: 74.65849865937321, L_x: 7393.182431482044,  L_y: 0.7066943324095655,  L_z: 0.0003996002428638095
step: 500, loss: 74.65849864717568, L_x: 7393.182430373463,  L_y: 0.7066943277257673,  L_z: 0.0003996003143055014
step: 750, loss: 74.64049979792156, L_x: 7392.259955218369,  L_y: 0.697975447942748,  L_z: 0.0003984959559021267
step: 1000, loss: 74.64049970981839, L_x: 7392.259955244096,  L_y: 0.697975359657356,  L_z: 0.00039849595440134025
step: 1250, loss: 74.64049969435972, L_x: 7392.259954218232,  L_y: 0.6979753595189124,  L_z: 0.0003984958531696959


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.6505599971254, L_x: 7392.423903649866,  L_y: 0.7078146219714117,  L_z: 0.0003701267731064937
step: 250, loss: 74.65055997866409, L_x: 7392.423900375803,  L_y: 0.7078146313750344,  L_z: 0.0003701268706206107
step: 500, loss: 74.65055997129126, L_x: 7392.423900771192,  L_y: 0.7078146237517757,  L_z: 0.000370126796551418
step: 750, loss: 74.63327810286494, L_x: 7391.548133104513,  L_y: 0.6993605877868798,  L_z: 0.000368723680658478
step: 1000, loss: 74.63327797633282, L_x: 7391.548128346563,  L_y: 0.6993605028304285,  L_z: 0.0003687238007350352
step: 1250, loss: 74.63327796847675, L_x: 7391.5481286040185,  L_y: 0.6993604978318733,  L_z: 0.000368723692093757


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.62082169443273, L_x: 7389.4637096339375,  L_y: 0.7061617601303596,  L_z: 0.0004004567592599364
step: 250, loss: 74.6208216935705, L_x: 7389.463708951305,  L_y: 0.706161763837383,  L_z: 0.00040045680440135465
step: 500, loss: 74.62082160088384, L_x: 7389.463699841592,  L_y: 0.7061617637556424,  L_z: 0.0004004567742454427
step: 750, loss: 74.62082169402187, L_x: 7389.463709637237,  L_y: 0.7061617601172484,  L_z: 0.00040045675064466176
step: 1000, loss: 74.60453592561818, L_x: 7388.64816314902,  L_y: 0.6980746985628852,  L_z: 0.00039959191130187264
step: 1250, loss: 74.60453593356827, L_x: 7388.648165070908,  L_y: 0.698074685357862,  L_z: 0.00039959195002633525
step: 1500, loss: 74.60453591299364, L_x: 7388.648164206331,  L_y: 0.6980746808467564,  L_z: 0.00039959180167142193


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.64840167997252, L_x: 7392.230700336465,  L_y: 0.7086000276703166,  L_z: 0.0003498929787512641
step: 250, loss: 74.64840166558471, L_x: 7392.230697283075,  L_y: 0.7086000440268444,  L_z: 0.00034989297454251773
step: 500, loss: 74.6484016797885, L_x: 7392.2307001536665,  L_y: 0.7086000277195689,  L_z: 0.00034989301064553685
step: 750, loss: 74.64840163936293, L_x: 7392.230697758808,  L_y: 0.708600020448282,  L_z: 0.0003498928265312616
step: 1000, loss: 74.63177900646114, L_x: 7391.408389039345,  L_y: 0.7002478472516909,  L_z: 0.0003489453763200369
step: 1250, loss: 74.63177898631805, L_x: 7391.4083877403245,  L_y: 0.7002478473305715,  L_z: 0.0003489452316846142
step: 1500, loss: 74.63177897305712, L_x: 7391.408386850651,  L_y: 0.7002478376963559,  L_z: 0.000348945337085342


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.63209157095186, L_x: 7390.736618790603,  L_y: 0.7063689472396149,  L_z: 0.00036712871612412396
step: 250, loss: 74.63209151601879, L_x: 7390.736611104165,  L_y: 0.7063689720296938,  L_z: 0.0003671286589487725
step: 500, loss: 74.63209157150739, L_x: 7390.736618790626,  L_y: 0.7063689473191241,  L_z: 0.0003671287256400984
step: 750, loss: 74.63209157090787, L_x: 7390.7366187902635,  L_y: 0.7063689472241789,  L_z: 0.00036712871562126866
step: 1000, loss: 74.61560540807517, L_x: 7389.9146379491995,  L_y: 0.6981515395479163,  L_z: 0.0003661497807049738
step: 1250, loss: 74.61560544215024, L_x: 7389.914640472087,  L_y: 0.6981515409916227,  L_z: 0.00036614992875503525
step: 1500, loss: 74.61560543887941, L_x: 7389.914640014307,  L_y: 0.6981515394903238,  L_z: 0.00036614998492035515


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.60083259878408, L_x: 7388.276203891966,  L_y: 0.7019073052848704,  L_z: 0.0003232650915908142
step: 250, loss: 74.60083259943922, L_x: 7388.276203801973,  L_y: 0.7019073055804689,  L_z: 0.0003232651167802649
step: 500, loss: 74.6008325656118, L_x: 7388.276202048818,  L_y: 0.701907296664975,  L_z: 0.0003232649691727332
step: 750, loss: 74.60083256400466, L_x: 7388.276201041774,  L_y: 0.7019072978562613,  L_z: 0.00032326511461312296
step: 1000, loss: 74.58465045248954, L_x: 7387.42032982687,  L_y: 0.694320969629403,  L_z: 0.0003225236918286069
step: 1250, loss: 74.58465046386529, L_x: 7387.420330811194,  L_y: 0.6943209699640558,  L_z: 0.0003225237157857133
step: 1500, loss: 74.5846503848868, L_x: 7387.420322625672,  L_y: 0.6943209739718799,  L_z: 0.0003225236931638172


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.63984169021973, L_x: 7392.221106813096,  L_y: 0.7002476863054189,  L_z: 0.00034765871566686294
step: 250, loss: 74.63984168980905, L_x: 7392.221106806397,  L_y: 0.7002476862517881,  L_z: 0.0003476587098656467
step: 500, loss: 74.63984169028163, L_x: 7392.221106813827,  L_y: 0.7002476863897966,  L_z: 0.00034765871507121907
step: 750, loss: 74.63984161096401, L_x: 7392.221100996993,  L_y: 0.7002476750772321,  L_z: 0.00034765851833694594
step: 1000, loss: 74.62419533020439, L_x: 7391.3919652907825,  L_y: 0.6929343741658511,  L_z: 0.00034682606261431944
step: 1250, loss: 74.62419526735187, L_x: 7391.391960974579,  L_y: 0.6929343529315563,  L_z: 0.0003468260934902828
step: 1500, loss: 74.62419533587085, L_x: 7391.391965563538,  L_y: 0.692934375091803,  L_z: 0.0003468261028732177


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.64467569527805, L_x: 7392.803816152142,  L_y: 0.698604633630277,  L_z: 0.0003606580025270308
step: 250, loss: 74.64467569390142, L_x: 7392.803815113875,  L_y: 0.6986046373731148,  L_z: 0.0003606581077912188
step: 500, loss: 74.64467566857004, L_x: 7392.80381450172,  L_y: 0.6986046243935095,  L_z: 0.00036065798318683773
step: 750, loss: 74.64467566959821, L_x: 7392.803813758315,  L_y: 0.6986046286246301,  L_z: 0.00036065806780836236
step: 1000, loss: 74.62956482335024, L_x: 7391.97993067204,  L_y: 0.6917642518545339,  L_z: 0.00036002529550616043
step: 1250, loss: 74.62956498991888, L_x: 7391.979949296364,  L_y: 0.6917642312999759,  L_z: 0.00036002531310542205
step: 1500, loss: 74.6295649901933, L_x: 7391.979949300428,  L_y: 0.6917642312187638,  L_z: 0.00036002531940495283


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.62334633471718, L_x: 7390.325278991387,  L_y: 0.7031597799721767,  L_z: 0.0003386752966226925
step: 250, loss: 74.62334633458305, L_x: 7390.325278985518,  L_y: 0.7031597799205571,  L_z: 0.0003386752961465122
step: 500, loss: 74.62334624882877, L_x: 7390.325271481281,  L_y: 0.7031597606976525,  L_z: 0.0003386754663658724
step: 750, loss: 74.62334628892228, L_x: 7390.325275728512,  L_y: 0.7031597675868416,  L_z: 0.00033867528100638215
step: 1000, loss: 74.60756822010329, L_x: 7389.4844166783605,  L_y: 0.6958332083619768,  L_z: 0.00033781689915412014
step: 1250, loss: 74.60756821944743, L_x: 7389.484416673906,  L_y: 0.6958332082395063,  L_z: 0.00033781688937717673
step: 1500, loss: 74.60756821406895, L_x: 7389.484414808015,  L_y: 0.6958332174318879,  L_z: 0.0003378169711381484


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.65787756317475, L_x: 7393.421560654675,  L_y: 0.7058417359654496,  L_z: 0.0003564044132509227
step: 250, loss: 74.65787755926952, L_x: 7393.421559968303,  L_y: 0.7058417393095472,  L_z: 0.0003564044055390362
step: 500, loss: 74.65787754209484, L_x: 7393.421558975856,  L_y: 0.7058417291290309,  L_z: 0.0003564044641447455
step: 750, loss: 74.65787748966095, L_x: 7393.421555159241,  L_y: 0.7058417186698743,  L_z: 0.0003564043879734076
step: 1000, loss: 74.64256384914711, L_x: 7392.641279363297,  L_y: 0.6983689258686463,  L_z: 0.00035564259290957537
step: 1250, loss: 74.64256382310708, L_x: 7392.641276508468,  L_y: 0.6983689242189438,  L_z: 0.000355642676069185
step: 1500, loss: 74.64256380969263, L_x: 7392.641275553859,  L_y: 0.6983689228480908,  L_z: 0.0003556426261188719


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.64679441254292, L_x: 7392.945645040269,  L_y: 0.7005262100563378,  L_z: 0.0003362350416778778
step: 250, loss: 74.64679439349256, L_x: 7392.945642145638,  L_y: 0.7005262158102711,  L_z: 0.0003362351245180531
step: 500, loss: 74.64679436997052, L_x: 7392.945642432413,  L_y: 0.7005261925845091,  L_z: 0.00033623506123764294
step: 750, loss: 74.646794408477, L_x: 7392.94564484201,  L_y: 0.7005262096646665,  L_z: 0.0003362350078447023
step: 1000, loss: 74.63211066874476, L_x: 7392.166559713113,  L_y: 0.6936887147050222,  L_z: 0.00033512713817196297
step: 1250, loss: 74.6321106834968, L_x: 7392.166560749123,  L_y: 0.693688726975846,  L_z: 0.0003351269805944989
step: 1500, loss: 74.63211072804229, L_x: 7392.166567996353,  L_y: 0.6936887004613081,  L_z: 0.0003351269523492011


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.61797386759598, L_x: 7390.146925850215,  L_y: 0.6986273034124449,  L_z: 0.0003575461136276663
step: 250, loss: 74.61797386560785, L_x: 7390.1469252306615,  L_y: 0.6986273057150855,  L_z: 0.0003575461517229727
step: 500, loss: 74.61797374575242, L_x: 7390.146916959475,  L_y: 0.6986272700303945,  L_z: 0.0003575461225455979
step: 750, loss: 74.61797386725247, L_x: 7390.146925853169,  L_y: 0.6986273034002658,  L_z: 0.0003575461064103085
step: 1000, loss: 74.60391258298931, L_x: 7389.4070519382785,  L_y: 0.6919950128415474,  L_z: 0.00035694101529952965
step: 1250, loss: 74.60391261717477, L_x: 7389.407058423505,  L_y: 0.691994985790051,  L_z: 0.00035694094299351426
step: 1500, loss: 74.60391259865597, L_x: 7389.407056247119,  L_y: 0.6919949795496154,  L_z: 0.00035694113270329545


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.64634234416802, L_x: 7392.415804485343,  L_y: 0.7064413816951859,  L_z: 0.0003148583523880076
step: 250, loss: 74.6463423267243, L_x: 7392.415801747078,  L_y: 0.7064413918408381,  L_z: 0.000314858348253365
step: 500, loss: 74.64634234328318, L_x: 7392.415804320489,  L_y: 0.706441381130756,  L_z: 0.00031485837895073684
step: 750, loss: 74.64634230618869, L_x: 7392.415802172278,  L_y: 0.7064413727802308,  L_z: 0.00031485823371363205
step: 1000, loss: 74.6316669503419, L_x: 7391.677452292013,  L_y: 0.6991855632276457,  L_z: 0.0003141372838822874
step: 1250, loss: 74.63166694111186, L_x: 7391.6774516226615,  L_y: 0.6991855651704381,  L_z: 0.000314137194296294
step: 1500, loss: 74.63166693546351, L_x: 7391.677451081456,  L_y: 0.699185562244383,  L_z: 0.00031413724809145214


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.62291828795146, L_x: 7390.5576929797335,  L_y: 0.7009120408607419,  L_z: 0.000328586345867794
step: 250, loss: 74.62291823247051, L_x: 7390.55768613345,  L_y: 0.7009120561256424,  L_z: 0.0003285863002071511
step: 500, loss: 74.62291828830308, L_x: 7390.557692979126,  L_y: 0.7009120408577012,  L_z: 0.0003285863530823597
step: 750, loss: 74.62291828789996, L_x: 7390.557692978474,  L_y: 0.7009120408429487,  L_z: 0.0003285863454452801
step: 1000, loss: 74.60851494786081, L_x: 7389.804343576848,  L_y: 0.694073550545487,  L_z: 0.00032795923093692573
step: 1250, loss: 74.60851494010093, L_x: 7389.8043435970385,  L_y: 0.6940735483477642,  L_z: 0.0003279591156556744
step: 1500, loss: 74.60851495868937, L_x: 7389.804344837048,  L_y: 0.6940735550240383,  L_z: 0.00032795910589688665


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.60244757804315, L_x: 7388.861202043565,  L_y: 0.6999187244936308,  L_z: 0.00027833666227775966
step: 250, loss: 74.60244757847792, L_x: 7388.861201961903,  L_y: 0.6999187246802842,  L_z: 0.00027833668357195697
step: 500, loss: 74.6024475568799, L_x: 7388.861200367115,  L_y: 0.6999187248697641,  L_z: 0.00027833656677949054
step: 750, loss: 74.60244754357703, L_x: 7388.861199451807,  L_y: 0.6999187150623353,  L_z: 0.00027833667993228865
step: 1000, loss: 74.58817429595456, L_x: 7388.097149868815,  L_y: 0.6933133049213726,  L_z: 0.0002777898469006477
step: 1250, loss: 74.5881742823651, L_x: 7388.097148717992,  L_y: 0.6933133063153,  L_z: 0.00027778977739752317
step: 1500, loss: 74.58817426193968, L_x: 7388.097147469987,  L_y: 0.6933132937458713,  L_z: 0.00027778986987856715


  0%|          | 0/6 [00:00<?, ?it/s]



step: 0, loss: 74.64654666756314, L_x: 7392.03332218104,  L_y: 0.7107720334196048,  L_z: 0.0003088282466630088
step: 250, loss: 74.64654666734468, L_x: 7392.033322174613,  L_y: 0.7107720335159475,  L_z: 0.0003088282416520282
step: 500, loss: 74.64654666754271, L_x: 7392.033322181742,  L_y: 0.710772033417988,  L_z: 0.0003088282461463306
step: 750, loss: 74.6465465930525, L_x: 7392.033317011048,  L_y: 0.710772018065652,  L_z: 0.0003088280975274905
step: 1000, loss: 74.6315395458789, L_x: 7391.291868479241,  L_y: 0.7031945695906967,  L_z: 0.0003085258299158332
step: 1250, loss: 74.63153948497595, L_x: 7391.291863525601,  L_y: 0.7031945620907013,  L_z: 0.00030852575258489743
step: 1500, loss: 74.6315395513746, L_x: 7391.291869435663,  L_y: 0.7031945668654342,  L_z: 0.0003085258030506193
step: 1750, loss: 74.63153955087472, L_x: 7391.291869449036,  L_y: 0.7031945667660247,  L_z: 0.0003085257923667812


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.66000807712123, L_x: 7393.803908377304,  L_y: 0.7056690735717892,  L_z: 0.000325998395528248
step: 250, loss: 74.66000806478665, L_x: 7393.803907416012,  L_y: 0.7056690666915499,  L_z: 0.000325998478699202
step: 500, loss: 74.66000806147244, L_x: 7393.803906859195,  L_y: 0.7056690738655829,  L_z: 0.00032599838029832287
step: 750, loss: 74.66000804491047, L_x: 7393.803906184299,  L_y: 0.7056690665353142,  L_z: 0.0003259983306431232
step: 1000, loss: 74.64569739108389, L_x: 7393.07274239338,  L_y: 0.6986831043033008,  L_z: 0.00032573725693571293
step: 1250, loss: 74.64569743290072, L_x: 7393.072746319427,  L_y: 0.6986831080650628,  L_z: 0.00032573723282763455
step: 1500, loss: 74.64569743227123, L_x: 7393.072746313632,  L_y: 0.6986831080826217,  L_z: 0.0003257372210457314
step: 1750, loss: 74.64569739192409, L_x: 7393.072742753331,  L_y: 0.6986830995113609,  L_z: 0.00032573729758838816


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.61865704286691, L_x: 7389.835785314116,  L_y: 0.7057706401259092,  L_z: 0.0002905709919968663
step: 250, loss: 74.61865704288375, L_x: 7389.83578530862,  L_y: 0.7057706402165848,  L_z: 0.0002905709916195094
step: 500, loss: 74.61865698363908, L_x: 7389.835778663675,  L_y: 0.7057706406752263,  L_z: 0.0002905711265420599
step: 750, loss: 74.61865699713366, L_x: 7389.835782420227,  L_y: 0.7057706239591359,  L_z: 0.0002905709794449631
step: 1000, loss: 74.60407642996627, L_x: 7389.078319530465,  L_y: 0.6987942963043438,  L_z: 0.0002899787671455419
step: 1250, loss: 74.60407643046632, L_x: 7389.078319537219,  L_y: 0.6987942963603939,  L_z: 0.00028997877467487013
step: 1500, loss: 74.60407641417387, L_x: 7389.07831702196,  L_y: 0.6987943018543652,  L_z: 0.00028997884199804137
step: 1750, loss: 74.60407640716046, L_x: 7389.078316410085,  L_y: 0.6987943035133111,  L_z: 0.00028997879092612985


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.6520148887429, L_x: 7392.661735265601,  L_y: 0.7094947091938724,  L_z: 0.00031805653786009315
step: 250, loss: 74.65201487635349, L_x: 7392.661734663017,  L_y: 0.709494703148895,  L_z: 0.000318056531488265
step: 500, loss: 74.65201487609293, L_x: 7392.66173377552,  L_y: 0.7094947093174078,  L_z: 0.0003180565804063435
step: 750, loss: 74.65201481654594, L_x: 7392.661730406276,  L_y: 0.7094946865483006,  L_z: 0.00031805651869785636
step: 1000, loss: 74.63765444724069, L_x: 7391.943809863862,  L_y: 0.7023316639239894,  L_z: 0.000317693693561766
step: 1250, loss: 74.63765434301769, L_x: 7391.943803894229,  L_y: 0.7023316142607208,  L_z: 0.00031769379629355683
step: 1500, loss: 74.63765436325711, L_x: 7391.943805503434,  L_y: 0.7023316235990911,  L_z: 0.00031769369247329975
step: 1750, loss: 74.63765438722706, L_x: 7391.9438097928205,  L_y: 0.7023316054554641,  L_z: 0.00031769367686753204


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.64094646549216, L_x: 7392.43126652036,  L_y: 0.7024249222539578,  L_z: 0.0002841775606920942
step: 250, loss: 74.64094643270822, L_x: 7392.431263907134,  L_y: 0.7024249123706108,  L_z: 0.0002841776253251535
step: 500, loss: 74.64094644380486, L_x: 7392.431264203253,  L_y: 0.7024249229317638,  L_z: 0.0002841775768111836
step: 750, loss: 74.64094646191187, L_x: 7392.431266344618,  L_y: 0.7024249217356895,  L_z: 0.0002841775345997975
step: 1000, loss: 74.62751044142308, L_x: 7391.727365273864,  L_y: 0.6960633626831895,  L_z: 0.00028346852002492616
step: 1250, loss: 74.62751035752102, L_x: 7391.727364525802,  L_y: 0.6960632853646561,  L_z: 0.000283468537966847
step: 1500, loss: 74.62751036468879, L_x: 7391.727365179364,  L_y: 0.696063287907426,  L_z: 0.0002834684997547317
step: 1750, loss: 74.6275103580707, L_x: 7391.72736458713,  L_y: 0.6960632888581365,  L_z: 0.0002834684668250465


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.61722544069362, L_x: 7389.135382927938,  L_y: 0.7102789207269186,  L_z: 0.00031185381374612843
step: 250, loss: 74.617225432555, L_x: 7389.135382385452,  L_y: 0.7102789164972905,  L_z: 0.0003118538440637741
step: 500, loss: 74.61722536351925, L_x: 7389.135375102836,  L_y: 0.7102789213834672,  L_z: 0.0003118538221486547
step: 750, loss: 74.61722544043424, L_x: 7389.135382930759,  L_y: 0.7102789207247492,  L_z: 0.000311853808038096
step: 1000, loss: 74.60381898795322, L_x: 7388.48585024061,  L_y: 0.7033910244975135,  L_z: 0.0003113892209917496
step: 1250, loss: 74.60381887031232, L_x: 7388.4858500645005,  L_y: 0.7033909100013459,  L_z: 0.00031138919331934406
step: 1500, loss: 74.60381885390497, L_x: 7388.48584869178,  L_y: 0.7033909105610926,  L_z: 0.00031138912852121724
step: 1750, loss: 74.60381885349355, L_x: 7388.485848497976,  L_y: 0.7033909084076956,  L_z: 0.00031138920212191705


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.64649472611616, L_x: 7391.797130285864,  L_y: 0.713862517799386,  L_z: 0.000293218109162479
step: 250, loss: 74.64649468213666, L_x: 7391.797127847103,  L_y: 0.7138624984509063,  L_z: 0.0002932181042945541
step: 500, loss: 74.64649472577392, L_x: 7391.797130138424,  L_y: 0.713862517806669,  L_z: 0.0002932181316600185
step: 750, loss: 74.6464946995765, L_x: 7391.797128220302,  L_y: 0.7138625166420587,  L_z: 0.00029321801462848103
step: 1000, loss: 74.63267918357327, L_x: 7391.143423421891,  L_y: 0.7066056159160623,  L_z: 0.0002927866687658821
step: 1250, loss: 74.63267906832431, L_x: 7391.14342232066,  L_y: 0.7066055076965898,  L_z: 0.00029278674842234196
step: 1500, loss: 74.6326790639722, L_x: 7391.143421828326,  L_y: 0.7066055112979349,  L_z: 0.00029278668782019274
step: 1750, loss: 74.63267906568629, L_x: 7391.143422618273,  L_y: 0.7066055076681851,  L_z: 0.00029278663670747034


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.61794906153418, L_x: 7389.701700557411,  L_y: 0.7070918364813288,  L_z: 0.0002768043895746694
step: 250, loss: 74.61794897133069, L_x: 7389.7016944730185,  L_y: 0.7070918089596098,  L_z: 0.0002768043528180598
step: 500, loss: 74.61794906181251, L_x: 7389.7017005569305,  L_y: 0.7070918364712363,  L_z: 0.0002768043954391615
step: 750, loss: 74.61794906150814, L_x: 7389.701700556723,  L_y: 0.7070918364792426,  L_z: 0.000276804389233366
step: 1000, loss: 74.60433769832022, L_x: 7389.013461398955,  L_y: 0.7003877701307222,  L_z: 0.0002763062839992281
step: 1250, loss: 74.60433758098905, L_x: 7389.013459773219,  L_y: 0.7003876685118591,  L_z: 0.00027630629489986465
step: 1500, loss: 74.60433759500538, L_x: 7389.013460704457,  L_y: 0.7003876724768534,  L_z: 0.00027630630967897103
step: 1750, loss: 74.6043375248415, L_x: 7389.013456019189,  L_y: 0.700387650096844,  L_z: 0.00027630629105509824


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.59941576278565, L_x: 7387.705238944971,  L_y: 0.7095711186332817,  L_z: 0.00025584509405332246
step: 250, loss: 74.59941576257991, L_x: 7387.705238872682,  L_y: 0.709571118275867,  L_z: 0.00025584511154441773
step: 500, loss: 74.59941574535003, L_x: 7387.705237470911,  L_y: 0.7095711197602457,  L_z: 0.0002558450176135064
step: 750, loss: 74.59941573949013, L_x: 7387.705236660979,  L_y: 0.7095711175001103,  L_z: 0.0002558451076045491
step: 1000, loss: 74.58568440576829, L_x: 7387.021106616397,  L_y: 0.7026985247401467,  L_z: 0.0002554962972835426
step: 1250, loss: 74.58568432118702, L_x: 7387.021105168358,  L_y: 0.7026984569240354,  L_z: 0.00025549625158802593
step: 1500, loss: 74.58568428630404, L_x: 7387.021101823797,  L_y: 0.7026984547907736,  L_z: 0.0002554962655060135
step: 1750, loss: 74.58568433646646, L_x: 7387.021106614713,  L_y: 0.7026984557268487,  L_z: 0.00025549629184937157


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.63473697380529, L_x: 7391.009712095073,  L_y: 0.7110510558476769,  L_z: 0.0002717759401375038
step: 250, loss: 74.63473697354131, L_x: 7391.009712089044,  L_y: 0.7110510558513737,  L_z: 0.0002717759359900353
step: 500, loss: 74.63473697377972, L_x: 7391.009712095729,  L_y: 0.71105105583702,  L_z: 0.00027177593970829537
step: 750, loss: 74.63473692007993, L_x: 7391.009707491867,  L_y: 0.7110510541112217,  L_z: 0.00027177582100084444
step: 1000, loss: 74.62079730753197, L_x: 7390.317287623689,  L_y: 0.7040478507338191,  L_z: 0.0002715316112253115
step: 1250, loss: 74.62079718509507, L_x: 7390.317286261238,  L_y: 0.704047741192952,  L_z: 0.00027153162579492585
step: 1500, loss: 74.6207972022866, L_x: 7390.317287625736,  L_y: 0.7040477459056067,  L_z: 0.00027153160247246047
step: 1750, loss: 74.62079720274545, L_x: 7390.317287633629,  L_y: 0.7040477458084411,  L_z: 0.0002715316120142895


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.6554025549886, L_x: 7392.704843633043,  L_y: 0.7137214983103316,  L_z: 0.00029265240695636836
step: 250, loss: 74.65540254959298, L_x: 7392.704842786676,  L_y: 0.7137214979875668,  L_z: 0.0002926524747729373
step: 500, loss: 74.6554025421401, L_x: 7392.704842291196,  L_y: 0.7137214994955345,  L_z: 0.0002926523946520217
step: 750, loss: 74.65540253194884, L_x: 7392.704841684332,  L_y: 0.7137214974007486,  L_z: 0.0002926523540951513
step: 1000, loss: 74.64161350844765, L_x: 7392.029744870169,  L_y: 0.7066909806807865,  L_z: 0.0002925015813030308
step: 1250, loss: 74.64161341920071, L_x: 7392.02974487827,  L_y: 0.706690891724359,  L_z: 0.0002925015738728669
step: 1500, loss: 74.64161341962256, L_x: 7392.029744878875,  L_y: 0.7066908917657283,  L_z: 0.0002925015813614236
step: 1750, loss: 74.64161340157138, L_x: 7392.0297433032265,  L_y: 0.7066908865860346,  L_z: 0.0002925016390616214


  0%|          | 0/6 [00:00<?, ?it/s]

step: 0, loss: 74.60988469755064, L_x: 7389.053158172854,  L_y: 0.7062663636058429,  L_z: 0.0002617350443250448
step: 250, loss: 74.60988469748462, L_x: 7389.053158167607,  L_y: 0.7062663636083986,  L_z: 0.00026173504400311406
step: 500, loss: 74.60988464630516, L_x: 7389.053152214376,  L_y: 0.7062663662349019,  L_z: 0.00026173515853006475
step: 750, loss: 74.60988466963698, L_x: 7389.053155600865,  L_y: 0.7062663619096606,  L_z: 0.0002617350343735509
step: 1000, loss: 74.59658791202548, L_x: 7388.368324807604,  L_y: 0.6998330744279408,  L_z: 0.0002614317904300041
step: 1250, loss: 74.59658783623875, L_x: 7388.368324807195,  L_y: 0.6998329986272048,  L_z: 0.00026143179079224637
step: 1500, loss: 74.59658780027499, L_x: 7388.368320809135,  L_y: 0.6998329989956232,  L_z: 0.000261431863760391
step: 1750, loss: 74.59658776410566, L_x: 7388.368319575694,  L_y: 0.6998329780694408,  L_z: 0.00026143180558558696


  0%|          | 0/6 [00:00<?, ?it/s]

# Utils for performance metrics

We need to save the indexes of both groups `privileged` and `discriminated` in two lists.

`y_privileged` is the part the dataset where `sensible_value` = 1 (for example `AgeCategory` = 1), and `y_discriminated` is the part of dataset where `sensible_value` = 0.

Build the confusion matrices (one for the privileged group, one for the discriminated group) for each model.


Not for in-processing that has only one ML model!

In [29]:
def compute_scores(predictions_and_tests, models, n_splits):
  precision = {}
  recall = {}
  accuracy = {}
  f1_score = {}

  if mitigation not in without_model_mitigations:
    for model_name in (models):

      precisions = []
      recalls = []
      accuracys = []
      f1_scores = []
      for i in range(0,n_splits):
        y_test = predictions_and_tests[model_name][i]['y_test']
        y_pred = predictions_and_tests[model_name][i]['y_pred']
        #print(len(y_test), len(y_pred))
        precisions.append(metrics.precision_score(y_test, y_pred))
        recalls.append(metrics.recall_score(y_test, y_pred))
        accuracys.append(metrics.accuracy_score(y_test, y_pred))
        f1_scores.append(metrics.f1_score(y_test, y_pred))
      precision[model_name] = precisions
      recall[model_name] = recalls
      accuracy[model_name] = accuracys
      f1_score[model_name] = f1_scores
  else:
    precisions = []
    recalls = []
    accuracys = []
    f1_scores = []
    for i in range(0,n_splits):
        y_test = predictions_and_tests[i]['y_test']
        y_pred = predictions_and_tests[i]['y_pred']
        #print(len(y_test), len(y_pred))
        precisions.append(metrics.precision_score(y_test, y_pred))
        recalls.append(metrics.recall_score(y_test, y_pred))
        accuracys.append(metrics.accuracy_score(y_test, y_pred))
        f1_scores.append(metrics.f1_score(y_test, y_pred))
    precision = precisions
    recall = recalls
    accuracy = accuracys
    f1_score = f1_scores
  return accuracy, precision, recall, f1_score

In [30]:
def compute_mean_std_dev(metric_list, models):
  metric_dict = {}
  if models is not None:
    for model_name in (models):
      metric = np.array(metric_list[model_name])
      mean_metric = metric.mean()
      std_metric = metric.std()
      metric_dict[model_name] = [mean_metric, std_metric]
  else:
    metric = np.array(metric_list)
    mean_metric = metric.mean()
    std_metric = metric.std()
    metric_dict = [mean_metric, std_metric]
  return metric_dict

In [31]:
def compute_performance_metrics(predictions_and_tests, models, n_splits):
  accuracy, precision, recall, f1_score = compute_scores(predictions_and_tests, models, n_splits)

  if mitigation not in without_model_mitigations:
    #for each model compute mean and standard deviation
    acc = compute_mean_std_dev(accuracy, models)
    prec = compute_mean_std_dev(precision, models)
    rec = compute_mean_std_dev(recall, models)
    f1 = compute_mean_std_dev(f1_score, models)
  else:
    acc = compute_mean_std_dev(accuracy, None)
    prec = compute_mean_std_dev(precision, None)
    rec = compute_mean_std_dev(recall, None)
    f1 = compute_mean_std_dev(f1_score, None)

  performance_metrics = {}
  performance_metrics['accuracy'] = acc
  performance_metrics['precision'] = prec
  performance_metrics['recall'] = rec
  performance_metrics['f1_score'] = f1

  return performance_metrics

In [32]:
def compute_confusion_matrices(predictions_and_tests, target_variable_labels, models, n_splits):
  confusion_matrices = {}
  if mitigation not in without_model_mitigations:
    for model_name in (models):
      cm_splits = {}
      for i in range(0,n_splits):
        temp_dict = {}
        cm_priviliged = {}
        cm_discriminated = {}
        y_test = predictions_and_tests[model_name][i]['y_test']
        y_pred = predictions_and_tests[model_name][i]['y_pred']
        s_test = predictions_and_tests[model_name][i]['s_test']

        df_metrics = pd.DataFrame({'s_test': s_test, 'y_test':y_test, 'y_pred':y_pred})
        df_discrim = df_metrics[df_metrics['s_test'] == 0]
        #len_dicr = len(df_discrim)
        df_priv = df_metrics[df_metrics['s_test'] == 1]
        #len_priv = len(df_priv)

        cm_discriminated = confusion_matrix(df_discrim['y_test'], df_discrim['y_pred'], labels=target_variable_labels)
        cm_privileged = confusion_matrix(df_priv['y_test'], df_priv['y_pred'], labels=target_variable_labels)
        temp_dict['discriminated'] = cm_discriminated
        temp_dict['privileged'] = cm_privileged
        cm_splits[i] = temp_dict
      confusion_matrices[model_name] = cm_splits
  else:
    cm_splits = {}
    for i in range(0,n_splits):
      temp_dict = {}
      cm_priviliged = {}
      cm_discriminated = {}
      y_test = predictions_and_tests[i]['y_test']
      y_pred = predictions_and_tests[i]['y_pred']
      s_test = predictions_and_tests[i]['s_test']

      df_metrics = pd.DataFrame({'s_test': s_test, 'y_test':y_test, 'y_pred':y_pred})

      df_discrim = df_metrics[df_metrics['s_test'] == 0]
      #len_dicr = len(df_discrim)
      df_priv = df_metrics[df_metrics['s_test'] == 1]
      #len_priv = len(df_priv)

      cm_discriminated = confusion_matrix(df_discrim['y_test'], df_discrim['y_pred'])
      cm_privileged = confusion_matrix(df_priv['y_test'], df_priv['y_pred'])
      temp_dict['discriminated'] = cm_discriminated
      temp_dict['privileged'] = cm_privileged
      cm_splits[i] = temp_dict
      confusion_matrices = cm_splits
  return confusion_matrices

##Functions to compute fairness metrics

Terminology:

- d is the predicted value,
- Y is the actual value in the dataset
- G the protected attribute, priv= privileged group, discr=discriminated group
- L is the legittimate attribute (only for Conditional Statistical Parity)

Fairness Metrics List:

1. Group Fairness: (d=1|G=priv) = (d=1|G=discr)
2. Predictive Parity: (Y=1|d=1,G=priv) = (Y=1|d=1,G=discr)
3. Predictive Equality: (d=1|Y=0,G=priv) = (d=1|Y=0,G=discr)
4. Equal Opportunity:  (d=0|Y=1,G=priv) = (d=0|Y=1,G=discr)
5. Equalized Odds: (d=1|Y=i,G=priv) = (d=1|Y=i,G=discr), i ∈ 0,1
6. ConditionalUseAccuracyEquality: (Y=1|d=1, G=priv) = (Y=1|d=1,G=discr) and (Y=0|d=0,G=priv) = (Y=0|d=0,G=discr)
7. Overall Accuracy Equality: (d=Y, G=priv) = (d=Y, G=priv)
8. Treatment Equality: (Y=1, d=0, G=priv)/(Y=0, d=1, G=priv) = (Y=1, d=0, G=discr)/(Y=0, d=1, G=discr)
9. FOR Parity: (Y=1|d=0, G=priv) = (Y=1|d=0,G=discr)

How to evaluate the results?

Looking at the value for each corresponding metric:

- If the value is between 0 and 1-t the discriminated group suffers from unfairness
- If the value is greater than 1+t the privileged group suffers from unfairness
- If the value is between 1-t and 1+t both privileged and discriminated group have a fair treatment

t is a threshold that should be choose by the user according to the context and the goal of the task.


In [33]:
# Retrieve TP, TN, FP, FN values from a confusion matrix
def retrieve_values(cm):
  TN = cm[0][0]
  FP = cm[0][1]
  FN = cm[1][0]
  TP = cm[1][1]
  total = TN+FP+FN+TP
  return TP, TN, FP, FN, total

def rescale(metric):
  metric = metric - 1
  return metric

def standardization(metric):
  if metric > 1:
    metric = 1
  elif metric < -1:
    metric = -1
  return metric

def valid(metric, th):
  if metric > 1-th and metric < 1+th:
    return True
  return False

def and_function(m1, m2, th):
  if m1 > 1+th and m2 > 1+th:
    return max(m1, m2)
  elif m1 < 1-th and m2 < 1-th:
    return min(m1, m2)
  elif valid(m1, th) and valid(m2, th):
    return max(m1, m2)
  elif (valid(m1, th) or valid(m2, th)) and (m1 > 1+th or m2 > 1+th):
    return max(m1, m2)
  elif (valid(m1, th) or valid(m2, th)) and (m1 < 1-th or m2 < 1-th):
    return min(m1, m2)
  else:
    return max(m1, m2)

In [34]:
# Fairness metrics computed using division operator
def fairness_metrics_division(confusion_matrix, threshold = 0.15):

  TP_priv, TN_priv, FP_priv, FN_priv, len_priv = retrieve_values(confusion_matrix['privileged'])
  TP_discr, TN_discr, FP_discr, FN_discr, len_discr = retrieve_values(confusion_matrix['discriminated'])

  GroupFairness_discr = (TP_discr+FP_discr)/len_discr
  GroupFairness_priv = (TP_priv+FP_priv)/len_priv
  if GroupFairness_priv == 0:
    GroupFairness = 2  #max value
  else:
     GroupFairness = GroupFairness_discr/GroupFairness_priv

  if TP_discr+FP_discr == 0:
    PredictiveParity_discr = 0
    PredictiveParity = 0  #min value
  else:
    PredictiveParity_discr = (TP_discr)/(TP_discr+FP_discr)
  if TP_priv+FP_priv == 0:
    PredictiveParity_priv = 0
    PredictiveParity = 2  #max value
  else:
    PredictiveParity_priv = (TP_priv)/(TP_priv+FP_priv)
  if PredictiveParity_discr != 0 and PredictiveParity_priv != 0:
    PredictiveParity = PredictiveParity_discr/PredictiveParity_priv
  elif PredictiveParity_priv == 0:
    PredictiveParity = 2  #max value
  else:
    PredictiveParity = 0  #min value

  if TN_discr+FP_discr == 0:
    PredictiveEquality_discr = 0
    PredictiveEquality = 0  #min value
  else:
    PredictiveEquality_discr = (FP_discr)/(TN_discr+FP_discr)
  if TN_priv+FP_priv == 0:
    PredictiveEquality_priv = 0
    PredictiveEquality = 2  #max value
  else:
    PredictiveEquality_priv = (FP_priv)/(TN_priv+FP_priv)
  if PredictiveEquality_discr != 0 and PredictiveEquality_priv != 0:
    PredictiveEquality = PredictiveEquality_discr/PredictiveEquality_priv
  elif PredictiveEquality_priv == 0:
    PredictiveEquality = 2  #max value
  else:
    PredictiveEquality = 0  #min value

  if FN_priv+TP_priv == 0:
    EqualOpportunity_priv = 0
    EqualOpportunity = 2  #max value
  else:
    EqualOpportunity_priv = (FN_priv)/(TP_priv+FN_priv)
  if FN_discr+TP_discr == 0:
    EqualOpportunity_discr = 0
    EqualOpportunity = 0  #min value
  else:
    EqualOpportunity_discr = (FN_discr)/(TP_discr+FN_discr)
  if EqualOpportunity_priv != 0 and EqualOpportunity_discr != 0:
    EqualOpportunity = EqualOpportunity_priv/EqualOpportunity_discr
  elif EqualOpportunity_discr == 0:
    EqualOpportunity = 0  #min value
  else:
    EqualOpportunity = 2  #max value

  if FN_discr+TP_discr == 0:
    EqualizedOdds1 = 0
    EqualizedOdds = 0 #min value
  elif FN_priv+TP_priv == 0:
    EqualizedOdds1 = 0
    EqualizedOdds = 2 #max value
  elif (TP_priv/(TP_priv+FN_priv)) == 0:
    EqualizedOdds1 = 2 #max value
  else:
    EqualizedOdds1 = ((TP_discr/(TP_discr+FN_discr)) / (TP_priv/(TP_priv+FN_priv))) # (1-equalOpportunity_discr)/(1-equalOpportunity_priv)
  if TN_priv+FP_priv == 0:
    EqualizedOdds2 = 0
    EqualizedOdds = 2 #max value
  elif TN_discr+FP_discr == 0:
    EqualizedOdds2 = 0
    EqualizedOdds = 0 #min value
  elif (FP_priv/(TN_priv+FP_priv)) == 0:
    EqualizedOdds2 = 2 #max value
  else:
    EqualizedOdds2 = ((FP_discr/(TN_discr+FP_discr)) / (FP_priv/(TN_priv+FP_priv))) # = PredictiveEquality
  # EqualizedOdds = (EqualizedOdds1 * EqualizedOdds2)
  if EqualizedOdds1 != 0 and EqualizedOdds2 != 0:
    EqualizedOdds = and_function(EqualizedOdds1, EqualizedOdds2, threshold)
  else:
    EqualizedOdds = 2 #max value

  if TP_discr+FP_discr == 0 or TN_discr+FP_discr == 0:
    ConditionalUseAccuracyEquality1 = 0
    ConditionalUseAccuracyEquality= 0 #min value
  elif (TP_priv/(TP_priv+FP_priv)) == 0:
    ConditionalUseAccuracyEquality1 = 2 #max value
  else:
    ConditionalUseAccuracyEquality1 = ((TP_discr/(TP_discr+FP_discr)) / (TP_priv/(TP_priv+FP_priv)))
  if TN_discr+FN_discr == 0 or TN_priv+FN_priv == 0:
    ConditionalUseAccuracyEquality2 = 0
    ConditionalUseAccuracyEquality = 2 #max value
  elif (TN_priv/(TN_priv+FN_priv)) == 0:
    ConditionalUseAccuracyEquality2 = 2 #max value
  else:
    ConditionalUseAccuracyEquality2 = ((TN_discr/(TN_discr+FN_discr)) / (TN_priv/(TN_priv+FN_priv)))
  # ConditionalUseAccuracyEquality = (ConditionalUseAccuracyEquality1 * ConditionalUseAccuracyEquality2)
  if ConditionalUseAccuracyEquality1 != 0 and ConditionalUseAccuracyEquality2 != 0:
    ConditionalUseAccuracyEquality = and_function(ConditionalUseAccuracyEquality1, ConditionalUseAccuracyEquality2, threshold)
  else:
    ConditionalUseAccuracyEquality = 2 #max value

  if TP_priv == 0:
    OAE1 = 0
    OverallAccuracyEquality = 2 #max value
  else:
    OAE1 = TP_discr/TP_priv
  if TN_priv == 0:
    OAE2 = 0
    OverallAccuracyEquality = 2 #max value
  else:
    OAE2 = TN_discr/TN_priv
  # OverallAccuracyEquality = (OAE1 * OAE2)
  if OAE1 != 0 and OAE2 != 0:
    OverallAccuracyEquality = and_function(OAE1, OAE2, threshold)
  else:
    OverallAccuracyEquality = 2 #max value

  if FP_priv == 0:
    TreatmentEquality_priv = 0
    TreatmentEquality = 2  #max value
  else:
    TreatmentEquality_priv = (FN_priv/FP_priv)
  if FP_discr == 0:
    TreatmentEquality_discr = 0
    TreatmentEquality = 0 #min value
  elif (FN_discr/FP_discr) == 0:
    TreatmentEquality_discr = 0 #max value
    TreatmentEquality = 0 #min value
  else:
    TreatmentEquality_discr = (FN_discr/FP_discr)
  if TreatmentEquality_priv != 0 and TreatmentEquality_discr != 0:
    TreatmentEquality = TreatmentEquality_priv/TreatmentEquality_discr
  elif TreatmentEquality_priv == 0:
    TreatmentEquality = 2 #max value
  else:
    TreatmentEquality = 0 #min value

  if TN_priv+FN_priv == 0:
    FORParity_priv = 0
    FORParity = 2 #max value
  else:
    FORParity_priv = (FN_priv)/(TN_priv+FN_priv)
  if TN_discr+FN_discr == 0:
    FORParity_discr = 0
    FORParity = 0  #min value
  elif (FN_discr)/(TN_discr+FN_discr) == 0:
    FORParity_discr = 0
    FORParity = 0 #min value
  else:
    FORParity_discr = (FN_discr)/(TN_discr+FN_discr)
  if FORParity_priv != 0 and FORParity_discr != 0:
    FORParity = FORParity_priv/FORParity_discr
  elif FORParity_priv == 0:
    FORParity = 2 #max value
  else:
    FORParity = 0 #min value


  FN_P_discr = (FN_discr)/len_discr
  FN_P_priv = (FN_priv)/len_priv
  if FN_P_discr == 0:
    FN_metric = 2  #max value
  else:
    FN_metric = FN_P_priv/FN_P_discr


  FP_P_discr = (FP_discr)/len_discr
  FP_P_priv = (FP_priv)/len_priv
  if FP_P_priv == 0:
    FP_metric = 0  #min value
  else:
    FP_metric = FP_P_discr/FP_P_priv


  #RecallParity = (TP_discr/(TP_discr+FN_discr))/(TP_priv/(TP_priv+FN_priv))

  metrics = {}
  metrics['GroupFairness'] = [GroupFairness, GroupFairness_discr, GroupFairness_priv]
  metrics['PredictiveParity'] = [PredictiveParity, PredictiveParity_discr, PredictiveParity_priv]
  metrics['PredictiveEquality'] = [PredictiveEquality, PredictiveEquality_discr, PredictiveEquality_priv]
  metrics['EqualOpportunity'] = [EqualOpportunity, EqualOpportunity_discr, EqualOpportunity_priv]
  metrics['EqualizedOdds'] = [EqualizedOdds, EqualizedOdds1, EqualizedOdds2]
  metrics['ConditionalUseAccuracyEquality'] = [ConditionalUseAccuracyEquality, ConditionalUseAccuracyEquality1 , ConditionalUseAccuracyEquality2]
  metrics['OverallAccuracyEquality'] = [OverallAccuracyEquality, OAE1, OAE2]
  metrics['TreatmentEquality'] = [TreatmentEquality, TreatmentEquality_discr, TreatmentEquality_priv]
  metrics['FORParity'] = [FORParity, FORParity_discr, FORParity_priv]
  metrics['FN'] = [FN_metric, FN_P_discr, FN_P_priv]
  metrics['FP'] = [FP_metric, FP_P_discr, FP_P_priv]

  for k in metrics.keys():
    value = standardization(rescale(metrics[k][0]))
    discr = metrics[k][1]
    priv = metrics[k][2]
    metrics[k] = {'Value': value, 'Discr_group': discr, 'Priv_group': priv}

  return metrics


# Fairness metrics computed using subtraction operator
def fairness_metrics_subtraction(confusion_matrix, threshold = 0.15):

  TP_priv, TN_priv, FP_priv, FN_priv, len_priv = retrieve_values(confusion_matrix['privileged'])
  TP_discr, TN_discr, FP_discr, FN_discr, len_discr = retrieve_values(confusion_matrix['discriminated'])

  GroupFairness_discr = (TP_discr+FP_discr)/len_discr
  GroupFairness_priv = (TP_priv+FP_priv)/len_priv
  GroupFairness = GroupFairness_priv-GroupFairness_discr

  if (TP_discr+FP_discr) == 0:
    PredictiveParity_discr = 0
    PredictiveParity = -1  #min value
  else:
    PredictiveParity_discr = (TP_discr)/(TP_discr+FP_discr)
  if (TP_priv+FP_priv) == 0:
    PredictiveParity_priv = 0
    PredictiveParity = 1 #max value
  else:
    PredictiveParity_priv = (TP_priv)/(TP_priv+FP_priv)
  if PredictiveParity_priv != 0 and PredictiveParity_discr != 0:
    PredictiveParity = PredictiveParity_priv-PredictiveParity_discr
  elif PredictiveParity_priv == 0:
    PredictiveParity = 1 #max value
  else:
    PredictiveParity = -1 #min value

  if TN_discr+FP_discr == 0:
    PredictiveEquality_discr = 0
    PredictiveEquality = -1  #min value
  else:
    PredictiveEquality_discr = (FP_discr)/(TN_discr+FP_discr)
  if TN_priv+FP_priv == 0:
    PredictiveEquality_priv = 0
    PredictiveEquality = 1 #max value
  else:
    PredictiveEquality_priv = (FP_priv)/(TN_priv+FP_priv)
  if PredictiveEquality_priv != 0 and PredictiveEquality_discr != 0:
    PredictiveEquality = PredictiveEquality_priv-PredictiveEquality_discr
  elif PredictiveEquality_priv == 0:
    PredictiveEquality = 1 #max value
  else:
    PredictiveEquality = -1 #min value

  if TP_discr+FN_discr == 0:
    EqualOpportunity_discr = 0
    EqualOpportunity = -1  #min value
  else:
    EqualOpportunity_discr = (FN_discr)/(TP_discr+FN_discr)
  if TP_priv+FN_priv == 0:
    EqualOpportunity_priv = 0
    EqualOpportunity = 1 #max value
  else:
    EqualOpportunity_priv = (FN_priv)/(TP_priv+FN_priv)
  if EqualOpportunity_priv != 0 and EqualOpportunity_discr != 0:
    EqualOpportunity = EqualOpportunity_priv-EqualOpportunity_discr
  elif EqualOpportunity_priv == 0:
    EqualOpportunity = 1 #max value
  else:
    EqualOpportunity = -1 #min value

  if FN_discr+TP_discr == 0:
    EqualizedOdds1 = 0
    EqualizedOdds = -1 #min value
  elif FN_priv+TP_priv == 0:
    EqualizedOdds1 = 0
    EqualizedOdds = 1 #max value
  else:
    EqualizedOdds1 = (TP_priv/(TP_priv+FN_priv))-(TP_discr/(TP_discr+FN_discr)) # (1-equalOpportunity_discr)/(1-equalOpportunity_priv)
  if FP_priv+TN_priv == 0:
    EqualizedOdds2 = 0
    EqualizedOdds = 1 #max value
  elif FP_discr+TN_discr == 0:
    EqualizedOdds2 = 0
    EqualizedOdds = -1 #min value
  else:
    EqualizedOdds2 = (FP_priv/(TN_priv+FP_priv))-(FP_discr/(TN_discr+FP_discr)) # = PredictiveEquality
  if EqualizedOdds1 != 0 and EqualizedOdds2 != 0:
    EqualizedOdds = and_function(EqualizedOdds1, EqualizedOdds2, threshold)
  elif EqualizedOdds1 == 0:
    EqualizedOdds = 1 #max value
  else:
    EqualizedOdds = -1 #min value

  if TP_discr+FP_discr == 0:
    ConditionalUseAccuracyEquality1 = 0
    ConditionalUseAccuracyEquality= -1 #min value
  elif TP_priv+FP_priv == 0:
    ConditionalUseAccuracyEquality1 = 0
    ConditionalUseAccuracyEquality = 1 #max value
  else:
    ConditionalUseAccuracyEquality1 = (TP_priv/(TP_priv+FP_priv)) - (TP_discr/(TP_discr+FP_discr))
  if TN_discr+FN_discr == 0:
    ConditionalUseAccuracyEquality2 = 0
    ConditionalUseAccuracyEquality = -1 #min value
  elif TN_priv+FN_priv == 0:
    ConditionalUseAccuracyEquality2 = 0
    ConditionalUseAccuracyEquality = 1 #max value
  else:
    ConditionalUseAccuracyEquality2 = (TN_priv/(TN_priv+FN_priv)) - (TN_discr/(TN_discr+FN_discr))
  if ConditionalUseAccuracyEquality1 != 0 and ConditionalUseAccuracyEquality2 != 0:
    ConditionalUseAccuracyEquality = and_function(ConditionalUseAccuracyEquality1, ConditionalUseAccuracyEquality2, threshold)
  elif ConditionalUseAccuracyEquality1 == 0:
    ConditionalUseAccuracyEquality = 1 #max value
  else:
    ConditionalUseAccuracyEquality = -1 #min value

  OAE1 = TP_priv-TP_discr
  OAE2 = TN_priv-TN_discr
  OverallAccuracyEquality = and_function(OAE1, OAE2, threshold)

  if FP_discr == 0:
    TreatmentEquality_discr = 0
    TreatmentEquality = 0 #min value
  else:
    TreatmentEquality_discr = (FN_discr/FP_discr)
  if FP_priv == 0:
    TreatmentEquality_priv = 0
    TreatmentEquality = 2 #max value
  else:
    TreatmentEquality_priv = (FN_priv/FP_priv)
  if TreatmentEquality_priv != 0 and TreatmentEquality_discr != 0:
    TreatmentEquality = TreatmentEquality_priv-TreatmentEquality_discr
  elif TreatmentEquality_priv == 0:
    TreatmentEquality = 2 #max value
  else:
    TreatmentEquality = 0 #min value

  if TN_discr+FN_discr == 0:
    FORParity_discr = 0
    FORParity = -1  #min value
  else:
    FORParity_discr = (FN_discr)/(TN_discr+FN_discr)
  if TN_priv+FN_priv == 0:
    FORParity_priv = 0
    FORParity = 1 #max value
  else:
    FORParity_priv = (FN_priv)/(TN_priv+FN_priv)
  if FORParity_priv != 0 and FORParity_discr != 0:
    FORParity = FORParity_priv-FORParity_discr
  elif FORParity_priv == 0:
    FORParity = 1 #max value
  else:
    FORParity = -1 #min value

  FN_P_discr =  (FN_discr)/len_discr
  FN_P_priv =  (FN_priv)/len_priv

  FP_P_discr = (FP_discr)/len_discr
  FP_P_priv =  (FP_priv)/len_priv

  #RecallParity = (TP_discr/(TP_discr+FN_discr))/(TP_priv/(TP_priv+FN_priv))

  metrics = {}
  metrics['GroupFairness'] = [GroupFairness, GroupFairness_discr, GroupFairness_priv]
  metrics['PredictiveParity'] = [PredictiveParity, PredictiveParity_discr, PredictiveParity_priv]
  metrics['PredictiveEquality'] = [PredictiveEquality, PredictiveEquality_discr, PredictiveEquality_priv]
  metrics['EqualOpportunity'] = [EqualOpportunity, EqualOpportunity_discr, EqualOpportunity_priv]
  metrics['EqualizedOdds'] = [EqualizedOdds, EqualizedOdds1, EqualizedOdds2]
  metrics['ConditionalUseAccuracyEquality'] = [ConditionalUseAccuracyEquality, ConditionalUseAccuracyEquality1 , ConditionalUseAccuracyEquality2]
  metrics['OverallAccuracyEquality'] = [OverallAccuracyEquality, OAE1, OAE2]
  metrics['TreatmentEquality'] = [TreatmentEquality, TreatmentEquality_discr, TreatmentEquality_priv]
  metrics['FORParity'] = [FORParity, FORParity_discr, FORParity_priv]
  metrics['FN'] = [FN_P_priv-FN_P_discr, FN_P_discr, FN_P_priv]
  metrics['FP'] = [FP_P_discr-FP_P_priv, FP_P_discr, FP_P_priv]

  for k in metrics.keys():
    value = standardization(metrics[k][0])
    discr = metrics[k][1]
    priv = metrics[k][2]
    metrics[k] = {'Value': value, 'Discr_group': discr, 'Priv_group': priv}

  return metrics

In [35]:
def compute_fairness_metrics(predictions_and_tests, target_variable_labels, models, n_splits):
  confusion_matrices = compute_confusion_matrices(predictions_and_tests, target_variable_labels, models, n_splits)
  fairness_metrics = {}
  sub_fairness_metrics = {}
  div_fairness_metrics = {}
  sub_dict = {}
  div_dict = {}
  #mitigation technique allow multiple models
  if mitigation not in without_model_mitigations:
    for model_name in (models):
      sub_dict = {}
      div_dict = {}
      for i in range(0,n_splits):
        model_split_conf_matrix = fairness_metrics_division(confusion_matrices[model_name][i])
        sub_dict[i] = fairness_metrics_subtraction(confusion_matrices[model_name][i])
        div_dict[i] = fairness_metrics_division(confusion_matrices[model_name][i])

      div_fairness_metrics[model_name] = div_dict
      sub_fairness_metrics[model_name] = sub_dict
  else:
    sub_dict = {}
    div_dict = {}
    for i in range(0,n_splits):
        sub_dict[i] = fairness_metrics_subtraction(confusion_matrices[i])
        div_dict[i] = fairness_metrics_division(confusion_matrices[i])

    div_fairness_metrics = div_dict
    sub_fairness_metrics = sub_dict

  fairness_metrics['division'] = div_fairness_metrics
  fairness_metrics['subtraction'] = sub_fairness_metrics

  return fairness_metrics

In [36]:
def compute_mean_std_dev_fairness_metrics(fairness_metrics, models):
  family_metrics = {}
  for f in family:
    model_metrics = {}
    #print(f)
    #mitigation technique allow multiple models
    if mitigation not in without_model_mitigations:
      metric_dict = {}
      for m in models:
        #print(m)
        for fair_m in fairness_catalogue:
          #print(fair_m)
          vec_metrics = []
          for i in range(0,n_splits):
            vec_metrics.append(fairness_metrics[f][m][i][fair_m]['Value'])
          #print(vec_metrics)
          #print(np.mean(vec_metrics), np.std(vec_metrics))
          metric_dict[fair_m] = [np.mean(vec_metrics), np.std(vec_metrics)]
        #print(metric_dict)
        model_metrics[m] = metric_dict
    #without multiple models
    else:
      metric_dict = {}
      for fair_m in fairness_catalogue:
        vec_metrics = []
        for i in range(0,n_splits):
          vec_metrics.append(fairness_metrics[f][i][fair_m]['Value'])
        metric_dict[fair_m] = [np.mean(vec_metrics), np.std(vec_metrics)]
      model_metrics = metric_dict

    family_metrics[f]=model_metrics

  return family_metrics

# Compute performance metrics

In [37]:
performance_metrics = compute_performance_metrics(predictions_and_tests, models, n_splits)
print(performance_metrics)

In [None]:
#Save performance metrics
save_path = path_to_project + '/measurements/performance_metrics-{}-{}.p'.format(dataset_name, mitigation)
with open(save_path, 'wb') as fp:
    pickle.dump(performance_metrics, fp, protocol=pickle.HIGHEST_PROTOCOL)

# Compute fairness metrics

In [38]:
if mitigation not in without_model_mitigations:
  print(compute_confusion_matrices(predictions_and_tests, target_variable_labels, models, n_splits))
else:
  print(compute_confusion_matrices(predictions_and_tests, target_variable_labels, None, n_splits))

In [None]:
if mitigation not in without_model_mitigations:
  fairness_metrics = compute_fairness_metrics(predictions_and_tests, target_variable_labels, models, n_splits)
else:
  fairness_metrics = compute_fairness_metrics(predictions_and_tests, target_variable_labels, None, n_splits)
print(fairness_metrics)

{'division': {'Logistic Regression': {0: {'GroupFairness': {'Value': -0.005005005005005003, 'Discr_group': 0.9594594594594594, 'Priv_group': 0.9642857142857143}, 'PredictiveParity': {'Value': 0.0, 'Discr_group': 1.0, 'Priv_group': 1.0}, 'PredictiveEquality': {'Value': 1, 'Discr_group': 0.0, 'Priv_group': 0.0}, 'EqualOpportunity': {'Value': -1, 'Discr_group': 0.0, 'Priv_group': 0.0}, 'EqualizedOdds': {'Value': 1, 'Discr_group': 1.0, 'Priv_group': 2}, 'ConditionalUseAccuracyEquality': {'Value': 0.0, 'Discr_group': 1.0, 'Priv_group': 1.0}, 'OverallAccuracyEquality': {'Value': 0.19999999999999996, 'Discr_group': 1.0518518518518518, 'Priv_group': 1.2}, 'TreatmentEquality': {'Value': 1, 'Discr_group': 0, 'Priv_group': 0}, 'FORParity': {'Value': 1, 'Discr_group': 0, 'Priv_group': 0.0}, 'FN': {'Value': 1, 'Discr_group': 0.0, 'Priv_group': 0.0}, 'FP': {'Value': -1, 'Discr_group': 0.0, 'Priv_group': 0.0}}, 1: {'GroupFairness': {'Value': -0.0064102564102563875, 'Discr_group': 0.96875, 'Priv_group

In [None]:
model_to_print = "Logistic Regression"
m = 'GroupFairness'
round_value = 5

if mitigation not in without_model_mitigations:
  print(m, np.round(fairness_metrics["division"][model_to_print][1][m]["Value"], round_value))
else:
  print(m, np.round(fairness_metrics["division"][1][m]["Value"], round_value))

GroupFairness -0.00641


In [None]:
if mitigation not in without_model_mitigations:
  final_metrics = compute_mean_std_dev_fairness_metrics(fairness_metrics, models)
else:
  final_metrics = compute_mean_std_dev_fairness_metrics(fairness_metrics, None)

print(final_metrics)

{'division': {'Logistic Regression': {'GroupFairness': [0.0004162433230156193, 0.02143738204486135], 'PredictiveParity': [0.0, 0.0], 'PredictiveEquality': [1.0, 0.0], 'EqualOpportunity': [-1.0, 0.0], 'EqualizedOdds': [1.0, 0.0], 'ConditionalUseAccuracyEquality': [0.0, 0.0], 'OverallAccuracyEquality': [-0.017879290215056652, 0.40158984312962265], 'TreatmentEquality': [1.0, 0.0], 'FORParity': [1.0, 0.0], 'FN': [1.0, 0.0], 'FP': [-1.0, 0.0]}, 'Decision Tree': {'GroupFairness': [0.0004162433230156193, 0.02143738204486135], 'PredictiveParity': [0.0, 0.0], 'PredictiveEquality': [1.0, 0.0], 'EqualOpportunity': [-1.0, 0.0], 'EqualizedOdds': [1.0, 0.0], 'ConditionalUseAccuracyEquality': [0.0, 0.0], 'OverallAccuracyEquality': [-0.017879290215056652, 0.40158984312962265], 'TreatmentEquality': [1.0, 0.0], 'FORParity': [1.0, 0.0], 'FN': [1.0, 0.0], 'FP': [-1.0, 0.0]}, 'Bagging': {'GroupFairness': [0.0004162433230156193, 0.02143738204486135], 'PredictiveParity': [0.0, 0.0], 'PredictiveEquality': [1.

In [None]:
model_to_print = "Logistic Regression"
round_value = 5
for f in family:
  print(f)
  for m in fairness_catalogue:
    if mitigation not in without_model_mitigations:
      print(m, np.round(final_metrics[f][model_to_print][m][0], round_value))
    else:
      print(m, np.round(final_metrics[f][m][0], round_value))


division
GroupFairness 0.00042
PredictiveParity 0.0
PredictiveEquality 1.0
EqualOpportunity -1.0
EqualizedOdds 1.0
ConditionalUseAccuracyEquality 0.0
OverallAccuracyEquality -0.01788
TreatmentEquality 1.0
FORParity 1.0
FN 1.0
FP -1.0
subtraction
GroupFairness -0.00011
PredictiveParity 0.0
PredictiveEquality 1.0
EqualOpportunity 1.0
EqualizedOdds 1.0
ConditionalUseAccuracyEquality 1.0
OverallAccuracyEquality 0.4
TreatmentEquality 1.0
FORParity 1.0
FN 0.0
FP 0.0


In [None]:
#Save the metrics results
save_path = path_to_project + '/measurements/metrics-{}-{}.p'.format(dataset_name, mitigation)
with open(save_path, 'wb') as fp:
    pickle.dump(final_metrics, fp, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
print(final_metrics)

{'division': {'Logistic Regression': {'GroupFairness': [0.0004162433230156193, 0.02143738204486135], 'PredictiveParity': [0.0, 0.0], 'PredictiveEquality': [1.0, 0.0], 'EqualOpportunity': [-1.0, 0.0], 'EqualizedOdds': [1.0, 0.0], 'ConditionalUseAccuracyEquality': [0.0, 0.0], 'OverallAccuracyEquality': [-0.017879290215056652, 0.40158984312962265], 'TreatmentEquality': [1.0, 0.0], 'FORParity': [1.0, 0.0], 'FN': [1.0, 0.0], 'FP': [-1.0, 0.0]}, 'Decision Tree': {'GroupFairness': [0.0004162433230156193, 0.02143738204486135], 'PredictiveParity': [0.0, 0.0], 'PredictiveEquality': [1.0, 0.0], 'EqualOpportunity': [-1.0, 0.0], 'EqualizedOdds': [1.0, 0.0], 'ConditionalUseAccuracyEquality': [0.0, 0.0], 'OverallAccuracyEquality': [-0.017879290215056652, 0.40158984312962265], 'TreatmentEquality': [1.0, 0.0], 'FORParity': [1.0, 0.0], 'FN': [1.0, 0.0], 'FP': [-1.0, 0.0]}, 'Bagging': {'GroupFairness': [0.0004162433230156193, 0.02143738204486135], 'PredictiveParity': [0.0, 0.0], 'PredictiveEquality': [1.

In [None]:
print(dataset_name, mitigation)

stroke-prediction aif360-roc


# Extra

Print example of metrics for a given model, e.g., Logistic Regression.

In [None]:
model_to_print = "Logistic Regression"
round_value = 5

print("Division \n")
for m in metrics:
  print(m, np.round(overall_metrics["division"][model_to_print][m]["Value"], round_value))
print("\nSubtraction \n")
for m in metrics:
  print(m, np.round(overall_metrics["subtraction"][model_to_print][m]["Value"], round_value))

Division 



TypeError: 'module' object is not iterable