### *Debiasing Deep Chest X-Ray Classifiers using Intra- and Post-processing Methods*: A Tutorial

Welcome to the demo notebook for our paper "*Debiasing Deep Chest X-Ray Classifiers using Intra- and Post-processing Methods*"! Here, we will learn how to apply the proposed **bias gradient descent/ascent (GD/A)** and **pruning** algorithms to a fully connected neural network trained on the **COMPAS data**.

#### Imports

Let's start by importing basic libraries, such as NumPy, pandas, and PyTorch:

In [1]:
# Imports
import numpy as np
import pandas as pd
import random
import torch
import sys
sys.path.insert(0, '../')

from torch import nn
import torch.nn.functional as F

#### Repeatability

Set the seeds for all relevant pseudorandom number generators:

In [2]:
seed = 220825
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fe4e67585d0>

#### COMPAS Data

We will use the version of the [COMPAS dataset](https://www.propublica.org/article/how-we-analyzed-the-compas-recidivism-algorithm) provided within the [AI Fairness 360 library](https://aif360.mybluemix.net/):

In [3]:
from aif360.datasets import CompasDataset

# Load the data
dataset = CompasDataset()
# Train-validation-test split
dataset_train, dataset_vt = dataset.split([0.6], shuffle=True, seed=seed)
dataset_valid, dataset_test = dataset_vt.split([0.5], shuffle=True, seed=seed)

2022-08-25 16:41:53.783573: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-08-25 16:41:53.783587: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


Let's have a look at the data:

In [4]:
# View data as a pandas DataFrame 
dataset.convert_to_dataframe()[0]

Unnamed: 0,sex,age,race,juv_fel_count,juv_misd_count,juv_other_count,priors_count,age_cat=25 - 45,age_cat=Greater than 45,age_cat=Less than 25,...,c_charge_desc=Viol Injunct Domestic Violence,c_charge_desc=Viol Injunction Protect Dom Vi,c_charge_desc=Viol Pretrial Release Dom Viol,c_charge_desc=Viol Prot Injunc Repeat Viol,c_charge_desc=Violation License Restrictions,c_charge_desc=Violation Of Boater Safety Id,c_charge_desc=Violation of Injunction Order/Stalking/Cyberstalking,c_charge_desc=Voyeurism,c_charge_desc=arrest case no charge,two_year_recid
1,0.0,69.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,34.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
4,0.0,24.0,0.0,0.0,0.0,1.0,4.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
7,0.0,44.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,0.0,41.0,1.0,0.0,0.0,0.0,14.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10996,0.0,23.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
10997,0.0,23.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
10999,0.0,57.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
11000,1.0,33.0,0.0,0.0,0.0,0.0,3.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Instead of using the AIF 360 dataset directly, we will work with the wrapper class `TabularData`:

In [5]:
from utils.data_utils import TabularData

# Device used to store tensors -- we will use CPU this time
device = torch.device('cpu')

# Arguments for the TabularData wrapper: which dataset to load and which attribute is protected
args = {'dataset': 'compas', 'protected': 'race'}
# Create a wrapper object
dataset = TabularData(config=args, seed=seed, device=device)



All the attribute names of the wrapper:

In [6]:
dataset.__dict__.keys()

dict_keys(['train', 'valid', 'test', 'priv', 'unpriv', 'X_train', 'y_train', 'p_train', 'X_valid', 'X_valid_gpu', 'y_valid', 'y_valid_gpu', 'p_valid', 'p_valid_gpu', 'X_valid_train', 'X_valid_valid', 'y_valid_train', 'y_valid_valid', 'p_valid_train', 'p_valid_valid', 'X_test', 'X_test_gpu', 'y_test', 'y_test_gpu', 'p_test', 'p_test_gpu', 'num_features'])

#### Standard Model

Let's build a simple fully connected neural network for binary classification! Below is a PyTorch implementation  of a three-layer perceptron with 401 inputs, 500 hidden units in each intermediate layer, and a single output:

In [7]:
# A simple FCNN
# NOTE: this architecture differs from the one studied in the original paper
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc0 = nn.Linear(401, 500)
        self.fcs = nn.ModuleList([nn.Linear(500, 500)])
        self.out = nn.Linear(500, 1)

    def forward(self, t):
        t = F.relu(self.fc0(t))
        for fc in self.fcs:
            t = F.relu(fc(t))
        return torch.sigmoid(self.out(t))

Let's train our model on the COMPAS data! Below, we use the function `train_model` with the routine described in the original paper; however, you can easily implement your preferred procedure:

In [8]:
from models.networks_tabular import train_model

# Initialise an FCNN and train it for a maximum of 300 epochs
model = Model()
train_model(model=model, data=dataset, epochs=300)

#### Model Evaluation

Let's now evaluate our classifier w.r.t. predictive performance and classification parity:

In [9]:
model.eval()

# Predict on the validation and test sets
with torch.no_grad():
    valid_pred = dataset.valid.copy(deepcopy=True)
    valid_pred.scores = model(dataset.X_valid)[:, 0].reshape(-1, 1).numpy()
    valid_pred.labels = np.array(valid_pred.scores > 0.5)

    test_pred = dataset.test.copy(deepcopy=True)
    test_pred.scores = model(dataset.X_test)[:, 0].reshape(-1, 1).numpy()
    test_pred.labels = np.array(test_pred.scores > 0.5)

In [10]:
from sklearn.metrics import balanced_accuracy_score

# Choose a classification threshold maximising balanced accuracy on the validation data
threshs = np.linspace(0, 1, 101)
performances = []
for thresh in threshs:
    perf = balanced_accuracy_score(dataset.y_valid, valid_pred.scores > thresh)
    performances.append(perf)
best_thresh = threshs[np.argmax(performances)]

In [11]:
from utils.evaluation import compute_empirical_bias

# Evaluate test-set balanced accuracy and statistical parity difference
print('Balanced accuracy: %.3f' % balanced_accuracy_score(dataset.y_test, test_pred.scores > best_thresh))
print('SPD: %.3f' % compute_empirical_bias((test_pred.scores > best_thresh) * 1., dataset.y_test.numpy(), 
                                           dataset.p_test, metric='spd'))

Balanced accuracy: 0.651
SPD: 0.113


As we can see, the statistical parity difference is considerably $> 0$. If the SPD is an appropriate criterion, we can try reducing it using a debiasing algorithm!

#### Bias Gradient Descent/Ascent (GD/A)

Bias gradient descent/ascent fine-tunes the network by minimising/maximising a differentiable proxy of the bias measure:

In [12]:
from algorithms.biasGrad import bias_gradient_decent

# Additional arguments for the bias GD/A
args['experiment_name'] = 'demo_experiment'
args['metric'] = 'spd'
args['acc_metric'] = 'balanced_accuracy'
args['objective'] = {}
args['objective']['sharpness'] = 500
args['objective']['epsilon'] = 0.05
args['biasGrad'] = {}
# Learning rate: reasonably small rates work well
args['biasGrad']['lr'] = 1e-5
# The maximum number of fine-tuning epochs
args['biasGrad']['n_epochs'] = 50
# Mini-batch size
args['biasGrad']['batch_size'] = 256
# Use only the validation set for debiasing? 
args['biasGrad']['val_only'] = True
# Lower bound on balanced accuracy: parameter ϱ from the original paper
args['biasGrad']['obj_lb'] = 0.61
# Number of evaluations per epoch (used for early stopping)
args['biasGrad']['n_evals'] = 1

# Fine-tune the network using bias GD/A
# NOTE: since the standard model's SPD > 0, we are performing gradient descent
model_GDA = bias_gradient_decent(model=model, data=dataset, config=args, seed=seed, plot=False, display=False, 
                              asc=False)

  0% |                                                                                                                                                                                                      |

Performing bias gradient ascent/descent...



 98% |##################################################################################################################################################################################################    |






Let's now evaluate the fine-tuned model:

In [13]:
model_GDA.eval()
with torch.no_grad():
    valid_pred_GDA = dataset.valid.copy(deepcopy=True)
    valid_pred_GDA.scores = model_GDA(dataset.X_valid)[:, 0].reshape(-1, 1).numpy()
    valid_pred_GDA.labels = np.array(valid_pred_GDA.scores > 0.5)

    test_pred_GDA = dataset.test.copy(deepcopy=True)
    test_pred_GDA.scores = model_GDA(dataset.X_test)[:, 0].reshape(-1, 1).numpy()
    test_pred_GDA.labels = np.array(test_pred_GDA.scores > 0.5)

threshs = np.linspace(0, 1, 101)
performances = []
for thresh in threshs:
    perf = balanced_accuracy_score(dataset.y_valid, valid_pred_GDA.scores > thresh)
    performances.append(perf)
best_thresh_GDA = threshs[np.argmax(performances)]

print('Balanced accuracy: %.3f' % balanced_accuracy_score(dataset.y_test, test_pred_GDA.scores > best_thresh_GDA))
print('SPD: %.3f' % compute_empirical_bias((test_pred_GDA.scores > best_thresh_GDA) * 1., dataset.y_test.numpy(), 
                                           dataset.p_test, metric='spd'))

Balanced accuracy: 0.635
SPD: -0.019


We have sacrificed some balanced accuracy, but the SPD is much lower now!

#### Pruning for Debiasing

Another algorithm proposed by us is pruning, a dropout-like removal of individual units from the network based on their contribution to the differentiable bias proxy. Let's try it out:

In [14]:
from algorithms.pruning import prune_fc

# Additional arguments for pruning
args['pruning'] = {}
# Re-evaluate unit influence after every pruning step?
args['pruning']['dynamic'] = True
# How many units are pruned per step?
args['pruning']['step_size'] = 1
# Stop pruning early if the balanced accuracy drops close to random?
args['pruning']['stop_early'] = True
# Use only the validation set for debiasing? 
args['pruning']['val_only'] = True
# Lower bound on balanced accuracy: parameter ϱ from the original paper
args['pruning']['obj_lb'] = 0.55

# Prune the network
# NOTE: this might take a while. To speed up, increase the step size parameter
model_pruned = prune_fc(model=model, data=dataset, config=args, seed=seed, plot=False, display=False)

Pruning the network...



 99% |##################################################################################################################################################################################################### |



Evaluate the pruned model:

In [15]:
model_pruned.eval()
with torch.no_grad():
    valid_pred_pruned = dataset.valid.copy(deepcopy=True)
    valid_pred_pruned.scores = model_pruned(dataset.X_valid)[:, 0].reshape(-1, 1).numpy()
    valid_pred_pruned.labels = np.array(valid_pred_pruned.scores > 0.5)

    test_pred_pruned = dataset.test.copy(deepcopy=True)
    test_pred_pruned.scores = model_pruned(dataset.X_test)[:, 0].reshape(-1, 1).numpy()
    test_pred_pruned.labels = np.array(test_pred_pruned.scores > 0.5)

threshs = np.linspace(0, 1, 101)
performances = []
for thresh in threshs:
    perf = balanced_accuracy_score(dataset.y_valid, valid_pred_pruned.scores > thresh)
    performances.append(perf)
best_thresh_pruned = threshs[np.argmax(performances)]

print('Balanced accuracy: %.3f' % balanced_accuracy_score(dataset.y_test, 
                                                          test_pred_pruned.scores > best_thresh_pruned))
print('SPD: %.3f' % compute_empirical_bias((test_pred_pruned.scores > best_thresh_pruned) * 1., 
                                           dataset.y_test.numpy(), dataset.p_test, metric='spd'))

Balanced accuracy: 0.629
SPD: -0.035


Like the bias GD/A, pruning reduces the SPD, having a negligible effect on the BA.