In [4]:
from scipy.optimize import fsolve
from tqdm import tqdm
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
import itertools
import ipywidgets as widgets
from IPython.display import display
np.random.seed(1)

#Vectorized optimal step
def opt_step(X, u_plus, u_minus, c_plus, c_minus):

    # Samples
    # All operations are performed on array-like structures
    X = np.asarray(X)

    # Calculate utility and predicted change of all samples
    exp_util = expected(X, u_plus, u_minus) 
    delta_x = expected(X, c_plus, c_minus)

    # Add predicted change to sample if utility is positive else return sample
    max_util = np.where(exp_util > 0, X + delta_x, X) 

    # Smallest threshold where utility is positive
    thresholds = X[exp_util > 0]
    opt_thresh = np.min(thresholds) if thresholds.size > 0 else None

    # array of updated samples, min sample whose util is positive
    return (max_util, opt_thresh)

def p(x) -> float:
    return 1 / (1 + np.exp(-x))

def expected(x, plus, minus):
    return p(x) * plus + (1-p(x)) * minus

def f(x):
    return expected(x, plus=1, minus=-1)

# Solve for x such that f(x) = 0
root = fsolve(f, x0=0)

def alt_fair_step(A, B, u_plus, u_minus, c_plus, c_minus, alpha, range, size):
    begin = range[0]
    end = range[1]
    thresholds = np.arange(begin, end, size)

    max_util = -np.inf
    opt_A = None
    opt_B = None
    thresh_A = None
    thresh_B = None 

    w_a = len(A) / (len(A) + len(B))
    w_b = 1 - w_a
    delta_A = expected(A, c_plus, c_minus)
    delta_B = expected(B, c_plus, c_minus)

    for threshold_A in thresholds:
            for threshold_B in thresholds:
                a = np.where(A + delta_A > threshold_A, A + delta_A, A )
                b = np.where(B + delta_B > threshold_B, B + delta_B, B )
                if np.abs( np.mean(a) - np.mean(b) ) >= alpha:
                    continue
                util = w_a * np.sum(expected(a, u_plus, u_minus)) + w_b * np.sum(expected(b, u_plus, u_minus))
                if util >= max_util: 
                    max_util = util
                    opt_A = a
                    opt_B = b
                    thresh_A = threshold_A
                    thresh_B = threshold_B

    return opt_A, opt_B, thresh_A, thresh_B, max_util

# Distribution parameters
mean_a, std_a = 0.5, 1
mean_b, std_b = 0, 1

# Domain of alpha values to test
# Testing num values of alpha from alpha_min to alpha_max
alpha_min, alpha_max = 0.4, 0.6
alphas = np.linspace(alpha_min, alpha_max, num=25)

# Sample size of each distribution
sample_size = 50

# Generate samples from normal distributions
np.random.seed(1)
a = np.random.normal(mean_a, std_a, sample_size)
b = np.random.normal(mean_b, std_b, sample_size)

# Store results
x = []
fair_mean_A = []
fair_mean_B = []
fair_threshold_A = []
fair_threshold_B = []
total_util = []

# Unconstrained optimal means
opt_result_A = opt_step(a, 1, -1, 1, -1)[0]
opt_result_B = opt_step(b, 1, -1, 1, -1)[0]
y3 = np.mean(opt_result_A)
y4 = np.mean(opt_result_B)
y5 = root
y6 = root

# Run fair optimization over a range of alpha values
for alpha in tqdm(alphas):
    results = alt_fair_step(a, b, 1, -1, 1, -1, alpha, [-3,3], 0.02)
    if results[0] is not None and results[1] is not None:
        A, B, thresh_A, thresh_B, max_util = results
        x.append(alpha)
        
        fair_mean_A.append(np.mean(A))
        fair_mean_B.append(np.mean(B))
        fair_threshold_A.append(thresh_A)
        fair_threshold_B.append(thresh_B)
        total_util.append(max_util * 0.1)  # Scale utility for better visualization

