In [53]:
# Import necessary modules/libraries
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import plotly.express as px
import plotly.graph_objects as go
import pprint as pp

In [93]:
# Set experiment parameters
np.random.seed(1)

# Distribution parameters
mean_a, std_a = 0.5, 1
mean_b, std_b = 0.0, 1

# Domain of alpha values to test
alpha_min, alpha_max = 0.01, 1
alphas = np.linspace(alpha_min, alpha_max, num=100)

# Sample size of each distribution
n = 25

# Generate samples from normal distributions
a = np.random.normal(mean_a, std_a, n)
b = np.random.normal(mean_b, std_b, n)
a.sort(), b.sort()
w_a = len(a) / (len(a) + len(b))
w_b = 1 - w_a

#Step Parameters
ALPHA = 0.5

# U+ / U- > C+ / C-
U_PLUS = 1
U_MINUS = -1.1
C_PLUS = 1
C_MINUS = -0.4
#print( "Assumption 1: U+ / U- > C+ / C- is", (u_plus / u_minus) > (c_plus / c_minus) )

# Threshold intervals
begin, end = -1, 2
interval = 1
THRESHOLDS = np.arange(begin, end, interval)

In [114]:
from utils import expected

def brute_force(a, b, alpha=None):
    data = []
    utilities = []
    mean_diffs = []
    
    delta_A = expected(a, C_PLUS, C_MINUS)
    delta_B = expected(b, C_PLUS, C_MINUS)

    for thresh_a1 in tqdm(THRESHOLDS):
        A = np.where(a > thresh_a1, a + delta_A, a)
        for thresh_b1 in THRESHOLDS:
            B = np.where(b > thresh_b1, b + delta_B, b)

            delta_A2 = expected(A, C_PLUS, C_MINUS)
            delta_B2 = expected(B, C_PLUS, C_MINUS)
            for thresh_a2 in THRESHOLDS:
                A2 = np.where(A > thresh_a2, A + delta_A2, A)
                
                for thresh_b2 in THRESHOLDS:
                    B2 = np.where(B > thresh_b2, B + delta_B2, B)

                    util_A = np.sum(expected(A2, U_PLUS, U_MINUS))
                    util_B = np.sum(expected(B2, U_PLUS, U_MINUS))
                    total_util = w_a * util_A + w_b * util_B
                    fairness_diff = abs(np.mean(A2) - np.mean(B2))
                    data.append( (fairness_diff, total_util) )
        
    for alpha in alphas:
        valid = [i for i in data if i[0] <= alpha]
        count = len(valid)
        if count > 0:
            max_util, diff = max(valid, key = lambda x: x[1])
        else:
            max_util, diff = np.nan, np.nan
        
        utilities.append(max_util)
        mean_diffs.append(diff)
    
        # Filter the utilities to only those less than alpha
    filtered_indices = [i for i, val in enumerate(utilities) if val < alpha]

    # If no elements meet the condition, handle that case
    if not filtered_indices:
        # handle case, e.g., raise error or return default
        raise ValueError("No utilities less than alpha")

    # Find the index of the max element among those filtered
    temp = max(filtered_indices, key=lambda i: utilities[i])

    return mean_diffs[temp], utilities[temp]


data = brute_force(a, b, alpha=1)

100%|██████████| 3/3 [00:00<00:00, 247.87it/s]


