## Config

In [None]:
import functools
import random
import matplotlib.pyplot as plt
import os
import torch
import json
import pandas as pd
import numpy as np
from sklearn import decomposition

In [None]:
from swarmi.game import (
    evaluate,
    agent_init,
    get_distance_reward,
    represent_poolers
)
from swarmi.operators import (
    prodop,
    NeuralNetOp,
    average,
    lambda_update,
    weighted_avg
)

In [None]:
# GAME
## one hypothesis only, agents care about P(H=true)
N_REPEATS = 500
NUM_ITERS = 50 # 20
N_AGENTS = 100
MAX_POOLERS = 20 # 10
MIN_POOLERS = 10
N_POOLS = 1
AGENT_FUNC = lambda: N_AGENTS
INITIALISATION = functools.partial(
    agent_init, 
    belief_init_func=lambda: random.uniform(0,1), 
    evidence_init_func=lambda: random.uniform(0,.5), 
    noise_init_func=lambda: random.uniform(.5,1),
    pooling_init_func=lambda: random.uniform(0,1),
    communication_init_func=lambda: random.uniform(0,.1),
    zealot_init_func=lambda: random.uniform(0,.5),
)
REWARD_FUNC = functools.partial(get_distance_reward, sign=1, exp=1)
N_TRUTH_FLIPS = 0 # 5
UPDATE_FUNC = None # functools.partial(lambda_update, _lambda_new=0.90)
PREPROCESS_FUNC = None # lambda x_features: x_features[2] 

trial_name = 'bigbatch'
NN_POLICIES = {
    # tests on big batch
    'RL BigBatch 60k steps': NeuralNetOp(f'./models/{trial_name}_62500.pt'),
    'RL BigBatch 125k steps': NeuralNetOp(f'./models/{trial_name}_125000.pt'),
    'RL BigBatch 600k steps': NeuralNetOp(f'./models/{trial_name}_done.pt'),

    # official paper models    
    # 'RL 2M': NeuralNetOp(f'./models/rl_varied_n_avgagg_6D_done.pt'),
    'Imitation NN - Avg': NeuralNetOp("./models/imitated_avg_10k_50e_150width_20maxagents_6D_normalised_avg_agg.pt"),
    # 'Imitation NN - NormLogOp': NeuralNetOp("./models/imitated_normprodop_10k_50e_150width_20maxagents_6D_normalised_avg_agg.pt")

    # tests on initialisation
    **{f'RL - {x} init': NeuralNetOp(f'./models/{x}_0_done.pt') for x in ['fromrandom', 'fromavg', 'fromsumaggavg', 'fromnormlogop']}
    
}
ALL_POLICIES = {'Avg': average, **NN_POLICIES} # only used for some plots below
BASELINE_OPS = {'Avg': weighted_avg, 'LogOp': prodop}
PREPROC_FUNCS = {'1': lambda x: 1, '$w_{truth}$': None}
_LAMBDAS = [0.95, 1]

## Run Eval

In [None]:
grid_scores = {}
for param_val in [x for x in range(MIN_POOLERS-5,MAX_POOLERS+5+1)]:
    print(param_val)
    # get eval results
    common_args = (
        N_REPEATS,
        AGENT_FUNC,
        NUM_ITERS,
        int(param_val),
        int(param_val),
        N_POOLS,
        INITIALISATION,
        REWARD_FUNC,
        N_TRUTH_FLIPS
    )
    scores, belief_history = evaluate(
        NN_POLICIES,
        *common_args,
        UPDATE_FUNC,
        PREPROCESS_FUNC,
        using_nn_pooling=True
    )
    for op_name, operator in BASELINE_OPS.items():
        for normalised in [True, False]:
            for preproc_name, preproc in PREPROC_FUNCS.items():
                for _lambda in _LAMBDAS:
                    if operator == weighted_avg and not normalised:
                        continue
                    _operator = functools.partial(prodop, normalise_w=True) if operator == prodop and normalised else operator
                    scores_v, belief_history_v = evaluate(
                        {
                            f'{op_name}, normalised={normalised}, $w_i$={preproc_name} $\lambda$={_lambda}': _operator
                        },
                        *common_args,
                        functools.partial(lambda_update, _lambda_new=_lambda),
                        preproc,
                        using_nn_pooling=False
                    )
                    scores.update(scores_v)
                    belief_history.update(belief_history_v)

    # get metrics out
    grid_scores[param_val] = {}
    for policy, grid_values in scores.items():
        avg_scores = []
        for values in grid_values:
            # gets average across time
            avg_scores.append(sum(values)/len(values))

        # mean across trials (for each param)
        mean_value = np.mean(avg_scores)
        std_value = np.std(avg_scores, ddof=1)  # Use ddof=1 for sample standard deviation
        confidence_interval = 1.96 * (std_value / np.sqrt(len(avg_scores)))

        grid_scores[param_val][policy] = {'mean': float(mean_value), 'ci': float(confidence_interval), 'samples': avg_scores}

