In [None]:
# Import necessary modules/libraries
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import plotly.express as px
import plotly.graph_objects as go

In [None]:
# Set experiment parameters
np.random.seed(1)

# Distribution parameters
mean_a, std_a = 0.5, 1
mean_b, std_b = 0.0, 1

# Domain of alpha values to test
alpha_min, alpha_max = 0.01, 1
alphas = np.linspace(alpha_min, alpha_max, num=100)

# Sample size of each distribution
n = 5

# Generate samples from normal distributions
a = np.random.normal(mean_a, std_a, n)
b = np.random.normal(mean_b, std_b, n)
a.sort(), b.sort()
w_a = len(a) / (len(a) + len(b))
w_b = 1 - w_a

#Step Parameters

# U+ / U- > C+ / C-
U_PLUS = 1
U_MINUS = -1.1
C_PLUS = 1
C_MINUS = -0.4
print( "Assumption 1: U+ / U- > C+ / C- is", (u_plus / u_minus) > (c_plus / c_minus) )

# Threshold intervals
begin, end = -3, 3
interval = 0.1
THRESHOLDS = np.arange(begin, end, interval)

In [None]:
from utils import expected

def brute_force(a, b):
    data = []
    utilities = []
    mean_diffs = []
    
    delta_A = expected(a, C_PLUS, C_MINUS)
    delta_B = expected(b, C_PLUS, C_MINUS)

    for thresh_a1 in THRESHOLDS:
        A = np.where(a > thresh_a1, a + delta_A, a)
        for thresh_b1 in THRESHOLDS:
            B = np.where(b > thresh_b1, b + delta_B, b)

            delta_A2 = expected(A, C_PLUS, C_MINUS)
            delta_B2 = expected(B, C_PLUS, C_MINUS)
            for thresh_a2 in THRESHOLDS:
                A2 = np.where(A > thresh_a2, A + delta_A2, A)
                
                for thresh_b2 in THRESHOLDS:
                    B2 = np.where(B > thresh_b2, B + delta_B2, B)

                    util_A = np.sum(np.where(A2 > 0, U_PLUS, U_MINUS))
                    util_B = np.sum(np.where(B2 > 0, U_PLUS, U_MINUS))
                    total_util = w_a * util_A + w_b * util_B
                    fairness_diff = abs(np.mean(A2) - np.mean(B2))
                    data.append(fairness_diff, total_util)
        
    for alpha in alphas:
        valid = [i for i in data if i[0] <= alpha]
        count = len(valid)
        if count > 0:
            max_util, diff = max(valid, key = lambda x: x[1])
        else:
            max_util, diff = np.nan, np.nan
        
        utilities.append(max_util)
        mean_diffs.append(diff)

                