In [106]:
def pairwise_thresholding(a, b, u_plus, u_minus, c_plus, c_minus, thresholds):
    T = len(thresholds)

    n, m = len(a), len(b)

    delta_A = expected(a, c_plus, c_minus)  # (n,)
    delta_B = expected(b, c_plus, c_minus)  # (m,)

    thresh1_a, thresh1_b = np.meshgrid(thresholds, thresholds, indexing='ij')  # (T, T)

    # Broadcast thresholds
    thresh1_a_bc = thresh1_a[:, :, np.newaxis]  # (T, T, 1)
    thresh1_b_bc = thresh1_b[:, :, np.newaxis]  # (T, T, 1)

    A_bc = a[np.newaxis, np.newaxis, :]  # (1, 1, n)
    B_bc = b[np.newaxis, np.newaxis, :]  # (1, 1, m)
    delta_A_bc = delta_A[np.newaxis, np.newaxis, :]
    delta_B_bc = delta_B[np.newaxis, np.newaxis, :]

    condition_A1 = A_bc > thresh1_a_bc
    condition_B1 = B_bc > thresh1_b_bc
    
    # Updated arrays
    a_adj = np.where(condition_A1, A_bc + delta_A_bc, A_bc)  # (T, T, n)
    b_adj = np.where(condition_B1, B_bc + delta_B_bc, B_bc)  # (T, T, m)


    return a_adj, b_adj

alpha = 1
a_adj, b_adj = pairwise_thresholding(a, b, U_PLUS, U_MINUS, C_PLUS, C_MINUS, THRESHOLDS)

T = len(THRESHOLDS)
a_rows = []
b_rows = []

for i in range(T):
    for j in range(T):
        print(f"Threshold pair: (a={THRESHOLDS[i]}, b={THRESHOLDS[j]})")
        print(f"a_adj[{i}, {j}] = {a_adj[i, j]}")  # shape (n,)
        print(f"b_adj[{i}, {j}] = {b_adj[i, j]}")  # shape (m,)
        print("-" * 40)
        a_rows.append(a_adj[i, j])
        b_rows.append(b_adj[i, j])

max_util = -np.inf
for A, B in zip(a_rows, b_rows):
    a_2, b_2 = pairwise_thresholding(A, B, U_PLUS, U_MINUS, C_PLUS, C_MINUS, THRESHOLDS)
    
    a_mean = np.mean(a_2, axis=-1)
    b_mean = np.mean(b_2, axis=-1)
    a_util = np.sum(expected( a_2, U_PLUS, U_MINUS), axis=-1)
    b_util = np.sum(expected( b_2, U_PLUS, U_MINUS), axis=-1)

    diff = np.abs(a_mean - b_mean)
    util = 0.5 * a_util + 0.5 * b_util

    util = np.where(diff > alpha, -np.inf, util)
    print(util)
    print("-" * 40)
    print(diff)
    print("-" * 40)


    max_util = max( np.max(util), max_util)

print(max_util)



Threshold pair: (a=-1, b=-1)
a_adj[0, 0] = [-1.8015387  -1.56014071 -0.50473631 -0.50377527 -0.46819566 -0.20855748
 -0.05211303  0.1491695   0.26196879  0.45648122  0.53957395  0.63789368
  0.74120761  1.02747188  1.39068436  1.62666254  1.72865213  2.08069967
  2.1241035   2.12500139  2.40512895  2.41817105  2.7895732   2.97489116
  3.11069466]
b_adj[0, 0] = [-1.11731035 -0.94131147 -0.87938985 -0.82459027 -0.69714116 -0.62453153
 -0.61864546 -0.61412615 -0.59774112 -0.55269926 -0.2338239  -0.06109218
  0.04122716  0.13415224  0.28290285  0.36858664  0.46216406  0.55753368
  0.61608754  1.01174789  1.12661815  1.29046394  2.4360982   2.47481997
  2.9475543 ]
----------------------------------------
Threshold pair: (a=-1, b=0)
a_adj[0, 1] = [-1.8015387  -1.56014071 -0.50473631 -0.50377527 -0.46819566 -0.20855748
 -0.05211303  0.1491695   0.26196879  0.45648122  0.53957395  0.63789368
  0.74120761  1.02747188  1.39068436  1.62666254  1.72865213  2.08069967
  2.1241035   2.12500139  2.4

In [115]:
print(data)


(np.float64(6.983792499140073), np.float64(0.8985219856359585))