In [None]:
with open(f'./grid_results_{N_REPEATS}.json', 'w') as f:
    json.dump(grid_scores, f)

## Performance

In [None]:
with open(f'./grid_results_{N_REPEATS}.json', 'r') as f:
    grid_scores = json.load(f)
    grid_scores = {int(k) : v for k, v in grid_scores.items()}

In [None]:
def convert_label(x: str):
    if 'RL' in x:
        return 'Reinforcement Learned'
    elif 'NN' in x:
        if 'NormLogOp' in x:
            return 'Supervised Learned (NormLogOp)'
        elif 'Avg' in x:
            return 'Supervised Learned'
    else:
        return x.replace("Avg, normalised=True,", "Weighted Average").replace('LogOp, normalised=True,', 'NormLogOp').replace(', normalised=False,', '').replace('$w_{truth}$', '$\omega_t$')

In [None]:
xs = []
ys = {p: [] for p in grid_scores[list(grid_scores.keys())[0]].keys()}
ys_err = {p: [] for p in grid_scores[list(grid_scores.keys())[0]].keys()}

for param_val, policy_vals in grid_scores.items():
    xs.append(param_val)
    for policy in ys.keys():
        ys[policy].append(policy_vals[policy]['mean'])    
        ys_err[policy].append(policy_vals[policy]['ci'])

In [None]:
for x in ys:
    print(x)

In [None]:
plt.figure(figsize=(22, 9))
markers = {
    'R': 'o',
    'N': 's',
    'W': 'D',
    'L': 'X',
    'S': 'o'
}
for p, vals in ys.items():
    if p in [
        *[x for x in ys.keys() if 'NN' in x or 'RL' in x],
        'LogOp, normalised=True, $w_i$=1 $\lambda$=0.95', 
        'LogOp, normalised=False, $w_i$=1 $\lambda$=0.95', 
        'Avg, normalised=True, $w_i$=$w_{truth}$ $\lambda$=1',
        'LogOp, normalised=False, $w_i$=1 $\lambda$=1', 
    ]:
        if 'NN - NormLogOp' in p:
            continue
        l = convert_label(p)
        plt.errorbar(xs, vals, ys_err[p], label=l, marker=markers[l[0]], markersize=10, capsize=6)

# Add vertical lines with arrows to indicate the training pool size
plt.axvline(x=20, color='r', linestyle='--')
plt.axvline(x=10, color='r', linestyle='--')
# Add a horizontal arrow between the two vertical lines
plt.annotate('', xy=(20, 0.18), xytext=(10, 0.18),
             arrowprops=dict(arrowstyle='<->', color='red'))
plt.text(15, 0.185, 'Training Interval', horizontalalignment='center', fontsize=16, color='red')


# Increase the size of the axis ticks and labels
plt.xticks(ticks=xs, fontsize=16)
plt.yticks(fontsize=16)

legend = plt.legend(ncol=3, fontsize=12.1, fancybox=False, edgecolor="black")
legend.get_frame().set_linewidth(0.5)
plt.xlabel('$N_{pool.size}$', fontsize=18)
plt.ylabel('Avg. MAE ± 95% CI', fontsize=18)