def plot_fairness(show_means=True, show_thresholds=True, show_utility=True):
    plt.figure(figsize=(10, 6))

    if show_means:
        plt.plot(0, mean_a, 'ro', label='Initial μ(A)')
        plt.plot(0, mean_b, 'bo', label='Initial μ(B)')
        plt.plot(x, fair_mean_A, label="Fair μ(A)'", color='red')
        plt.plot(x, fair_mean_B, label="Fair μ(B)'", color='blue') 
        plt.axhline(y3, color='red', linestyle='--', label='Optimal μ(A)')
        plt.axhline(y4, color='blue', linestyle='--', label='Optimal μ(B)')

    if show_thresholds:
        plt.plot(x, fair_threshold_A, label="Fair Threshold (A)", color='orange')
        plt.plot(x, fair_threshold_B, label="Fair Threshold (B)", color='purple')
        plt.axhline(y5, color='orange', linestyle='--', label='Optimal Threshold (A)')
        plt.axhline(y6, color='purple', linestyle='--', label='Optimal Threshold (B)')

    if show_utility:
        plt.plot(x, total_util, label='Total Utility', color='green')

    plt.title("Fair vs. Optimal Means under Varying Fairness Constraint α")
    plt.xlabel("α (Fairness Threshold)")
    plt.ylabel("Metric Value")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Toggle widgets
means_toggle = widgets.Checkbox(value=True, description='Show Means')
thresholds_toggle = widgets.Checkbox(value=True, description='Show Thresholds')
utility_toggle = widgets.Checkbox(value=True, description='Show Utility')

ui = widgets.HBox([means_toggle, thresholds_toggle, utility_toggle])
out = widgets.interactive_output(plot_fairness, {
    'show_means': means_toggle,
    'show_thresholds': thresholds_toggle,
    'show_utility': utility_toggle
})

display(ui, out)

100%|██████████| 25/25 [01:21<00:00,  3.26s/it]


