In [None]:
import numpy as np
import random
import wandb
import matplotlib.pyplot as plt

from deap import base
from deap import creator
from deap import tools
from deap import algorithms

wandb.login()

random.seed(74385)

In [3]:
# Create Configs
hyperparams ={
"__NAME": f"FINAL",
"algo": 'MPL',
"crossover": 'PMX',
"selection": 'LEX',
"mutation": 'ShuffleIndexes',
"n_bits": 10,
"sym": False,
"lambda_mul": 3,
"pop_size": 10,
"tournsize": 5,
"cxpb": 0.5,
"mutpb": 0.5,
"ngen": 500
}

run = wandb.init(
    project='Evolving Tinder',
    config = hyperparams,
    name=hyperparams['__NAME'],
)

config = run.config

In [None]:
# Preferences
if config.sym:
    m_preferences = np.array([np.roll(np.arange(config.n_bits), i) for i in range(config.n_bits)])    # Each column i represents man i's ranking of all women  
    w_preferences = np.array([np.roll(np.arange(config.n_bits), i) for i in range(config.n_bits)])
    m_preferences = [list(i) for i in m_preferences]
    w_preferences = [list(i) for i in w_preferences]
    print('Best Case Preferences')
else:
    m_preferences = [list(np.random.permutation(np.arange(config.n_bits))) for _ in range(config.n_bits)]    # Each column i represents man i's ranking of all women  
    w_preferences = [list(np.random.permutation(np.arange(config.n_bits))) for _ in range(config.n_bits)]    # Each column i represents woman i's ranking of all men  
    print('Random Preferences')
    
# Penalty
def blocking_pairs(x, m_preferences=m_preferences, w_preferences = w_preferences):
    ans = 0
    for man in list(range(len(x))):
        current_woman = x[man]
        for woman in list(range(len(x))):
            current_man = x.index(woman)
            if m_preferences[man].index(woman) > m_preferences[man].index(current_woman):
                if w_preferences[woman].index(man) > w_preferences[woman].index(current_man):
                    ans += 1
    return ans

# Multi-Objective Function
def evaluate(x, m_preferences=m_preferences, w_preferences=w_preferences):

    # 1 - Social Value
    w_scores, m_scores = 0.0, 0.0   # each is bounded below by 0 and above by n * (n-1)
    for man, woman in enumerate(x):
        m_scores += m_preferences[man].index(woman)
        w_scores += w_preferences[woman].index(man)
    
    # 2 - Blocking Pairs
    penalty = blocking_pairs(x) # bounded by all pairs being blocking, at most n choose 2, i.e., n * (n-1)

    return m_scores, w_scores, penalty

# DEAP Environment
creator.create("FitnessMulti", base.Fitness, weights=(1.0, 1.0, -1.0))   # Max Social Value, Min Blocking Pairs
creator.create("Individual", list, fitness=creator.FitnessMulti)
toolbox = base.Toolbox()
toolbox.register("permutation", np.random.choice, range(config.n_bits), config.n_bits, False)
toolbox.register("individual", tools.initIterate, creator.Individual, toolbox.permutation)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("evaluate", evaluate)

# Crossover
if config.crossover == 'PMX':
    toolbox.register("mate", tools.cxPartialyMatched)
elif config.crossover == 'OX':
    toolbox.register("mate", tools.cxOrdered)

# Mutation
if config.mutation == 'ShuffleIndexes':
    toolbox.register("mutate", tools.mutShuffleIndexes, indpb=2.0/config.n_bits)

# Selection
if config.selection == 'NSGA2':
    toolbox.register("select", tools.selNSGA2)
elif config.selection == 'SPEA2':
    toolbox.register("select", tools.selSPEA2)
elif config.selection == 'LEX':
    toolbox.register("select", tools.selLexicase)

# To save stats at each generation
def wandb_log(fits):

    wandb.log({
        'SVm_max': np.max(fits, axis = 0)[0],
        'SVw_max': np.max(fits, axis = 0)[1],
        'BP_min': np.min(fits, axis = 0)[2],
    })
    return 'Done'

