In [1]:
import sys
sys.path.append('..')

In [2]:
import torch
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from copy import deepcopy
from torch.utils.data import DataLoader

from src.model import LR
from src.effort import Optimal_Effort
from src.data import FairnessDataset, SyntheticDataset, IncomeDataset
from src.methods import covariance_proxy, fair_batch_proxy

In [3]:
tau = 0.5
alpha = 0.1
lamb = 0.1
dataset = IncomeDataset(num_samples=20000, seed=0)
train_tensors, _, _ = dataset.tensor(0)
X_train, Y_train, Z_train = train_tensors
train_dataset = FairnessDataset(*train_tensors, dataset.imp_feats)

In [4]:
torch.manual_seed(0)
effort = Optimal_Effort(dataset.delta)
proxy = fair_batch_proxy
model = LR(X_train.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.BCELoss(reduction='mean')

In [5]:
d = {'GD Iter': [], 'Batch Iter': [], 'PGA Iter': [], 'Fairness Loss': []}
d2 = {'GD Iter': [], 'Total Loss': [], 'Pred Loss': [], 'Fair Loss': []}


generator = torch.Generator().manual_seed(0)
data_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, generator=generator)

n = 100
for i in range(n):
    # print(f'GD Iter {i}, Pred Loss: {pred_loss.clone().detach().item()}')
    curr_alpha = alpha# * (i/n)
    for b, (X_batch, Y_batch, Z_batch) in enumerate(data_loader):
        
        Y_hat = model(X_batch).reshape(-1)
        pred_loss = loss_fn(Y_hat, Y_batch)
        
        if torch.sum(Y_hat<tau) > 0:
            X_e = X_batch[(Y_hat < tau)]
            Z_e = Z_batch[(Y_hat<tau)]
            
            X_hat_max = effort(model, train_dataset, X_e)
            
            for module in model.layers:
                if hasattr(module, 'weight'):
                    weight_min = module.weight.data - curr_alpha
                    weight_max = module.weight.data + curr_alpha
                if hasattr(module, 'bias'):
                    bias_min = module.bias.data - curr_alpha
                    bias_max = module.bias.data + curr_alpha
            
            model_adv = deepcopy(model)#.bounded_init((weight_min, weight_max), (bias_min, bias_max))
            optimizer_adv = torch.optim.Adam(model_adv.parameters(), maximize=True)

            for j in range(20):
                Y_hat_max_pga = model_adv(X_hat_max).reshape(-1)
                fair_loss_pga = proxy(Z_e, Y_hat_max_pga)

                optimizer_adv.zero_grad()
                fair_loss_pga.backward()
                optimizer_adv.step()
                
                for module in model_adv.layers:
                    if hasattr(module, 'weight'):
                        module.weight.data = module.weight.data.clamp(weight_min, weight_max)
                    if hasattr(module, 'bias'):
                        module.bias.data = module.bias.data.clamp(bias_min, bias_max)
                # print(f'GD Iter {i}, PGA Iter {j}, Fair Loss: {fair_loss_pga.clone().detach().item()}, theta_adv: {model_adv.get_theta().numpy().round(3)}')
                d['GD Iter'].append(str(i))
                d['Batch Iter'].append(str(b))
                d['PGA Iter'].append(j)
                d['Fairness Loss'].append(fair_loss_pga.clone().detach().item())
            
            Y_hat_max = model_adv(X_hat_max).reshape(-1)
            fair_loss = proxy(Z_e, Y_hat_max)
            loss = (1-lamb)*pred_loss + lamb*fair_loss
            
            d2['GD Iter'].append(i)
            d2['Total Loss'].append(loss.clone().detach().item())
            d2['Fair Loss'].append(fair_loss.clone().detach().item())
            d2['Pred Loss'].append(pred_loss.clone().detach().item())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # print(f'theta: {model.get_theta().numpy().round(3)}')
            # print(f'theta_adv: {model_adv.get_theta().numpy().round(3)}')
            # print()

In [11]:
pd.DataFrame(d).groupby(['GD Iter', 'Batch Iter', 'PGA Iter'], as_index=False).mean()

Unnamed: 0,GD Iter,Batch Iter,PGA Iter,Fairness Loss
0,0,0,0,0.244731
1,0,0,1,0.247746
2,0,0,2,0.250780
3,0,0,3,0.253834
4,0,0,4,0.256907
...,...,...,...,...
399915,99,99,15,0.048658
399916,99,99,16,0.049697
399917,99,99,17,0.050735
399918,99,99,18,0.051773


In [9]:
px.line(pd.DataFrame(d).groupby(['GD Iter', 'Batch Iter'], as_index=False).mean(), x='PGA Iter', y='Fairness Loss', animation_frame='GD Iter', markers=True)

In [7]:
px.scatter(pd.DataFrame(d2).groupby('GD Iter', as_index=False).mean(), x='GD Iter', y='Total Loss', color='Fair Loss')