In [1]:
# Import necessary modules/libraries
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import plotly.express as px
import plotly.graph_objects as go
import pprint as pp

In [2]:
# Set experiment parameters
np.random.seed(1)

# Distribution parameters
mean_a, std_a = 0.5, 1
mean_b, std_b = 0.0, 1

# Domain of alpha values to test
alpha_min, alpha_max = 0.01, 1
alphas = np.linspace(alpha_min, alpha_max, num=100)

# Sample size of each distribution
n = 25

# Generate samples from normal distributions
a = np.random.normal(mean_a, std_a, n)
b = np.random.normal(mean_b, std_b, n)
a.sort(), b.sort()
w_a = len(a) / (len(a) + len(b))
w_b = 1 - w_a

#Step Parameters
ALPHA = 0.5

# U+ / U- > C+ / C-
U_PLUS = 1
U_MINUS = -1.1
C_PLUS = 1
C_MINUS = -0.4
#print( "Assumption 1: U+ / U- > C+ / C- is", (u_plus / u_minus) > (c_plus / c_minus) )

# Threshold intervals
begin, end = -3, 3.1
interval = 0.1
THRESHOLDS = np.arange(begin, end, interval)

In [3]:
from utils import expected

def brute_force(a, b):
    utilities = []
    mean_diffs = []
    
    delta_A = expected(a, C_PLUS, C_MINUS)
    delta_B = expected(b, C_PLUS, C_MINUS)

    for thresh_a1 in tqdm(THRESHOLDS):
        A = np.where(a > thresh_a1, a + delta_A, a)
        for thresh_b1 in THRESHOLDS:
            B = np.where(b > thresh_b1, b + delta_B, b)

            delta_A2 = expected(A, C_PLUS, C_MINUS)
            delta_B2 = expected(B, C_PLUS, C_MINUS)
            for thresh_a2 in THRESHOLDS:
                A2 = np.where(A > thresh_a2, A + delta_A2, A)
                
                for thresh_b2 in THRESHOLDS:
                    B2 = np.where(B > thresh_b2, B + delta_B2, B)

                    util_A = np.sum(expected(A2, U_PLUS, U_MINUS))
                    util_B = np.sum(expected(B2, U_PLUS, U_MINUS))
                    total_util = w_a * util_A + w_b * util_B
                    fairness_diff = abs(np.mean(A2) - np.mean(B2))

                    utilities.append(total_util)
                    mean_diffs.append(fairness_diff)
    return (mean_diffs, utilities)

In [4]:
diffs, utilities = brute_force(a, b)

100%|██████████| 61/61 [08:13<00:00,  8.08s/it]


In [5]:
y1 = []
for alpha in tqdm(alphas):
    valid_diffs = np.where(diffs <= alpha, utilities, -np.inf)
    max_util = np.max(valid_diffs)
    y1.append(max_util)

100%|██████████| 100/100 [01:43<00:00,  1.04s/it]


In [6]:
n4 = go.Scatter(x=alphas, y=y1)
fig = go.Figure(data=[n4])
fig.show()

In [4]:
# Vectorized approach
def pairwise_thresholding(a, b, c_plus, c_minus, thresholds):
    T = len(thresholds)

    n, m = len(a), len(b)

    delta_A = expected(a, c_plus, c_minus)  # (n,)
    delta_B = expected(b, c_plus, c_minus)  # (m,)

    thresh1_a, thresh1_b = np.meshgrid(thresholds, thresholds, indexing='ij')  # (T, T)

    # Broadcast thresholds
    thresh1_a_bc = thresh1_a[:, :, np.newaxis]  # (T, T, 1)
    thresh1_b_bc = thresh1_b[:, :, np.newaxis]  # (T, T, 1)

    A_bc = a[np.newaxis, np.newaxis, :]  # (1, 1, n)
    B_bc = b[np.newaxis, np.newaxis, :]  # (1, 1, m)
    delta_A_bc = delta_A[np.newaxis, np.newaxis, :]
    delta_B_bc = delta_B[np.newaxis, np.newaxis, :]

    condition_A1 = A_bc > thresh1_a_bc
    condition_B1 = B_bc > thresh1_b_bc
    
    # Updated arrays
    a_adj = np.where(condition_A1, A_bc + delta_A_bc, A_bc)  # (T, T, n)
    b_adj = np.where(condition_B1, B_bc + delta_B_bc, B_bc)  # (T, T, m)


    return a_adj, b_adj

