In [2]:
# Import necessary modules/libraries
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import plotly.express as px
import plotly.graph_objects as go

In [3]:
# Set experiment parameters
np.random.seed(1)

# Distribution parameters
mean_a, std_a = 0.5, 1
mean_b, std_b = 0.0, 1

# Domain of alpha values to test
alpha_min, alpha_max = 0.01, 1
alphas = np.linspace(alpha_min, alpha_max, num=100)

# Sample size of each distribution
n = 5

# Generate samples from normal distributions
a = np.random.normal(mean_a, std_a, n)
b = np.random.normal(mean_b, std_b, n)
a.sort(), b.sort()
w_a = len(a) / (len(a) + len(b))
w_b = 1 - w_a

#Step Parameters

# U+ / U- > C+ / C-
U_PLUS = 1
U_MINUS = -1.1
C_PLUS = 1
C_MINUS = -0.4
#print( "Assumption 1: U+ / U- > C+ / C- is", (u_plus / u_minus) > (c_plus / c_minus) )

# Threshold intervals
begin, end = -3, 3
interval = 1
THRESHOLDS = np.arange(begin, end, interval)

In [4]:
from utils import expected

def brute_force(a, b):
    data = []
    utilities = []
    mean_diffs = []
    
    delta_A = expected(a, C_PLUS, C_MINUS)
    delta_B = expected(b, C_PLUS, C_MINUS)

    for thresh_a1 in tqdm(THRESHOLDS):
        A = np.where(a > thresh_a1, a + delta_A, a)
        for thresh_b1 in THRESHOLDS:
            B = np.where(b > thresh_b1, b + delta_B, b)

            delta_A2 = expected(A, C_PLUS, C_MINUS)
            delta_B2 = expected(B, C_PLUS, C_MINUS)
            for thresh_a2 in THRESHOLDS:
                A2 = np.where(A > thresh_a2, A + delta_A2, A)
                
                for thresh_b2 in THRESHOLDS:
                    B2 = np.where(B > thresh_b2, B + delta_B2, B)

                    util_A = np.sum(np.where(A2 > 0, U_PLUS, U_MINUS))
                    util_B = np.sum(np.where(B2 > 0, U_PLUS, U_MINUS))
                    total_util = w_a * util_A + w_b * util_B
                    fairness_diff = abs(np.mean(A2) - np.mean(B2))
                    data.append( (fairness_diff, total_util) )
        
    for alpha in alphas:
        valid = [i for i in data if i[0] <= alpha]
        count = len(valid)
        if count > 0:
            max_util, diff = max(valid, key = lambda x: x[1])
        else:
            max_util, diff = np.nan, np.nan
        
        utilities.append(max_util)
        mean_diffs.append(diff)
    
    return data

data = brute_force(a, b)

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x000002288E6ACD40>>
Traceback (most recent call last):
  File "C:\Users\joelj\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\ipykernel\ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 
100%|██████████| 6/6 [00:00<00:00, 197.36it/s]


In [None]:
A_mat = np.vstack([a for _ in range(0, len(THRESHOLDS))])
B_mat = np.vstack([b for _ in range(0, len(THRESHOLDS))])
T_mat = np.vstack([THRESHOLDS for _ in range(0, len(a))]).T


A_delta = expected(A_mat, C_PLUS, C_MINUS)
A_mat = np.where(A_mat > T_mat, A_mat + A_delta, A_mat)

B_delta = expected(B_mat, C_PLUS, C_MINUS)
B_mat = np.where(B_mat > T_mat, B_mat + B_delta, B_mat)

thresh_A, thresh_B = np.meshgrid(THRESHOLDS, THRESHOLDS, indexing='ij') # 2d grid; len(thresholds)^2
A_mat = a[:, np.newaxis, np.newaxis] #3d grid; len(a) x 1 x 1
B_mat = b[:, np.newaxis, np.newaxis] #3d grid; len(b) x 1 x 1



(6, 6)
(5, 1, 1)
[[[-0.57296862]]

 [[-0.11175641]]

 [[-0.02817175]]

 [[ 1.36540763]]

 [[ 2.12434536]]]
[[[-3.57296862 -3.57296862 -3.57296862 -3.57296862 -3.57296862
   -3.57296862]
  [-2.57296862 -2.57296862 -2.57296862 -2.57296862 -2.57296862
   -2.57296862]
  [-1.57296862 -1.57296862 -1.57296862 -1.57296862 -1.57296862
   -1.57296862]
  [-0.57296862 -0.57296862 -0.57296862 -0.57296862 -0.57296862
   -0.57296862]
  [ 0.42703138  0.42703138  0.42703138  0.42703138  0.42703138
    0.42703138]
  [ 1.42703138  1.42703138  1.42703138  1.42703138  1.42703138
    1.42703138]]

 [[-3.11175641 -3.11175641 -3.11175641 -3.11175641 -3.11175641
   -3.11175641]
  [-2.11175641 -2.11175641 -2.11175641 -2.11175641 -2.11175641
   -2.11175641]
  [-1.11175641 -1.11175641 -1.11175641 -1.11175641 -1.11175641
   -1.11175641]
  [-0.11175641 -0.11175641 -0.11175641 -0.11175641 -0.11175641
   -0.11175641]
  [ 0.88824359  0.88824359  0.88824359  0.88824359  0.88824359
    0.88824359]
  [ 1.88824359  1.8882