In [None]:
import functools
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as tkr
from mpl_toolkits.mplot3d import Axes3D

In [None]:
from swarmi.operators import (
    prodop,
    NeuralNetOp,
    average
)
from swarmi.net import normalise

In [None]:
BASE_POLICIES = {
    "Avg": average,
    "LogOp": prodop,
    "SLogOp": functools.partial(prodop, normalise_w=True),
    "ProdOp": functools.partial(prodop, set_w_to_one=True),
    "SProdOp": functools.partial(prodop, set_w_to_one=True, normalise_w=True),
}
NN_POLICIES = {
    'Imitation NN': NeuralNetOp("./models/imitated_logop_10k_100e_150width_2maxagents.pt"),
    'General RL + Imitation NN': NeuralNetOp('./models/rl_test_classic_n3_p2_1m'),
    'Specific RL + Imitation NN': NeuralNetOp('./models/rl_test_classic_n3_p2_morenoise'),
    'yo': NeuralNetOp('./models/rl_test_classic_test')
}
ALL_POLICIES = BASE_POLICIES|NN_POLICIES

In [None]:
chosen_policy = 'yo'

policy = ALL_POLICIES[chosen_policy]
if hasattr(policy, 'net'):
    policy.net.debug = False

GRID_DIM = 20
evidence = [1,1,1]
ys = zs = np.linspace(0,1,GRID_DIM)
values = [0,0.25,0.5,0.75,1]

grid_scores = []
for i, x in enumerate(values):
    grid_scores.append([])
    for j, y in enumerate(ys):
        grid_scores[i].append([])
        for z in zs:
            # need to normalise first
            w1_mass = evidence[1]+y
            w2_mass = evidence[2]+z
            inputs_to_model =  [[evidence[0], evidence[1], evidence[1]/w1_mass, evidence[2], evidence[2]/w2_mass], [x, y, y/w1_mass, z, z/w2_mass]]
            if 'Imitation' in chosen_policy:
                inputs_to_model = [[normalise(x) for x in b] for b in inputs_to_model]
            scalar = policy(inputs_to_model)
            grid_scores[i][j].append(scalar)

fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(111, projection='3d')

grid_scores = np.array(grid_scores)
xs_3d, ys_3d, zs_3d = np.meshgrid(values, ys, zs)
scatter = ax.scatter(xs_3d.flatten(), ys_3d.flatten(), zs_3d.flatten(), c=grid_scores.flatten(), cmap='viridis', vmin=0, vmax=1)
cbar = plt.colorbar(scatter)
cbar.set_label('Pooled Belief')

ax.set_ylabel('$X_{belief}$')
ax.set_xlabel('$P_{location}$')
ax.set_zlabel('$P_{truth}$')
print('Min / Max: ', [round(max(grid_scores.flatten()), 2), round(min(grid_scores.flatten()), 2)])
ax.set_title(f'Outputs when using {chosen_policy=} with {evidence=}')
plt.show()

In [None]:
def num_fmt(x, pos):
    return "{}".format(x / GRID_DIM)

for i, x in enumerate(values):
    grid_cut = grid_scores[i]
    fig, ax = plt.subplots()
    img = ax.imshow(grid_cut, cmap='viridis')
    yfmt = tkr.FuncFormatter(num_fmt)
    ax.yaxis.set_major_formatter(yfmt)
    ax.xaxis.set_major_formatter(yfmt)
    ax.set_title(f"Output with {x=} and {evidence=}")
    ax.set_ylabel("$P_{location}$")
    ax.set_xlabel("$P_{truth}$")
    fig.colorbar(img)
    plt.show()

In [None]:
heatmaps = {}
const = .5
grid = np.linspace(0,1, GRID_DIM)

for p, policy in ALL_POLICIES.items():
    ys = []
    for i, x in enumerate(grid):
        ys.append([])
        for y in grid:
            inputs = [[x, const], [y, const]]
            if 'Imitation' in p or 'yo' in p:
                filler = [normalise(const)]*4
                inputs = [[normalise(x)]+filler, [normalise(y)]+filler]
            out = policy(inputs)
            ys[i].append(out)
    heatmaps[p] = ys