In [None]:
results_in_training_dist = {k:v for k,v in grid_scores.items() if int(k) > 9 and int(k) < 21}
first_key = list(results_in_training_dist.keys())[0]
samples = {p: [] for p in results_in_training_dist[first_key].keys()}

for param_val, policy_vals in results_in_training_dist.items():
    for policy in samples.keys():
        samples[policy].extend(policy_vals[policy]['samples'])

In [None]:
df = pd.DataFrame(samples)
df.columns = [convert_label(x) for x in df.columns]
for k in df.columns:
    if 'NN - NormLogOp' in k:
        continue
    std_dev = df[k].std(ddof=1)
    ci_interval = 1.96 * (std_dev / np.sqrt(len(df[k])))
    print(f"{k}: {df[k].mean():.3f} +/- {ci_interval:.3f}")

## Interpretability

In [None]:
agents = agent_init(
    MIN_POOLERS,
    belief_init_func=lambda: random.uniform(0,1),
    evidence_init_func=lambda: .25,
    noise_init_func=lambda: 0.75,
    pooling_init_func=lambda: 0.5,
    communication_init_func=lambda: 0.05,
    zealot_init_func=lambda: 0.25,
)

corrected = {}
for a, rep in agents.items():
    if a < 5:
        corrected[a] = (1, *rep[1:])
    else:
        corrected[a] = (0, *rep[1:])
corrected

In [None]:
reliable = (0.5, 1, 0.01, 0.001, 0.25)
unreliable = (0.5, 0.5, .01, 0.001, 0.25)
avg_reliable = (0.25, 0.5, 0.5, 0.001, 0.25)
const = avg_reliable

preproc = None
logop = functools.partial(prodop, normalise_w=False)
_lambda = 0.95

def kl_divergence(p, q):
    return sum(p[i] * np.log(p[i]/q[i]) for i in range(len(p)))

def entropy(p):
    return -sum(p[i] * np.log(p[i]) for i in range(len(p)))

effects = {
    p.replace('\n', ' - '): {
        'kls': [], 
        'entropy_diff': [], 
        'policy': ALL_POLICIES[p] if p in ALL_POLICIES else prodop
    }
    for p in ['Avg', f'LogOp $w_i=1$ $\lambda=${_lambda}'] + list(NN_POLICIES.keys())
}

n_agents = 10
xs = [round(x,4) for x in np.linspace(0.01,1-0.01,20)]
for p in effects.keys():
    for x1 in xs:
        ys = []
        for x2 in xs:
            effects[p]['kls'].append(kl_divergence([x1,1-x1], [x2,1-x2]))
            original_entropy = ((entropy([x1,1-x1]) + entropy([x2,1-x2]))/2)
            
            beliefs = [x1 if a < 5 else x2 for a in range(n_agents)]
            example = {i: (b, *const) for i, b in enumerate(beliefs)}
            is_nn = 'NN' in p or 'RL' in p
            merged_reps = represent_poolers(
                list(example.values()), 
                preproc if 'Product' in p else None, 
                True if is_nn else False
            )   
            out = effects[p]['policy'](merged_reps)
            if 'Product' in p:
                out1 = lambda_update(x1, out, _lambda_new=_lambda)
                out2 = lambda_update(x2, out, _lambda_new=_lambda)
                new_ent = ((entropy([out1,1-out1]) + entropy([out2,1-out2]))/2)
            else:
                new_ent = entropy([out,1-out])
            ent_diff = new_ent - original_entropy
            effects[p]['entropy_diff'].append(ent_diff)

