In [None]:
import json
import os
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from swarmi.net import Net, normalise, MultiNet
from swarmi.operators import average_sampler_11D, normmeanprodop_sampler_11D

In [3]:
num_points = int(1e4)
max_pooling_agents = 20
min_pooling_agents = 10

n_hypotheses = 3
op_name = 'Average'

data = []
for _ in range(num_points):
    if op_name == 'Average':
        x_values, w1s, w2s, w3s, w4s, w5s, target = average_sampler_11D(max_pooling_agents, min_pooling_agents, n_hypotheses, True, False)
        data.append([[[*x] for x in zip(x_values, w1s, w2s, w3s, w4s, w5s)], target])
    elif op_name == 'NormLogOp':
        x_values, w1s, w2s, w3s, w4s, w5s, target = normmeanprodop_sampler_11D(max_pooling_agents, min_pooling_agents, n_hypotheses, True, False)
        data.append([[[*x] for x in zip(x_values, w1s, w2s, w3s, w4s, w5s)], target])
    else:
        raise ValueError(f'Operator {op_name} not recognised')

In [4]:
device = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'

In [22]:
xs = []
ys = []
for sample in data:
    formatted_agents = []
    for agent in sample[0]:
        belief_list = agent[0].tolist()    
        flat_sample = [*belief_list, *agent[1:]]
        formatted_agents.append(flat_sample)
    
    input_sample = torch.tensor(formatted_agents).squeeze(0).to(device)
    input_sample = normalise(input_sample)
    output = torch.tensor(sample[1]).to(device)
    
    xs.append(input_sample)
    ys.append(output)

n_train = int(0.80 * num_points)
xs_train, xs_test = xs[:n_train], xs[n_train:]
ys_train, ys_test = torch.vstack(ys[:n_train]), torch.vstack(ys[n_train:])

In [None]:
hidden_dim = int(
    torch.ceil(torch.exp(torch.tensor(min(max_pooling_agents ** (0.5), xs[0][0].shape[0])))).item()
)
neural_net = MultiNet(hidden_dim, n_hypotheses, agg_fn=torch.sum).to(device)
optimiser = torch.optim.Adam(neural_net.parameters(), lr=3e-3)

tr_losses = []
te_losses = []
NUM_EPOCHS = 50
SAVE_MODEL = False
LOG_INTERVAL = 10

In [24]:
IMITATED_AVG_VARY_AGENTS_20MAXAGENTS_AVG_AGG_6D_MODEL_PATH = "./models/imitated_avg_10k_50e_150width_20maxagents_6D_normalised_avg_agg.pt"
IMITATED_AVG_VARY_AGENTS_20MAXAGENTS_SUM_AGG_6D_MODEL_PATH = "./models/imitated_avg_10k_50e_150width_20maxagents_6D_normalised_sum_agg.pt"
IMITATED_NORMPRODOP_VARY_AGENTS_20MAXAGENTS_AVG_AGG_6D_MODEL_PATH = "./models/imitated_normprodop_10k_50e_150width_20maxagents_6D_normalised_avg_agg.pt"
IMITATED_MODEL_PATH = "testing_multi_net.pt"

In [None]:
for e in range(NUM_EPOCHS):
    with torch.no_grad():
        test_out = neural_net(xs_test)
        test_loss = F.mse_loss(ys_test, test_out)
        te_losses.append(test_loss.item())
    
    if e % LOG_INTERVAL == 0 or e == NUM_EPOCHS - 1 or e == 0:
        print(f"Epoch {e}: pred={test_out[0]} real={ys_test[0]} loss={round(te_losses[-1], 4)}")

    optimiser.zero_grad()
    outs = neural_net(xs_train)
    tr_loss = F.mse_loss(ys_train, outs)
    tr_loss.backward()
    optimiser.step()
    tr_losses.append(tr_loss.item())

In [26]:
if e == NUM_EPOCHS - 1 and SAVE_MODEL:
    torch.save(neural_net, IMITATED_MODEL_PATH)

In [27]:
# optional saving of results
SAVE_RESULT = True
if SAVE_RESULT:
    if not os.path.exists('scale_imitation_results.json'):
        with open('scale_imitation_results.json', mode='w') as f:
            json.dump({}, f)
    with open('scale_imitation_results.json', mode='r') as f:
        json_scores = json.load(f)
        if op_name not in json_scores:
            json_scores[op_name] = {}
        json_scores[op_name][str(n_hypotheses)] = te_losses

    with open('scale_imitation_results.json', mode='w') as f:
        json.dump(json_scores, f)

In [None]:
with open('scale_imitation_results.json', mode='r') as f:
    json_scores = json.load(f)

belief_ordered_results = {}
for op, belief_scores in json_scores.items():
    belief_scores = sorted(belief_scores.items(), key=lambda x: int(x[0]))
    for belief, scores in belief_scores:
        if belief not in belief_ordered_results:
            belief_ordered_results[belief] = {}
        belief_ordered_results[belief][op] = scores 

for belief, belief_scores in belief_ordered_results.items():
    print(belief)
    plt.figure(figsize=(10, 6))
    for op, scores in belief_scores.items():
        plt.plot(scores, marker='o', markersize=6, label=op)
    plt.xlabel('# Epochs', fontsize=16)
    plt.ylabel('MSE Test Loss', fontsize=16)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.legend(fontsize=15)
    plt.show()