HBox(children=(Checkbox(value=True, description='Show Means'), Checkbox(value=True, description='Show Threshol…

Output()

In [3]:
### Vectorized version of the code

from scipy.optimize import fsolve
from tqdm import tqdm
from utils import alt_fair_opt_step, expected

def f(x):
    return expected(x, plus=1, minus=-1)

# Solve for x such that f(x) = 0
root = fsolve(f, x0=0)

def __alt_fair_step(A, B, u_plus, u_minus, c_plus, c_minus, alpha, range, size):
    begin = range[0]
    end = range[1]
    thresholds = np.arange(begin, end, size)

    max_util = -np.inf
    opt_A = None
    opt_B = None
    thresh_A = None
    thresh_B = None 

    w_a = len(A) / (len(A) + len(B))
    w_b = 1 - w_a
    delta_A = expected(A, c_plus, c_minus)
    delta_B = expected(B, c_plus, c_minus)

    a, b = np.meshgrid(thresholds, thresholds, indexing='ij')  # (T, T)
    a_exp = a[:, :, None]  # (T, T, 1)
    b_exp = b[:, :, None]  # (T, T, 1)

    # Expand A/B and compute deltas
    a_matrix = A[None, None, :]               # (1, 1, n)
    b_matrix = B[None, None, :]               # (1, 1, m)
    delta_A = expected(A, c_plus, c_minus)[None, None, :]  # (1, 1, n)
    delta_B = expected(B, c_plus, c_minus)[None, None, :]  # (1, 1, m)

    # Apply delta conditionally based on thresholds
    a_adj = np.where(a_matrix + delta_A > a_exp, a_matrix + delta_A, a_matrix)  # (T, T, n)
    b_adj = np.where(b_matrix + delta_B > b_exp, b_matrix + delta_B, b_matrix)  # (T, T, m)

    # Compute expected utility for each grid cell
    util_A = expected(a_adj, u_plus, u_minus).sum(axis=2)  # (T, T)
    util_B = expected(b_adj, u_plus, u_minus).sum(axis=2)  # (T, T)
    util = w_a * util_A + w_b * util_B                     # (T, T)

    # Fairness check: mean difference under alpha
    mean_diff = np.abs(np.mean(a_adj) - np.mean(b_adj))  # (T, T)
    util[mean_diff >= alpha] = -np.inf  # Mask unfair combinations

    # Find the best combination
    i, j = np.unravel_index(np.argmax(util), util.shape)
    opt_A = a_adj[i, j]
    opt_B = b_adj[i, j]
    thresh_A = a[i, j]
    thresh_B = b[i, j]
    max_util = util[i, j]

    return opt_A, opt_B, thresh_A, thresh_B, max_util

# Distribution parameters
mean_a, std_a = 0.5, 1
mean_b, std_b = 0, 1

# Domain of alpha values to test
# Testing num values of alpha from alpha_min to alpha_max
alpha_min, alpha_max = 0.4, 0.6
alphas = np.linspace(alpha_min, alpha_max, num=25)

# Sample size of each distribution
sample_size = 50

# Generate samples from normal distributions
np.random.seed(1)
a = np.random.normal(mean_a, std_a, sample_size)
b = np.random.normal(mean_b, std_b, sample_size)

# Store results
x = []
fair_mean_A = []
fair_mean_B = []
fair_threshold_A = []
fair_threshold_B = []
total_util = []

# Unconstrained optimal means
opt_result_A = opt_step(a, 1, -1, 1, -1)[0]
opt_result_B = opt_step(b, 1, -1, 1, -1)[0]
y3 = np.mean(opt_result_A)
y4 = np.mean(opt_result_B)
y5 = root
y6 = root

# Run fair optimization over a range of alpha values
for alpha in tqdm(alphas):
    results = alt_fair_opt_step(a, b, 1, -1, 1, -1, alpha, [-3,3], 0.02)
    if results[0] is not None and results[1] is not None:
        A, B, thresh_A, thresh_B, max_util = results
        x.append(alpha)
        
        fair_mean_A.append(np.mean(A))
        fair_mean_B.append(np.mean(B))
        fair_threshold_A.append(thresh_A)
        fair_threshold_B.append(thresh_B)
        total_util.append(max_util * 0.1)  # Scale utility for better visualization

def plot_fairness(show_means=True, show_thresholds=True, show_utility=True):
    plt.figure(figsize=(10, 6))

    if show_means:
        plt.plot(0, mean_a, 'ro', label='Initial μ(A)')
        plt.plot(0, mean_b, 'bo', label='Initial μ(B)')
        plt.plot(x, fair_mean_A, label="Fair μ(A)'", color='red')
        plt.plot(x, fair_mean_B, label="Fair μ(B)'", color='blue') 
        plt.axhline(y3, color='red', linestyle='--', label='Optimal μ(A)')
        plt.axhline(y4, color='blue', linestyle='--', label='Optimal μ(B)')

    if show_thresholds:
        plt.plot(x, fair_threshold_A, label="Fair Threshold (A)", color='orange')
        plt.plot(x, fair_threshold_B, label="Fair Threshold (B)", color='purple')
        plt.axhline(y5, color='orange', linestyle='--', label='Optimal Threshold (A)')
        plt.axhline(y6, color='purple', linestyle='--', label='Optimal Threshold (B)')

    if show_utility:
        plt.plot(x, total_util, label='Total Utility', color='green')

    plt.title("Fair vs. Optimal Means under Varying Fairness Constraint α")
    plt.xlabel("α (Fairness Threshold)")
    plt.ylabel("Metric Value")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Toggle widgets
means_toggle = widgets.Checkbox(value=True, description='Show Means')
thresholds_toggle = widgets.Checkbox(value=True, description='Show Thresholds')
utility_toggle = widgets.Checkbox(value=True, description='Show Utility')

ui = widgets.HBox([means_toggle, thresholds_toggle, utility_toggle])
out = widgets.interactive_output(plot_fairness, {
    'show_means': means_toggle,
    'show_thresholds': thresholds_toggle,
    'show_utility': utility_toggle
})

display(ui, out)



  0%|          | 0/25 [00:00<?, ?it/s]


ValueError: operands could not be broadcast together with shapes (300,300) (50,50) 