In [None]:
names = {
    'fromrandom init': 'Random Init.',
    'fromavg init': 'Avg. Init.',
    'fromnormlogop init': 'NormLogOp Init.',
}
for policy_to_check in effects:
    entropy_diff_array = np.array(effects[policy_to_check]['entropy_diff'])
    plt.imshow(entropy_diff_array.reshape(20, 20)[::-1], cmap='coolwarm', extent=[0, 1, 0, 1])
    plt.clim(-1,1)
    plt.colorbar()
    for k,v in names.items():
        if k in policy_to_check:
            policy_to_check = policy_to_check.replace(k, v)
    # plt.title(f'$\Delta$ Avg. Entropy: {policy_to_check}')
    plt.xlabel('$X_{B}$')
    plt.ylabel('$X_{A}$')
    print(policy_to_check)
    plt.show()

In [None]:
plt.figure(figsize=(10, 5))
for e in effects:
    plt.scatter(effects[e]['kls'], effects[e]['entropy_diff'], label=e, marker='D' if 'RL' in e else 'o')
plt.xlabel('KL Divergence (between $X_{A}$, $X_{B}$)')
plt.ylabel('$\Delta$ Entropy (post-pooling)')
plt.legend()
plt.title('KL Divergence vs Entropy Difference, homogeneous swarm, half with $X_{A}$, half with $X_{B}$')

In [None]:
variances = {}
xs = [round(x,4) for x in np.linspace(0.01,1-0.01,20)]
repeats = 200
for name, policy in [
    ['Imitation NN NormLogOp', NN_POLICIES['Imitation NN - NormLogOp']],
    ['RL 2M', NN_POLICIES['RL 2M']], 
    ['Avg', average], 
    ['LogOp $w_i=w_{truth}$', prodop], 
]:
    variances[name] = []
    for x1 in xs:
        ys = []
        for x2 in xs:
            samples = []
            for repeat in range(repeats):
                beliefs = INITIALISATION(10)
                beliefs = [x1 if a < 5 else x2 for a in range(n_agents)]
                example = {i: (b, *const) for i, b in enumerate(beliefs)}
                is_nn = True if ('RL' in name or 'NN' in name) else False
                merged_reps = represent_poolers(list(example.values()), None, is_nn)
                out = policy(merged_reps)
                samples.append(out)
            stat = np.var(samples)
            ys.append(stat)
        variances[name].append(ys)

In [None]:
cap = None
for name, variance in variances.items():
    variance_arr = np.array(variance)
    # plt.figure(figsize=(4,4))
    plt.imshow(variance_arr.reshape(20, 20)[::-1], cmap='coolwarm', extent=[0, 1, 0, 1])
    plt.colorbar().set_label('$\sigma^2$')
    plt.title(f"{name}")
    plt.xlabel('$X_{B}$')
    plt.ylabel('$X_{A}$')
    if cap:
        plt.clim(1e-8,cap)
    plt.show()

In [None]:
rows = []
attr_name_to_index = {
    '$x$': 0,
    '$w_{evidence}$': 1,
    '$w_{truth}$': 2,
    '$w_{pool}$': 3,
    '$w_{noise}$': 4,
    '$w_{doubt}$': 5
}
attr_vals = {a: [] for a in attr_name_to_index}
n_agents = 10

chosen = list(NN_POLICIES.keys())[3]
network = NN_POLICIES[chosen].net

for _ in range(200):
    agents = INITIALISATION(n_agents)
    merged_reps = represent_poolers(list(agents.values()), None, True)
    with torch.no_grad():
        input_tens = torch.tensor(merged_reps, dtype=torch.float32).unsqueeze(0)
        projected = network.lin1(input_tens[0])
    
    for rep, agent in zip(projected, agents.values()):
        rows.append(rep.numpy())
        for a_name, a_index in attr_name_to_index.items():
            attr_vals[a_name].append(agent[a_index])

In [None]:
df = pd.DataFrame(rows)
pca = decomposition.PCA(n_components=5)
transformed = pca.fit_transform(df)
[round(x*100, 1) for x in pca.explained_variance_ratio_]

In [None]:
for attr_name, attr in attr_vals.items():   
    plt.scatter(transformed[:,0], transformed[:,1], c=attr, cmap='coolwarm')
    plt.colorbar().set_label(attr_name)
    print(chosen)
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.show()

