In [None]:
!pip install folktables

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
# Imports
import numpy as np
import pandas as pd
from folktables import ACSDataSource, ACSEmployment, ACSIncome

from sklearn.preprocessing import normalize, StandardScaler
import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import csv 

In [None]:
# STATE List
states = ['AK', 'AL', 'AR', 'AZ', 'CA', 'CO', 'CT', 'DE', 'FL', 'GA',
           'HI', 'IA', 'ID', 'IL', 'IN', 'KS', 'KY', 'LA', 'MA', 'MD', 'ME',
           'MI', 'MN', 'MO', 'MS', 'MT', 'NC', 'ND', 'NE', 'NH', 'NJ', 'NM',
           'NV', 'NY', 'OH', 'OK', 'OR', 'PA', 'RI', 'SC', 'SD', 'TN', 'TX',
           'UT', 'VA', 'VT', 'WA', 'WI', 'WV', 'WY']
           # TX: 42, CA: 5, UT: 43

data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
data_array = []
for item in states:
  data_array.append(data_source.get_data(states=[item], download=True)) 

Downloading data for 2018 1-Year person survey for AK...
Downloading data for 2018 1-Year person survey for AL...
Downloading data for 2018 1-Year person survey for AR...
Downloading data for 2018 1-Year person survey for AZ...
Downloading data for 2018 1-Year person survey for CA...
Downloading data for 2018 1-Year person survey for CO...
Downloading data for 2018 1-Year person survey for CT...
Downloading data for 2018 1-Year person survey for DE...
Downloading data for 2018 1-Year person survey for FL...
Downloading data for 2018 1-Year person survey for GA...
Downloading data for 2018 1-Year person survey for HI...
Downloading data for 2018 1-Year person survey for IA...
Downloading data for 2018 1-Year person survey for ID...
Downloading data for 2018 1-Year person survey for IL...
Downloading data for 2018 1-Year person survey for IN...
Downloading data for 2018 1-Year person survey for KS...
Downloading data for 2018 1-Year person survey for KY...
Downloading data for 2018 1-Yea

In [None]:
for item in data_array:
  print(item.shape)

(6711, 286)
(47777, 286)
(30503, 286)
(69990, 286)
(378817, 286)
(55928, 286)
(36287, 286)
(9123, 286)
(202160, 286)
(100855, 286)
(14400, 286)
(32362, 286)
(16711, 286)
(126456, 286)
(67680, 286)
(29567, 286)
(45475, 286)
(43589, 286)
(70131, 286)
(59840, 286)
(13275, 286)
(99419, 286)
(55783, 286)
(62416, 286)
(29124, 286)
(10336, 286)
(102523, 286)
(7876, 286)
(19451, 286)
(13780, 286)
(88586, 286)
(19247, 286)
(28927, 286)
(196967, 286)
(119086, 286)
(37648, 286)
(42117, 286)
(129066, 286)
(10489, 286)
(49818, 286)
(8986, 286)
(67950, 286)
(268100, 286)
(31603, 286)
(84755, 286)
(6436, 286)
(76225, 286)
(59833, 286)
(18066, 286)
(5740, 286)


