In [None]:
!pip install fairlearn



In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
from fairlearn.metrics import demographic_parity_difference, equalized_odds_difference

In [None]:
class FairLoss(nn.Module):
  """A custom loss function that combines Binary Cross-Entropy (BCE) with fairness penalty terms.

  The loss penalizes both demographic parity violations and equalized odds violations.
  Loss = BCE + λDP × DP_penalty + λEO × EO_penalty

  Args:
      lambda_dp (float): Weight for the demographic parity penalty. Defaults to 2.0.
      lambda_eo (float): Weight for the equalized odds penalty. Defaults to 2.0.
  """
  def __init__(self, lambda_dp=2.0, lambda_eo=2.0):
    super().__init__()
    self.lambda_dp = lambda_dp
    self.lambda_eo = lambda_eo
    self.bce = nn.BCEWithLogitsLoss()

  def forward(self, logits, y_true, race):
    """Calculates the total loss including BCE, demographic parity, and equalized odds penalties.

    Args:
        logits (torch.Tensor): Raw model outputs (before sigmoid).
        y_true (torch.Tensor): True binary labels.
        race (torch.Tensor): Sensitive attribute (e.g., race) with binary values (0 or 1).

    Returns:
        tuple: A tuple containing total_loss, bce_loss, dp_loss, and eo_loss.

    DP_penalty measures the squared difference in mean predicted probabilities between racial groups.
    EO_penalty averages the squared differences in True Positive Rates (TPR) and False Positive Rates (FPR)
    between racial groups for positive and negative classes separately.
    """
    logits = logits.squeeze()
    y_true = y_true.float()
    race = race.float()
    probs = torch.sigmoid(logits)

    bce_loss = self.bce(logits, y_true)

    mask_0 = (race == 0)
    mask_1 = (race == 1)

    if mask_0.sum() > 0 and mask_1.sum() > 0:
      dp_loss = (probs[mask_0].mean() - probs[mask_1].mean()) ** 2
    else:
      dp_loss = torch.tensor(0.0)

    mask_pos = (y_true == 1)
    mask_neg = (y_true == 0)

    mask_0_pos = mask_0 & mask_pos
    mask_1_pos = mask_1 & mask_pos
    mask_0_neg = mask_0 & mask_neg
    mask_1_neg = mask_1 & mask_neg

    if mask_0_pos.sum() > 0 and mask_1_pos.sum() > 0:
      tpr_loss = (probs[mask_0_pos].mean() - probs[mask_1_pos].mean()) ** 2
    else:
      tpr_loss = torch.tensor(0.0)

    if mask_0_neg.sum() > 0 and mask_1_neg.sum() > 0:
      fpr_loss = (probs[mask_0_neg].mean() - probs[mask_1_neg].mean()) ** 2
    else:
      fpr_loss = torch.tensor(0.0)

    eo_loss = (tpr_loss + fpr_loss) / 2
    total_loss = bce_loss + self.lambda_dp * dp_loss + self.lambda_eo * eo_loss

    return total_loss, bce_loss.item(), dp_loss.item(), eo_loss.item()

In [None]:
class SimpleNN(nn.Module):
    """A 4-layer feedforward neural network.

    The network consists of an input layer (2 features), two hidden layers with ReLU activation
    and dropout, and an output layer for binary classification.
    """
    def __init__(self):
      super().__init__()
      self.net = nn.Sequential(
        nn.Linear(2, 32),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(32, 16),
        nn.ReLU(),
        nn.Linear(16, 1)
      )

    def forward(self, x):
      """Performs a forward pass through the neural network.

      Args:
          x (torch.Tensor): Input features.

      Returns:
          torch.Tensor: Raw output logits from the network.
      """
      return self.net(x)

In [None]:
df = pd.read_csv("/content/lsac_data.csv")
if "ZFYGPA" in df.columns:
    df = df.rename(columns={"ZFYGPA": "zfygpa"})

needed = ["race", "gender", "lsat", "ugpa", "zfygpa"]

for c in needed:
    df[c] = pd.to_numeric(df[c], errors="coerce")

# Label: above median zfygpa
cutoff_value = np.median(df["zfygpa"])
df["admit_sim"] = (df["zfygpa"] >= cutoff_value).astype(int)