def vectorized(a, b):
    a_adj, b_adj = pairwise_thresholding(a, b, C_PLUS, C_MINUS, THRESHOLDS)

    T = len(THRESHOLDS)
    a_rows = []
    b_rows = []

    
    for i in range(T):
        for j in range(T):
            #print(f"Threshold pair: (a={THRESHOLDS[i]}, b={THRESHOLDS[j]})")
            #print(f"a_adj[{i}, {j}] = {a_adj[i, j]}")  # shape (n,)
            #print(f"b_adj[{i}, {j}] = {b_adj[i, j]}")  # shape (m,)
            #print("-" * 40)
            a_rows.append(a_adj[i, j])
            b_rows.append(b_adj[i, j])

    print(a_adj.shape, b_adj.shape)
    print(len(a_rows), len(b_rows))
    
    util_matrix = []
    diff_matrix = []

    count = 0
    for A, B in tqdm(zip(a_rows, b_rows)):
        a_2, b_2 = pairwise_thresholding(A, B, C_PLUS, C_MINUS, THRESHOLDS)
    
        a_mean = np.mean(a_2, axis=-1)
        b_mean = np.mean(b_2, axis=-1)
        a_util = np.sum(expected( a_2, U_PLUS, U_MINUS), axis=-1)
        b_util = np.sum(expected( b_2, U_PLUS, U_MINUS), axis=-1)

        diff = np.abs(a_mean - b_mean)
        util = 0.5 * a_util + 0.5 * b_util

        util_matrix.append(util)
        diff_matrix.append(diff)

        count += util.size

    util_matrix = np.vstack(util_matrix)
    diff_matrix = np.vstack(diff_matrix)
    print(count)
    print(util_matrix.shape, diff_matrix.shape)
    print(util_matrix.size, diff_matrix.size)

    return util_matrix, diff_matrix




In [5]:
util_matrix, diff_matrix = vectorized(a, b)

(61, 61, 25) (61, 61, 25)
3721 3721


3721it [00:21, 175.12it/s]


13845841
(226981, 61) (226981, 61)
13845841 13845841


In [9]:
print(util_matrix.shape)
print(diff_matrix.shape)
print(len(THRESHOLDS))

(226981, 61)
(226981, 61)
61


In [6]:
y2 = []
print(util_matrix)
print(diff_matrix)
for alpha in tqdm(alphas):
    valid_diffs = np.where(diff_matrix <= alpha, util_matrix, -np.inf)
    y2.append(np.max(valid_diffs))


[[6.86681721 6.86681721 6.86681721 ... 5.50223763 5.50223763 5.47118409]
 [6.86681721 6.86681721 6.86681721 ... 5.50223763 5.50223763 5.47118409]
 [6.86681721 6.86681721 6.86681721 ... 5.50223763 5.50223763 5.47118409]
 ...
 [2.19010447 2.19010447 2.19010447 ... 0.87909218 0.87909218 0.87909218]
 [2.19010447 2.19010447 2.19010447 ... 0.87909218 0.87909218 0.87909218]
 [2.19010447 2.19010447 2.19010447 ... 0.87909218 0.87909218 0.87909218]]
[[0.87231449 0.87231449 0.87231449 ... 1.18063343 1.18063343 1.2178417 ]
 [0.87231449 0.87231449 0.87231449 ... 1.18063343 1.18063343 1.2178417 ]
 [0.87231449 0.87231449 0.87231449 ... 1.18063343 1.18063343 1.2178417 ]
 ...
 [0.24996337 0.24996337 0.24996337 ... 0.52556992 0.52556992 0.52556992]
 [0.24996337 0.24996337 0.24996337 ... 0.52556992 0.52556992 0.52556992]
 [0.24996337 0.24996337 0.24996337 ... 0.52556992 0.52556992 0.52556992]]


100%|██████████| 100/100 [00:08<00:00, 12.14it/s]


In [7]:
n2 = go.Scatter(x=alphas, y=y2)
fig = go.Figure(data=[n2])
fig.show()