In [None]:
class FERMI(torch.nn.Module):

  def __init__(self, X_train, Y_train, S_train, batch_size=64, epochs=2000):
        
        super(FERMI, self).__init__()
              
        self.X_train = X_train
        self.Y_train = Y_train
        self.S_train = S_train

        self.batch_size = batch_size
        self.epochs = epochs
        
        self.n = X_train.shape[0]
        self.d = X_train.shape[1]
        self.m = Y_train.shape[1]
        if self.m == 1:
          self.m = 2

        self.k = S_train.shape[1]

        self.W = nn.Parameter(torch.zeros(self.k, self.m)) # k: Support of sensitive attributes, m: number of labels
        self.theta = nn.Parameter(torch.zeros(self.d, 1))

        sums = self.S_train.sum(axis=0) / self.n
        print(sums)
        
        print(sums.shape)

        final_entries = []
        for item in sums:
          final_entries.append(1.0 / np.sqrt(item))
        
        self.P_s = np.diag(sums)
        
        self.P_s_sqrt_inv = torch.from_numpy(np.diag(final_entries)).double()
        print(self.P_s_sqrt_inv)


  def forward(self, X):
    outputs = torch.mm(X.double(), self.theta.double())
    logits = torch.sigmoid(outputs)
    return logits

  
  def grad_loss(self, X, Y):
    outputs = torch.mm(X, self.theta.double())
    probs = torch.sigmoid(outputs)
    return torch.matmul(torch.t(X), probs - Y)

  def fairness_regularizer(self, X, S):

    current_batch_size = X.shape[0]      
    summation = 0
    
    Y_hat = torch.sigmoid(torch.matmul(X, self.theta.double()))

    for i in range(current_batch_size):
      
      # Binary output:
      Y_hat_i = torch.zeros(self.m, self.m).double()
      Y_hat_i[0][0] = 1.0 - Y_hat[i]
      Y_hat_i[1][1] = Y_hat[i]

      W_gram = torch.matmul(torch.t(self.W.double()), self.W.double()) # W^T W
      
      summation -= torch.trace(torch.matmul(Y_hat_i, W_gram))
      
      # Not Binary
      # y_square = torch.matmul(Y_hat[i], torch.t(Y_hat[i]))
      # print(Y_hat[i].shape)
      # term1 = torch.matmul(self.W, y_square)
      # term2 = torch.matmul(term1, torch.t(self.W))
      # summation += -torch.trace(term2)

      # Binary
      P_ys = torch.zeros(self.m, self.k).double()
      P_ys[0][0] = S[i][0] * (1.0 - Y_hat[i])
      P_ys[0][1] = S[i][1] * (1.0 - Y_hat[i])
      P_ys[1][0] = S[i][0] * (Y_hat[i])
      P_ys[1][1] = S[i][1] * (Y_hat[i])
      prob_matrix_mul = torch.matmul(P_ys, self.P_s_sqrt_inv)
    
      # term1 = Y_hat[i] * self.W.double()
      # term2 = torch.matmul(term1, torch.t(S[i]).unsqueeze(0))
      # term3 = torch.matmul(term2, self.P_s_sqrt_inv)
      summation += 2 * torch.trace(torch.matmul(prob_matrix_mul, self.W.double())) - 1

      # Not Binary
      # term3 = torch.matmul(self.W, Y_hat[i])
      # term4 = torch.matmul(term3, S[i])
      # term5 = torch.matmul(term4, self.P_s_sqrt_inv)

      # summation += 2 * torch.trace(term5) - 1
    # print(self.lam)
    return summation

In [None]:
def fair_training(fermi, batch_size, epochs, initial_epochs = 300, initial_learning_rate = 1, lam=0.1, learning_rate_min = 0.01, learning_rate_max = 0.01, dr_type='L1', epsilon=0.1):

  X = X_Train
  S_Matrix = S_Train
  Y = Y_Train
  print(X.shape)
  print(S_Matrix.shape)
  print(Y.shape)

  criterion=torch.nn.BCELoss()
  
  minimizer = torch.optim.SGD([fermi.theta, fermi.W], lr=initial_learning_rate)
  # maximizer = torch.optim.SGD([fermi.W], lr=learning_rate_max)
  
  # minimizer_track = []
   # maximizer_track = []

  X_total = torch.from_numpy(X).double()
  Y_total = torch.from_numpy(Y).double()

  for ep in range(epochs + initial_epochs):

      if ep % 10 == 9:
        print(ep+1, " epochs:")

        # Test:
        for i in range(len(states)):
          # if i != 42:
          #   continue
          XTest = X_Test_Array[i]
          YTest = Y_Test_Array[i]
          STest = S_Test_Array[i]
          print("State Name:", states[i])
          pre_logits = np.dot(XTest, fermi.theta.detach().numpy())
          output_logits = 1/(1 + np.exp(-pre_logits))
          final_preds = output_logits > 0.5
          test = YTest == 1
          acc = final_preds == test
          true_preds = acc.sum(axis=0)
          print("Accuracy: ", true_preds[0] / output_logits.shape[0] * 100, "%")

          final_preds = np.array(final_preds)
          intersections = np.dot(final_preds.T, STest)
          numbers = STest.sum(axis=0)

          group1 = intersections[0][0] / numbers[0]
          group2 = intersections[0][1] / numbers[1]
          print("DP Violation: ", np.abs(group1 - group2))
          print("*********************************************")
          
      number_of_iterations = X.shape[0] // batch_size
      for i in range(number_of_iterations):
          
        
          start = i * batch_size
          end = (i+1) * batch_size
          
          current_batch_X = X[start:end]
          current_batch_Y = Y[start:end]
          current_batch_S = S_Matrix[start:end]
          
          XTorch = torch.from_numpy(current_batch_X).double()
          logits = fermi(XTorch)
          YTorch = torch.from_numpy(current_batch_Y).double()
          STorch = torch.from_numpy(current_batch_S).double()
          
          if ep < initial_epochs:
            loss_min = criterion(logits, YTorch)
          else: 

            if dr_type == "L1":
              loss_min = criterion(logits, YTorch) + (lam + epsilon) * fermi.fairness_regularizer(XTorch, STorch)
            
            elif dr_type == "L2":
              # Robust L2
              reg_val = fermi.fairness_regularizer(XTorch, STorch)
              loss_min = criterion(logits, YTorch) + lam * reg_val - epsilon * torch.sqrt(-reg_val) 
              # loss_min = criterion(logits, YTorch) + fermi.fairness_regularizer(XTorch, STorch)
               # loss_min = criterion(logits, YTorch) 
            
          minimizer.zero_grad()
          loss_min.backward()
          
          if ep >= initial_epochs:
            fermi.theta.grad.data.mul_(learning_rate_min / initial_learning_rate) # You can have \eta_w here
            fermi.W.grad.data.mul_(-learning_rate_max / initial_learning_rate) # You can have \eta_w here 

          minimizer.step()
  return fermi.theta, fermi.W

