<a href="https://colab.research.google.com/github/dolmani38/papers/blob/main/%EC%88%98%EC%A0%95-FR-Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/yuji-roh/fr-train.git

Cloning into 'fr-train'...
remote: Enumerating objects: 31, done.[K
remote: Counting objects: 100% (31/31), done.[K
remote: Compressing objects: 100% (29/29), done.[K
remote: Total 31 (delta 9), reused 4 (delta 0), pack-reused 0[K
Unpacking objects: 100% (31/31), 168.32 KiB | 3.18 MiB/s, done.


In [2]:
%cd fr-train
!pwd

/content/fr-train
/content/fr-train


# FR-Train on poisoned synthetic data

## Import libraries

In [3]:
import sys, os
import numpy as np
import math

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch

import math
import matplotlib.pyplot as plt

from argparse import Namespace

from FRTrain_arch import Generator, DiscriminatorF, DiscriminatorR, weights_init_normal, test_model

import warnings
warnings.filterwarnings("ignore")


## Load and process data (using poisoned y train label)

In [4]:
# a namespace object which contains some of the hyperparameters
opt = Namespace(num_train=2000, num_val1=200, num_val2=500, num_test=1000)

In [5]:
num_train = opt.num_train
num_val1 = opt.num_val1
num_val2 = opt.num_val2
num_test = opt.num_test

X = np.load('X_synthetic.npy') # Input features
y = np.load('y_synthetic.npy') # Original labels
y_poi = np.load('y_poi.npy') # Poisoned train labels
s1 = np.load('s1_synthetic.npy') # Sensitive features

X = torch.FloatTensor(X)
y = torch.FloatTensor(y)
y_poi = torch.FloatTensor(y_poi)
s1 = torch.FloatTensor(s1)

X_train = X[:num_train - num_val1]
y_train = y_poi[:num_train - num_val1] # Poisoned label
s1_train = s1[:num_train - num_val1]

X_val = X[num_train: num_train + num_val1]
y_val = y[num_train: num_train + num_val1]
s1_val = s1[num_train: num_train + num_val1]

# Currently not used
# X_val2 = X[num_train + num_val1 : num_train + num_val1 + num_val2]
# y_val2 = y[num_train + num_val1 : num_train + num_val1 + num_val2]
# s1_val2 = s1[num_train + num_val1 : num_train + num_val1 + num_val2]

X_test = X[num_train + num_val1 + num_val2 : num_train + num_val1 + num_val2 + num_test]
y_test = y[num_train + num_val1 + num_val2 : num_train + num_val1 + num_val2 + num_test]
s1_test = s1[num_train + num_val1 + num_val2 : num_train + num_val1 + num_val2 + num_test]

XS_train = torch.cat([X_train, s1_train.reshape((s1_train.shape[0], 1))], dim=1)
XS_val = torch.cat([X_val, s1_val.reshape((s1_val.shape[0], 1))], dim=1)
XS_test = torch.cat([X_test, s1_test.reshape((s1_test.shape[0], 1))], dim=1)

In [6]:
print("--------------------- Number of Data -------------------------" )
print(
    "Train data : %d, Validation data : %d, Test data : %d "
    % (len(y_train), len(y_val), len(y_test))
)       
print("--------------------------------------------------------------")

--------------------- Number of Data -------------------------
Train data : 1800, Validation data : 200, Test data : 1000 
--------------------------------------------------------------


In [16]:
print(X_train.shape)
print(y_train.shape)

torch.Size([1800, 2])
torch.Size([1800])


## Training with poisoned  data

In [10]:
def train_model(train_tensors, val_tensors, test_tensors, train_opt, lambda_f, seed):
    """
      Trains FR-Train by using the classes in FRTrain_arch.py.
      
      Args:
        train_tensors: Training data.
        val_tensors: Clean validation data.
        test_tensors: Test data.
        train_opt: Options for the training. It currently contains size of validation set, 
                number of epochs, generator/discriminator update ratio, and learning rates.
        lambda_f: The tuning knob for L_2 (ref: FR-Train paper, Section 3.3).
        #lambda_r: The tuning knob for L_3 (ref: FR-Train paper, Section 3.3).
        seed: An integer value for specifying torch random seed.
        
      Returns:
        Information about the tuning knobs (lambda_f, lambda_r),
        the test accuracy of the trained model, and disparate impact of the trained model.
    """
    
    XS_train = train_tensors.XS_train
    y_train = train_tensors.y_train
    s1_train = train_tensors.s1_train
    
    XS_val = val_tensors.XS_val
    y_val = val_tensors.y_val
    s1_val = val_tensors.s1_val
    
    XS_test = test_tensors.XS_test
    y_test = test_tensors.y_test
    s1_test = test_tensors.s1_test
    
    # Saves return values here
    test_result = [] 
    
    val = train_opt.val # Number of data points in validation set
    k = train_opt.k     # Update ratio of generator and discriminator (1:k training).
    n_epochs = train_opt.n_epochs  # Number of training epoch
    
    # Changes the input validation data to an appropriate shape for the training
    XSY_val = torch.cat([XS_val, y_val.reshape((y_val.shape[0], 1))], dim=1)  

    # The loss values of each component will be saved in the following lists. 
    # We can draw epoch-loss graph by the following lists, if necessary.
    g_losses =[]
    d_f_losses = []
    #d_r_losses = []
    clean_test_result = []

    bce_loss = torch.nn.BCELoss()

    # Initializes generator and discriminator
    generator = Generator()
    discriminator_F = DiscriminatorF()
    #discriminator_R = DiscriminatorR()

    # Initializes weights
    torch.manual_seed(seed)
    generator.apply(weights_init_normal)
    discriminator_F.apply(weights_init_normal)
    #discriminator_R.apply(weights_init_normal)

    optimizer_G = torch.optim.Adam(generator.parameters(), lr=train_opt.lr_g)
    optimizer_D_F = torch.optim.SGD(discriminator_F.parameters(), lr=train_opt.lr_f)
    #optimizer_D_R = torch.optim.SGD(discriminator_R.parameters(), lr=train_opt.lr_r)

    XSY_val_data = XSY_val[:val]

    train_len = XS_train.shape[0]
    val_len = XSY_val.shape[0]

    # Ground truths using in Disriminator_R
    Tensor = torch.FloatTensor
    valid = Variable(Tensor(train_len, 1).fill_(1.0), requires_grad=False)
    generated = Variable(Tensor(train_len, 1).fill_(0.0), requires_grad=False)
    fake = Variable(Tensor(train_len, 1).fill_(0.0), requires_grad=False)
    clean = Variable(Tensor(val_len, 1).fill_(1.0), requires_grad=False)
    

    #r_weight = torch.ones_like(y_train, requires_grad=False).float()
    #r_ones = torch.ones_like(y_train, requires_grad=False).float()

    for epoch in range(n_epochs):

        # -------------------
        #  Forwards Generator
        # -------------------
        if epoch % k == 0 or epoch < 500:
            optimizer_G.zero_grad()

        gen_y = generator(XS_train).reshape(-1,1)
        gen_data = torch.cat([XS_train, gen_y.reshape((gen_y.shape[0], 1))], dim=1)


        # -------------------------------
        #  Trains Fairness Discriminator
        # -------------------------------

        optimizer_D_F.zero_grad()
        
        # Discriminator_F tries to distinguish the sensitive groups by using the output of the generator.
        d_f_loss = bce_loss(discriminator_F(gen_y.detach()), s1_train.reshape(-1,1))
        d_f_loss.backward()
        d_f_losses.append(d_f_loss)
        optimizer_D_F.step()
            
            
        # ---------------------------------
        #  Trains Robustness Discriminator
        # ---------------------------------
        '''
        optimizer_D_R.zero_grad()

        # Discriminator_R tries to distinguish whether the input is from the validation data or the generated data from generator.
        clean_loss =  bce_loss(discriminator_R(XSY_val_data), clean)
        poison_loss = bce_loss(discriminator_R(gen_data.detach()), fake)
        d_r_loss = 0.5 * (clean_loss + poison_loss)

        d_r_loss.backward()
        d_r_losses.append(d_r_loss)
        optimizer_D_R.step()
        '''
        
        # ---------------------
        #  Updates Generator
        # ---------------------


        if epoch < 500 :
            g_loss = 0.1 * bce_loss((F.tanh(gen_y)+1)/2, (y_train.reshape(-1,1)+1)/2)
            g_loss.backward()
            g_losses.append(g_loss)
            optimizer_G.step()
        elif epoch % k == 0:
            #r_decision = discriminator_R(gen_data)
            #r_gen = bce_loss(r_decision, generated)
            
            # ---------------------------------
            #  Re-weights using output of D_R
            # ---------------------------------
            #if epoch % 100 == 0:
            #    loss_ratio = (g_losses[-1]/d_r_losses[-1]).detach()
            #    a = 1/(1+torch.exp(-(loss_ratio-3)))
            #    b = 1-a
            #    r_weight_tmp = r_decision.detach().squeeze()
            #    r_weight = a * r_weight_tmp + b * r_ones

            f_cost = F.binary_cross_entropy(discriminator_F(gen_y), s1_train.reshape(-1,1), reduction="none").squeeze()
            g_cost = F.binary_cross_entropy_with_logits(gen_y.squeeze(), (y_train.squeeze()+1)/2, reduction="none").squeeze()

            #f_gen = torch.mean(f_cost*r_weight)
            f_gen = torch.mean(f_cost)
            #g_loss = (1-lambda_f-lambda_r) * torch.mean(g_cost*r_weight) - lambda_f * f_gen -  lambda_r * r_gen 
            g_loss = (1-lambda_f) * torch.mean(g_cost) - lambda_f * f_gen 

            g_loss.backward()
            optimizer_G.step()


        g_losses.append(g_loss)

        if epoch % 200 == 0:
            print(
                    "[Lambda: %1f] [Epoch %d/%d] [D_F loss: %f] [G loss: %f]"
                    % (lambda_f, epoch, n_epochs, d_f_losses[-1], g_losses[-1])
                )

#     torch.save(generator.state_dict(), './FR-Train_on_poi_synthetic.pth')
    tmp = test_model(generator, XS_test, y_test, s1_test)
    test_result.append([lambda_f, tmp[0].item(), tmp[1]])

    return test_result

In [11]:
train_result = []
train_tensors = Namespace(XS_train = XS_train, y_train = y_train, s1_train = s1_train)
val_tensors = Namespace(XS_val = XS_val, y_val = y_val, s1_val = s1_val) 
test_tensors = Namespace(XS_test = XS_test, y_test = y_test, s1_test = s1_test)

train_opt = Namespace(val=len(y_val), n_epochs=10000, k=5, lr_g=0.001, lr_f=0.001, lr_r=0.001)
seed = 1

lambda_f_set = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.52] # Lambda value for the fairness discriminator of FR-Train.
#lambda_r = 0.4 # Lambda value for the robustness discriminator of FR-Train.

for lambda_f in lambda_f_set:
    train_result.append(train_model(train_tensors, val_tensors, test_tensors, train_opt, lambda_f = lambda_f, seed = seed))

[Lambda: 0.100000] [Epoch 0/10000] [D_F loss: 0.686340] [G loss: 0.060358]
[Lambda: 0.100000] [Epoch 200/10000] [D_F loss: 0.682480] [G loss: 0.057829]
[Lambda: 0.100000] [Epoch 400/10000] [D_F loss: 0.679608] [G loss: 0.057678]
[Lambda: 0.100000] [Epoch 600/10000] [D_F loss: 0.669466] [G loss: 0.456159]
[Lambda: 0.100000] [Epoch 800/10000] [D_F loss: 0.666391] [G loss: 0.452935]
[Lambda: 0.100000] [Epoch 1000/10000] [D_F loss: 0.664273] [G loss: 0.452936]
[Lambda: 0.100000] [Epoch 1200/10000] [D_F loss: 0.661742] [G loss: 0.453039]
[Lambda: 0.100000] [Epoch 1400/10000] [D_F loss: 0.659712] [G loss: 0.453157]
[Lambda: 0.100000] [Epoch 1600/10000] [D_F loss: 0.657988] [G loss: 0.453279]
[Lambda: 0.100000] [Epoch 1800/10000] [D_F loss: 0.656552] [G loss: 0.453396]
[Lambda: 0.100000] [Epoch 2000/10000] [D_F loss: 0.655381] [G loss: 0.453505]
[Lambda: 0.100000] [Epoch 2200/10000] [D_F loss: 0.654453] [G loss: 0.453603]
[Lambda: 0.100000] [Epoch 2400/10000] [D_F loss: 0.653740] [G loss: 0.4

In [9]:
train_tensors
print("-----------------------------------------------------------------------------------")
print("------------------ Training Results of FR-Train on poisoned data ------------------" )
for i in range(len(train_result)):
    print(
        "[Lambda_f: %.2f] [Lambda_r: %.2f] Accuracy : %.3f, Disparate Impact : %.3f "
        % (train_result[i][0][0], train_result[i][0][1], train_result[i][0][2], train_result[i][0][3])
    )       
print("-----------------------------------------------------------------------------------")

-----------------------------------------------------------------------------------
------------------ Training Results of FR-Train on poisoned data ------------------
[Lambda_f: 0.10] [Lambda_r: 0.40] Accuracy : 0.842, Disparate Impact : 0.657 
[Lambda_f: 0.15] [Lambda_r: 0.40] Accuracy : 0.835, Disparate Impact : 0.704 
[Lambda_f: 0.20] [Lambda_r: 0.40] Accuracy : 0.833, Disparate Impact : 0.722 
[Lambda_f: 0.25] [Lambda_r: 0.40] Accuracy : 0.827, Disparate Impact : 0.745 
[Lambda_f: 0.30] [Lambda_r: 0.40] Accuracy : 0.824, Disparate Impact : 0.760 
[Lambda_f: 0.35] [Lambda_r: 0.40] Accuracy : 0.821, Disparate Impact : 0.764 
[Lambda_f: 0.40] [Lambda_r: 0.40] Accuracy : 0.821, Disparate Impact : 0.770 
[Lambda_f: 0.45] [Lambda_r: 0.40] Accuracy : 0.810, Disparate Impact : 0.786 
[Lambda_f: 0.52] [Lambda_r: 0.40] Accuracy : 0.814, Disparate Impact : 0.827 
-----------------------------------------------------------------------------------


In [13]:
train_tensors
print("-----------------------------------------------------------------------------------")
print("------------------ Training Results of FR-Train on poisoned data ------------------" )
for i in range(len(train_result)):
    print(
        "[Lambda_f: %.2f] Accuracy : %.3f, Disparate Impact : %.3f "
        % (train_result[i][0][0], train_result[i][0][1], train_result[i][0][2])
    )       
print("-----------------------------------------------------------------------------------")

-----------------------------------------------------------------------------------
------------------ Training Results of FR-Train on poisoned data ------------------
[Lambda_f: 0.10] Accuracy : 0.813, Disparate Impact : 0.556 
[Lambda_f: 0.15] Accuracy : 0.802, Disparate Impact : 0.606 
[Lambda_f: 0.20] Accuracy : 0.790, Disparate Impact : 0.672 
[Lambda_f: 0.25] Accuracy : 0.783, Disparate Impact : 0.706 
[Lambda_f: 0.30] Accuracy : 0.781, Disparate Impact : 0.721 
[Lambda_f: 0.35] Accuracy : 0.778, Disparate Impact : 0.741 
[Lambda_f: 0.40] Accuracy : 0.772, Disparate Impact : 0.766 
[Lambda_f: 0.45] Accuracy : 0.768, Disparate Impact : 0.784 
[Lambda_f: 0.52] Accuracy : 0.767, Disparate Impact : 0.815 
-----------------------------------------------------------------------------------