# Features (drop label, zfygpa, protected attrs)
X = df.select_dtypes(include="number").drop(
    columns=["zfygpa", "admit_sim", "race", "gender"], errors="ignore"
)
y = df["admit_sim"].astype(int)
gender = df["gender"].astype(int)
race = df["race"].astype(int)

X_train, X_test, y_train, y_test, g_train, g_test, r_train, r_test = train_test_split(
    X, y, gender, race, stratify=y, test_size=0.25, random_state=42
)

X_train = torch.FloatTensor(X_train.to_numpy())
X_test = torch.FloatTensor(X_test.to_numpy())
y_train = torch.FloatTensor(y_train.to_numpy())
y_test = torch.FloatTensor(y_test.to_numpy())
r_train = torch.LongTensor(r_train.to_numpy())
r_test = torch.LongTensor(r_test.to_numpy())

model = SimpleNN()

In [None]:
criterion = FairLoss(lambda_dp=0.05, lambda_eo=0.05)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
batch_size = 256
epochs = 500

for epoch in range(epochs):
    model.train()
    indices = torch.randperm(len(X_train))

    for i in range(0, len(X_train), batch_size):
        idx = indices[i:i+batch_size]
        X_batch = X_train[idx]
        y_batch = y_train[idx]
        r_batch = r_train[idx]

        optimizer.zero_grad()
        logits = model(X_batch)
        loss, bce, dp, eo = criterion(logits, y_batch, r_batch)
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1}: BCE={bce:.4f}, DP={dp:.4f}, EO={eo:.4f}, Total={loss.item():.4f}")

model.eval()
with torch.no_grad():
    logits_test = model(X_test)
    probs_test = torch.sigmoid(logits_test).squeeze().numpy()
    preds_test = (probs_test > 0.5).astype(int)

acc = accuracy_score(y_test.numpy(), preds_test)
auc = roc_auc_score(y_test.numpy(), probs_test)
dp = demographic_parity_difference(y_test.numpy(), preds_test, sensitive_features=r_test.numpy())
eo = equalized_odds_difference(y_test.numpy(), preds_test, sensitive_features=r_test.numpy())

print(f"\nNeural Network (Original Labels):")
print(f"Accuracy: {acc:.3f} | AUC: {auc:.3f}")
print(f"DP difference: {dp:.3f} | EO difference: {eo:.3f}")

Epoch 20: BCE=0.6652, DP=0.0140, EO=0.0178, Total=0.6668
Epoch 40: BCE=0.6800, DP=0.0294, EO=0.0196, Total=0.6824
Epoch 60: BCE=0.6434, DP=0.0813, EO=0.0345, Total=0.6492
Epoch 80: BCE=0.6677, DP=0.0185, EO=0.0189, Total=0.6696
Epoch 100: BCE=0.6683, DP=0.0431, EO=0.0425, Total=0.6725
Epoch 120: BCE=0.6347, DP=0.0760, EO=0.0557, Total=0.6412
Epoch 140: BCE=0.6925, DP=0.0425, EO=0.0521, Total=0.6972
Epoch 160: BCE=0.6594, DP=0.0728, EO=0.0328, Total=0.6647
Epoch 180: BCE=0.6435, DP=0.0598, EO=0.0231, Total=0.6476
Epoch 200: BCE=0.6530, DP=0.0492, EO=0.0182, Total=0.6564
Epoch 220: BCE=0.6520, DP=0.0171, EO=0.0052, Total=0.6531
Epoch 240: BCE=0.6856, DP=0.0668, EO=0.0575, Total=0.6919
Epoch 260: BCE=0.6451, DP=0.0284, EO=0.0103, Total=0.6471
Epoch 280: BCE=0.6449, DP=0.0550, EO=0.0233, Total=0.6488
Epoch 300: BCE=0.6545, DP=0.0689, EO=0.0613, Total=0.6610
Epoch 320: BCE=0.6751, DP=0.0353, EO=0.0275, Total=0.6783
Epoch 340: BCE=0.7123, DP=0.0527, EO=0.0811, Total=0.7190
Epoch 360: BCE=0.6