In [1]:
# Parameters
output_path = "../../data/artificial_humans/ahm_1_0"
data_file = "../../data/experiments/pilot_random1_player_round_slim.csv"
device = "cpu"
labels = {}
experiment_names = ["trail_rounds_2"]
fraction_training = 1.0
mask_name = "manager_valid"
model_args = {
    "add_edge_model": True,
    "add_global_model": False,
    "add_rnn": False,
    "hidden_size": 5,
    "u_encoding": [
        {"encoding": "numeric", "n_levels": 16, "name": "round_number"},
        {"etype": "float", "name": "prev_common_good", "norm": 32},
    ],
    "x_encoding": [
        {"encoding": "numeric", "n_levels": 21, "name": "contributions"},
        {"encoding": "numeric", "n_levels": 31, "name": "prev_punishments"},
        {"etype": "bool", "name": "valid"},
        {"etype": "bool", "name": "prev_manager_valid"},
    ],
    "y_levels": 31,
    "y_name": "punishments",
}
model_name = "graph"
n_cross_val = None
n_player = 4
optimizer_args = {"lr": 0.0005, "weight_decay": 0.0001}
shuffle_features = [
    "prev_punishments",
    "contributions",
    "prev_common_good",
    "valid",
    "prev_manager_valid",
]
train_args = {"batch_size": 20, "clamp_grad": 1, "epochs": 5000, "eval_period": 50}


In [2]:
%load_ext autoreload
%autoreload 2

import os
import pandas as pd
import torch as th
from aimanager.generic.data import create_syn_data, create_torch_data, get_cross_validations
from aimanager.artificial_humans import AH_MODELS
from aimanager.artificial_humans.evaluation import Evaluator
from aimanager.utils.array_to_df import using_multiindex
from aimanager.generic.graph_encode import create_fully_connected
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader

output_path = os.path.join(output_path, 'data')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
df = pd.read_csv(data_file)

df = df[df['experiment_name'].isin(experiment_names)]

data, default_values = create_torch_data(df)
# syn_data = create_syn_data(n_contribution=21, n_punishment=31, default_values=default_values)

In [4]:
th_device = th.device(device)

metrics = []
confusion_matrix = []
syn_pred = []
ev = Evaluator()

th_device = th.device(device)

syn_index = ['prev_punishments', 'prev_contributions']
edge_index = create_fully_connected(n_player)


def shuffle_feature(data, feature_name):
    data = {**data}
    data[feature_name] = data[feature_name][th.randperm(len(data[feature_name]))]
    return data

for i, (train_data, test_data) in enumerate(get_cross_validations(data, n_cross_val, fraction_training)):
    model = AH_MODELS[model_name](default_values=default_values, **model_args).to(th_device)

    train_data_ = model.encode(train_data, mask=mask_name, edge_index=edge_index)
    if test_data is not None:
        test_data_ = model.encode(test_data, mask=mask_name, edge_index=edge_index)
    # syn_data_ = model.encode(syn_data, mask=None, y_encode=False, info_columns=syn_index, edge_index=edge_index)

    # syn_df = using_multiindex(
    #     Batch.from_data_list(syn_data_)['info'].detach().cpu().numpy(), ['idx', 'round_number'], syn_index)

    optimizer = th.optim.Adam(model.parameters(), **optimizer_args)
    loss_fn = th.nn.CrossEntropyLoss(reduction='none')
    sum_loss = 0
    n_steps = 0

    for e in range(train_args['epochs']):
        ev.set_labels(cv_split=i, epoch=e)
        model.train()
        for j, batch_data in enumerate(iter(DataLoader(train_data_, shuffle=True, batch_size=train_args['batch_size']))):
            optimizer.zero_grad()
            py = model(batch_data).flatten(end_dim=-2)
            y_true = batch_data['y_enc'].flatten(end_dim=-2)
            mask = batch_data['mask'].flatten()
            loss = loss_fn(py, y_true)
            loss = (loss * mask).sum() / mask.sum()

            loss.backward()

            if train_args['clamp_grad']:
                for param in model.parameters():
                    param.grad.data.clamp_(-train_args['clamp_grad'], train_args['clamp_grad'])
            optimizer.step()
            sum_loss += loss.item()
            n_steps +=1
        
        last_epoch = e == (train_args['epochs'] - 1)

        if (e % train_args['eval_period'] == 0) or last_epoch:
            avg_loss = sum_loss/n_steps
            print(f'CV {i} | Epoch {e} | Loss {avg_loss}')
            ev.add_loss(avg_loss)

            ev.eval_set(model, train_data_, calc_confusion=last_epoch, set='train')
            if test_data is not None:
                ev.eval_set(model, test_data_, calc_confusion=last_epoch, set='test')
                for sf in shuffle_features:
                    shuffled_data = shuffle_feature(test_data, sf)
                    shuffled_data = model.encode(shuffled_data, mask=mask_name, edge_index=edge_index)
                    ev.eval_set(model, shuffled_data, calc_confusion=False, set='test', shuffle_feature=sf)
            sum_loss = 0
            n_steps = 0
    # ev.eval_syn(model, syn_data_, syn_df)