In [None]:
for p, ys in heatmaps.items():
    fig, ax = plt.subplots()
    ax.figure.set_figwidth(7)
    ax.figure.set_figheight(7)
    img = ax.imshow(ys, cmap='viridis', clim=(0,1))
    yfmt = tkr.FuncFormatter(num_fmt)
    ax.yaxis.set_major_formatter(yfmt)
    ax.xaxis.set_major_formatter(yfmt)
    ax.set_title(f"{p} \n equal confidence - $X_{1}$ & $X_{2}$ - max / min: {round(max(np.array(ys).flatten()), 2)} / {round(min(np.array(ys).flatten()), 2)}")
    # ax.set_ylabel("$X_{1}$")
    # ax.set_xlabel("$X_{2}$")
    # fig.colorbar(img)
    plt.show()

In [None]:
heatmaps = {}
const_1 = .25
const_2 = 0.75
delta = 1e-10
grid = np.linspace(delta,1-delta, GRID_DIM)

for p, policy in ALL_POLICIES.items():
    ys = []
    for i, x in enumerate(grid):
        ys.append([])
        for y in grid:
            inputs = [[const_1, x], [const_2, y]]
            if 'Imitation' in p or 'yo' in p:
                w_sum = x+y
                inputs = [[const_1, x, x/w_sum, 0.50, 0.50], [const_2, y, y/w_sum, 0.50, 0.50]]
                inputs = [[normalise(k) for k in b] for b in inputs]
            out = policy(inputs)
            ys[i].append(out)
    heatmaps[p] = ys

In [None]:
for p, ys in heatmaps.items():
    fig, ax = plt.subplots()
    ax.figure.set_figwidth(7)
    ax.figure.set_figheight(7)
    img = ax.imshow(ys, cmap='viridis', clim=(0,1))
    yfmt = tkr.FuncFormatter(num_fmt)
    ax.yaxis.set_major_formatter(yfmt)
    ax.xaxis.set_major_formatter(yfmt)
    ax.set_title(f"{p} \n $X_1={const_1}$, $X_2={const_2}$, max / min: {round(max(np.array(ys).flatten()), 2)} / {round(min(np.array(ys).flatten()), 2)}")
    ax.set_ylabel("$w^1_{location}$")
    ax.set_xlabel("$w^2_{location}$")
    fig.colorbar(img)
    plt.show()

In [None]:
heatmaps = {}
const_1 = .25
const_2 = 0.75
delta = 1e-10
grid = np.linspace(delta,1-delta, GRID_DIM)

for p, policy in ALL_POLICIES.items():
    ys = []
    for i, x in enumerate(grid):
        ys.append([])
        for y in grid:
            inputs = [[const_1, x], [const_2, y]]
            if 'Imitation' in p or 'yo' in p:
                w_sum = x+y
                inputs = [[const_1, 0.50, 0.50, x, x/w_sum], [const_2, 0.50, 0.50, y, y/w_sum]]
                inputs = [[normalise(k) for k in b] for b in inputs]
            out = policy(inputs)
            ys[i].append(out)
    heatmaps[p] = ys

In [None]:
for p, ys in heatmaps.items():
    fig, ax = plt.subplots()
    ax.figure.set_figwidth(7)
    ax.figure.set_figheight(7)
    img = ax.imshow(ys, cmap='viridis', clim=(0,1))
    yfmt = tkr.FuncFormatter(num_fmt)
    ax.yaxis.set_major_formatter(yfmt)
    ax.xaxis.set_major_formatter(yfmt)
    ax.set_title(f"{p} \n $X_1={const_1}$, $X_2={const_2}$, max / min: {round(max(np.array(ys).flatten()), 2)} / {round(min(np.array(ys).flatten()), 2)}")
    ax.set_ylabel("$w^{1}_{truth}$")
    ax.set_xlabel("$w^{2}_{truth}$")
    fig.colorbar(img)
    plt.show()