# Evolution
def main(seed=0):
    random.seed(seed)

    pop = toolbox.population(n=config.pop_size)
    hof = tools.ParetoFront()
    stats = tools.Statistics(lambda ind: ind.fitness.values) # Object that compiles statistics on a list of arbitrary objects
    stats.register("SVm_max", lambda x:  np.max(x, axis = 0)[0])
    stats.register("SVw_max", lambda x:  np.max(x, axis = 0)[1])
    stats.register("BP_min", lambda x:  np.min(x, axis = 0)[2])
    stats.register("wandb_log", wandb_log)
    
    if config.algo == 'S':
        algorithms.eaSimple(pop, toolbox, cxpb=config.cxpb, mutpb=config.mutpb, ngen=config.ngen, stats=stats,
                            halloffame=hof, verbose=True)
    elif config.algo == 'MPL':
        algorithms.eaMuPlusLambda(pop, toolbox, mu=config.pop_size, lambda_=int(config.pop_size*config.lambda_mul), cxpb=config.cxpb, mutpb=config.mutpb, ngen=config.ngen, stats=stats,
                            halloffame=hof, verbose=True)
    elif config.algo == 'MCL':
        algorithms.eaMuCommaLambda(pop, toolbox, mu=config.pop_size, lambda_=int(config.pop_size*config.lambda_mul), cxpb=config.cxpb, mutpb=config.mutpb, ngen=config.ngen, stats=stats,
                            halloffame=hof, verbose=True)
        
    return pop, stats, hof
pop, stats, hof = main()

# Deferred Acceptance Algorithm
def DAA(proposers_prefs, acceptors_prefs):
    n = len(proposers_prefs)  # Number of proposers (and acceptors)
    
    # Initialize all proposers and acceptors as free
    free_proposers = list(range(n))
    proposals = {proposer: [] for proposer in range(n)}
    acceptor_matches = {acceptor: None for acceptor in range(n)}
    proposer_matches = [None] * n
    
    # While there are free proposers who haven't proposed to all acceptors
    while free_proposers:
        proposer = free_proposers.pop(0)
        proposer_pref_list = proposers_prefs[proposer]
        
        for acceptor in proposer_pref_list[::-1]:
            if acceptor not in proposals[proposer]:
                proposals[proposer].append(acceptor)
                current_match = acceptor_matches[acceptor]
                
                if current_match is None:
                    # The acceptor is free and will accept the proposal
                    acceptor_matches[acceptor] = proposer
                    proposer_matches[proposer] = acceptor
                else:
                    # The acceptor is currently matched, decide if they prefer the new proposer
                    current_proposer_rank = acceptors_prefs[acceptor].index(current_match)
                    new_proposer_rank = acceptors_prefs[acceptor].index(proposer)
                    
                    if new_proposer_rank > current_proposer_rank:
                        # The acceptor prefers the new proposer, match them
                        acceptor_matches[acceptor] = proposer
                        proposer_matches[proposer] = acceptor
                        
                        # The previous match is now free
                        free_proposers.append(current_match)
                        proposer_matches[current_match] = None
                    else:
                        # The acceptor rejects the proposal, proposer remains free
                        free_proposers.append(proposer)
                break
    
    return proposer_matches

# Compare solutions
daa_x = DAA(m_preferences, w_preferences)
print(daa_x)
print(evaluate(daa_x))
evo_found_daa = daa_x in list(hof)
daa_x = DAA(m_preferences, w_preferences)
daa_fit = evaluate(daa_x)
fits = list(map(evaluate, list(hof)))
fits.append(daa_fit)
xs = [i[0] for i in fits]
ys = [i[1] for i in fits]
labels = [i[2] for i in fits]

# Visualize Pareto Front
plt.scatter(xs, ys, c=labels, cmap='viridis')   #cmap='tab10'
plt.annotate('DAA_solution',
            daa_fit[:2],
            textcoords="offset points",
            xytext=(-100,0),# distance from the point
            ha='center',
            color='green' if evo_found_daa else 'red',
            arrowprops=dict(arrowstyle="->", lw=.5))
cbar = plt.colorbar()
cbar.set_label('Stability')
cbar.set_ticks(np.linspace(0, np.max([labels])+1, 10, dtype=int))
cbar.set_ticklabels(np.linspace(0, np.max([labels])+1, 10, dtype=int))
plt.ylabel('Women\'s Social Value')
plt.xlabel('Men\'s Social Value')
plt.title(f'{config.__NAME}')
plt.savefig(f'./plots/{config.__NAME}.png')
plt.show()
wandb.log({'ParetoFront': wandb.Image(f'./plots/{config.__NAME}.png')})

run.finish()