ev.save(output_path, labels)
model_path = os.path.join(output_path, 'model.pt')
model.save(model_path)

CV 0 | Epoch 0 | Loss 3.500137948989868


CV 0 | Epoch 50 | Loss 3.0759399843215944


CV 0 | Epoch 100 | Loss 1.888313014984131


CV 0 | Epoch 150 | Loss 1.6561043243408202


CV 0 | Epoch 200 | Loss 1.6194266653060914


CV 0 | Epoch 250 | Loss 1.5858565196990966


CV 0 | Epoch 300 | Loss 1.5752317218780518


CV 0 | Epoch 350 | Loss 1.5522488424777985


CV 0 | Epoch 400 | Loss 1.5347375345230103


CV 0 | Epoch 450 | Loss 1.521586466550827


CV 0 | Epoch 500 | Loss 1.5136036217212676


CV 0 | Epoch 550 | Loss 1.5003368849754333


CV 0 | Epoch 600 | Loss 1.4935480971336366


CV 0 | Epoch 650 | Loss 1.484337068080902


CV 0 | Epoch 700 | Loss 1.48970307803154


CV 0 | Epoch 750 | Loss 1.4779928183555604


CV 0 | Epoch 800 | Loss 1.4604814867973328


CV 0 | Epoch 850 | Loss 1.4640244891643523


CV 0 | Epoch 900 | Loss 1.4409496965408326


CV 0 | Epoch 950 | Loss 1.4605785667896272


CV 0 | Epoch 1000 | Loss 1.4495114312171935


CV 0 | Epoch 1050 | Loss 1.4487529826164245


CV 0 | Epoch 1100 | Loss 1.433277945280075


CV 0 | Epoch 1150 | Loss 1.4414787995815277


CV 0 | Epoch 1200 | Loss 1.4305339498519898


CV 0 | Epoch 1250 | Loss 1.444915419816971


CV 0 | Epoch 1300 | Loss 1.4270888080596924


CV 0 | Epoch 1350 | Loss 1.4351849353313446


CV 0 | Epoch 1400 | Loss 1.4284181566238403


CV 0 | Epoch 1450 | Loss 1.427361958026886


CV 0 | Epoch 1500 | Loss 1.4192534883022307


CV 0 | Epoch 1550 | Loss 1.422819787979126


CV 0 | Epoch 1600 | Loss 1.4158158876895905


CV 0 | Epoch 1650 | Loss 1.4125487740039826


CV 0 | Epoch 1700 | Loss 1.4199052312374114


CV 0 | Epoch 1750 | Loss 1.4128065259456635


CV 0 | Epoch 1800 | Loss 1.412502495288849


CV 0 | Epoch 1850 | Loss 1.4078909990787507


CV 0 | Epoch 1900 | Loss 1.4167726953029633


CV 0 | Epoch 1950 | Loss 1.4153132193088531


CV 0 | Epoch 2000 | Loss 1.4093877260684966


CV 0 | Epoch 2050 | Loss 1.411502007484436