In [None]:
S_Test_Array = []
Y_Test_Array = []
X_Test_Array = []

for item in data_array:
  features, labels, _ = ACSIncome.df_to_numpy(item)
  Y_Test_Array.append(labels[:, np.newaxis])

  # X_Test = normalize(features, axis=0)
  intercept = features.shape[0] * [1]
  intercept_numpy = np.array(intercept)
  intercept_numpy = intercept_numpy[:, np.newaxis]
  X_Test = np.append(features, intercept_numpy, axis=1)
  X_Test_Array.append(X_Test)

  sensitive_attributeTest = features[:, 8] - 1

  one_hot_encodeTest = np.zeros((sensitive_attributeTest.shape[0], 2))
  for i in range(sensitive_attributeTest.shape[0]):
    one_hot_encodeTest[i][sensitive_attributeTest.astype(int)[i]] = 1
  # one_hot_encode[sensitive_attribute.astype(int)] = 1
  print(X_Test.shape)
  S_Test_Array.append(one_hot_encodeTest)



(3546, 11)
(22268, 11)
(13929, 11)
(33277, 11)
(195665, 11)
(31306, 11)
(19785, 11)
(4713, 11)
(98925, 11)
(50915, 11)
(7731, 11)
(17745, 11)
(8265, 11)
(67016, 11)
(35022, 11)
(15807, 11)
(22006, 11)
(20667, 11)
(40114, 11)
(33042, 11)
(7002, 11)
(50008, 11)
(31021, 11)
(31664, 11)
(13189, 11)
(5463, 11)
(52067, 11)
(4455, 11)
(10785, 11)
(7966, 11)
(47781, 11)
(8711, 11)
(14807, 11)
(103021, 11)
(62135, 11)
(17917, 11)
(21919, 11)
(68308, 11)
(5712, 11)
(24879, 11)
(4899, 11)
(34003, 11)
(135924, 11)
(16337, 11)
(46144, 11)
(3767, 11)
(39944, 11)
(32690, 11)
(8103, 11)
(3064, 11)


In [None]:
S_Train = S_Test_Array[42]
X_Train = X_Test_Array[42]
Y_Train = Y_Test_Array[42]
print(S_Train.shape)

(135924, 2)


In [None]:
# Run FERMI
fermi_instance = FERMI(X_Train, Y_Train, S_Train)
theta_star, W_star = fair_training(fermi_instance, batch_size = 1000, epochs=2000, initial_epochs=300, initial_learning_rate=5e-8, learning_rate_min=5e-10, learning_rate_max=1e-10, lam=1, epsilon=0.1, dr_type='L2')

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
State Name: CT
Accuracy:  60.89461713419257 %
DP Violation:  0.032096839811819045
*********************************************
State Name: DE
Accuracy:  65.79673244218121 %
DP Violation:  0.02053948081212073
*********************************************
State Name: FL
Accuracy:  71.0194591862522 %
DP Violation:  0.028002482257605987
*********************************************
State Name: GA
Accuracy:  69.37641166650299 %
DP Violation:  0.026611682231645534
*********************************************
State Name: HI
Accuracy:  66.21394386237228 %
DP Violation:  0.019976395785122725
*********************************************
State Name: IA
Accuracy:  71.66525781910397 %
DP Violation:  0.026250589643870018
*********************************************
State Name: ID
Accuracy:  73.1881427707199 %
DP Violation:  0.044298890218060175
*********************************************
State Name: IL
Accuracy:  66.4393577653097