In [None]:
idempotent = {}
xs = np.linspace(0,1,1000)
for x in xs:
    example = {i: (x, *const) for i in range(10)}
    for name, policy in [
        ['Imitation NN - Avg', NN_POLICIES['Imitation NN - Avg']],
        ['Reinforcement Learned (ours)', NN_POLICIES['RL 2M']], 
        ['Avg', average], 
        ['LogOp $w_i=w_{truth}$', prodop], 
    ]:
        if name not in idempotent:
            idempotent[name] = []
        is_nn = True if ('Learned' in name or 'NN' in name) else False
        merged_reps = represent_poolers(list(example.values()), None, is_nn)
        out = policy(merged_reps)
        idempotent[name].append(out)

In [None]:
for name, vals in idempotent.items():
    plt.plot(xs, vals, label=name)
plt.legend()
plt.ylabel('Post-pooling belief')
plt.xlabel('Pre-pooling belief');

In [None]:
indiv_bayes = {}
xs = np.linspace(0,1,100)
for w_doubt in xs:
    for name, policy in [
        ['Imitation NN - Avg', NN_POLICIES['Imitation NN - Avg']],
        ['Reinforcement Learned (ours)', NN_POLICIES['RL 2M']], 
        ['Avg', average], 
        ['LogOp $w_i=w_{truth}$', prodop], 
    ]:
        befores = []
        afters = []
        for _ in range(100):
            example = INITIALISATION(10)
            example = {i: (b[0], 1, 1, 1, 0, w_doubt) for i, b in example.items()}

            if name not in indiv_bayes:
                indiv_bayes[name] = []
            is_nn = True if ('Learned' in name or 'NN' in name) else False
            
            # if pool second
            from swarmi.game import receive_evidence
            new_agents = {i: [receive_evidence(a, 1) or a[0], *a[1:]] for i, a in example.items()}
            merged_reps = represent_poolers(list(new_agents.values()), None, is_nn)
            before_out = policy(merged_reps)
            befores.append(float(before_out))
            
            # if pool second
            merged_reps = represent_poolers(list(new_agents.values()), None, is_nn)
            before_out = policy(merged_reps)
            new_agents = {i: [receive_evidence([before_out, *a[1:]], 1) or before_out, *a[1:]] for i, a in example.items()}
            after_out = sum([a[0] for a in new_agents.values()])/10
            afters.append(float(after_out))

        indiv_bayes[name].append(sum([abs(x-y) for x,y in zip(afters, befores)])/len(afters))

In [None]:
for name, vals in indiv_bayes.items():
    plt.scatter(xs, vals, label=name)
plt.legend()
plt.ylabel('$|x_{belief}^{pool.first} - x_{belief}^{pool.second}|$')
plt.xlabel('$w_{doubt}$');

## Training Metrics

In [None]:
name_to_tag = {
    'Random': 'fromrandom',
    'Weighted Average': 'fromavg',
    'NormLogOp': 'fromnormlogop',
    # 'Average - $\sum$ Agg': 'fromsumaggavg',
}

In [None]:
losses = {}
for name, tag in name_to_tag.items():
    
    with open(f"./plots/{tag}_100k_rewards_0.txt", 'r') as f:
        rewards = eval(f.read())
    
    chkpoints = [x for x in os.listdir('./models') if tag in x and 'done' not in x]    
    intervals = [int(x.split(f'{tag}_0_')[-1].replace('.pt', '')) for x in chkpoints]
    intervals = sorted(intervals)
    max_ = intervals[-1]+(intervals[-1]-intervals[-2])
    
    losses[name] = {'reward': rewards, 'num_episodes': np.linspace(0, max_, len(rewards)).tolist()}

In [None]:
plt.figure(figsize=(8, 6))
for name, vals in losses.items():
    plt.plot([x/1_000_000 for x in vals['num_episodes']], vals['reward'], label=name, marker='o', markersize=6)
plt.xlabel('# RL Episodes (millions)', fontsize=16)
plt.ylabel('Training Reward', fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.legend(fontsize=15)
plt.xlim(0, 5.1)