CV 0 | Epoch 2100 | Loss 1.4118750123977661


CV 0 | Epoch 2150 | Loss 1.4018728563785552


CV 0 | Epoch 2200 | Loss 1.4072290992736816


CV 0 | Epoch 2250 | Loss 1.3929835443496703


CV 0 | Epoch 2300 | Loss 1.394537209033966


CV 0 | Epoch 2350 | Loss 1.385819105386734


CV 0 | Epoch 2400 | Loss 1.391748658657074


CV 0 | Epoch 2450 | Loss 1.402610449552536


CV 0 | Epoch 2500 | Loss 1.3940145552158356


CV 0 | Epoch 2550 | Loss 1.3874357146024705


CV 0 | Epoch 2600 | Loss 1.3884948437213898


CV 0 | Epoch 2650 | Loss 1.3878671865463257


CV 0 | Epoch 2700 | Loss 1.3861274354457855


CV 0 | Epoch 2750 | Loss 1.3853743736743926


CV 0 | Epoch 2800 | Loss 1.3899763482809067


CV 0 | Epoch 2850 | Loss 1.385875095129013


CV 0 | Epoch 2900 | Loss 1.3748401670455932


CV 0 | Epoch 2950 | Loss 1.3703460640907288


CV 0 | Epoch 3000 | Loss 1.3779486112594606


CV 0 | Epoch 3050 | Loss 1.3654646261930465


CV 0 | Epoch 3100 | Loss 1.3675807099342345


CV 0 | Epoch 3150 | Loss 1.3758512907028198


CV 0 | Epoch 3200 | Loss 1.3707515325546265


CV 0 | Epoch 3250 | Loss 1.365604222536087


CV 0 | Epoch 3300 | Loss 1.3698775525093079


CV 0 | Epoch 3350 | Loss 1.3586621539592743


CV 0 | Epoch 3400 | Loss 1.3641138956546783


CV 0 | Epoch 3450 | Loss 1.3665747983455658


CV 0 | Epoch 3500 | Loss 1.3655056936740875


CV 0 | Epoch 3550 | Loss 1.359401424884796


CV 0 | Epoch 3600 | Loss 1.3654948635101318


CV 0 | Epoch 3650 | Loss 1.3613304667472839


CV 0 | Epoch 3700 | Loss 1.3618234689235686


CV 0 | Epoch 3750 | Loss 1.3542640581130982


CV 0 | Epoch 3800 | Loss 1.363657231092453


CV 0 | Epoch 3850 | Loss 1.3552354218959808


CV 0 | Epoch 3900 | Loss 1.3665721395015717


CV 0 | Epoch 3950 | Loss 1.3526167476177215


CV 0 | Epoch 4000 | Loss 1.34828435587883


CV 0 | Epoch 4050 | Loss 1.348082805633545


CV 0 | Epoch 4100 | Loss 1.3479335088729858


CV 0 | Epoch 4150 | Loss 1.3435925352573395


CV 0 | Epoch 4200 | Loss 1.3517408220767975


CV 0 | Epoch 4250 | Loss 1.3494736757278443


CV 0 | Epoch 4300 | Loss 1.345869161605835


CV 0 | Epoch 4350 | Loss 1.342041302204132


CV 0 | Epoch 4400 | Loss 1.346231059551239


CV 0 | Epoch 4450 | Loss 1.33980513381958


CV 0 | Epoch 4500 | Loss 1.3473378250598906


CV 0 | Epoch 4550 | Loss 1.3351399698257447


CV 0 | Epoch 4600 | Loss 1.329675498008728


CV 0 | Epoch 4650 | Loss 1.3388611211776733


CV 0 | Epoch 4700 | Loss 1.336333280324936


CV 0 | Epoch 4750 | Loss 1.3457052099704743


CV 0 | Epoch 4800 | Loss 1.3381726024150848


CV 0 | Epoch 4850 | Loss 1.3358524692058564


CV 0 | Epoch 4900 | Loss 1.3236241834163667


CV 0 | Epoch 4950 | Loss 1.3331311650276183


CV 0 | Epoch 4999 | Loss 1.3388474890163966
