In [1]:
# Parameters
fraction_training = 1.0
n_cross_val = None
n_player = 4
data_file = "../../data/experiments/pilot_random1_player_round_slim.csv"
output_path = "../../data/training/ah_1_1"
model_name = "graph"
shuffle_features = [
    "prev_punishments",
    "prev_contributions",
    "prev_common_good",
    "prev_valid",
]
labels = {}
model_args = {
    "hidden_size": 5,
    "add_rnn": True,
    "add_edge_model": True,
    "add_global_model": False,
    "x_encoding": [
        {"name": "prev_contributions", "n_levels": 21, "encoding": "numeric"},
        {"name": "prev_punishments", "n_levels": 31, "encoding": "numeric"},
        {"etype": "bool", "name": "prev_valid"},
    ],
    "u_encoding": [
        {"name": "round_number", "n_levels": 16, "encoding": "numeric"},
        {"name": "prev_common_good", "norm": 32, "etype": "float"},
    ],
    "y_levels": 21,
    "y_name": "contributions",
}
mask_name = "valid"
experiment_names = ["trail_rounds_2", "random_1"]
optimizer_args = {"lr": 0.001, "weight_decay": 1e-05}
train_args = {"epochs": 2000, "batch_size": 10, "clamp_grad": 1, "eval_period": 10}
device = "cpu"


In [2]:
%load_ext autoreload
%autoreload 2

import os
import pandas as pd
import torch as th
from aimanager.generic.data import create_syn_data, create_torch_data, get_cross_validations
from aimanager.artificial_humans import AH_MODELS
from aimanager.artificial_humans.evaluation import Evaluator
from aimanager.utils.array_to_df import using_multiindex
from aimanager.generic.graph_encode import create_fully_connected
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader

output_path = os.path.join(output_path, 'data')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
df = pd.read_csv(data_file)

df = df[df['experiment_name'].isin(experiment_names)]

data, default_values = create_torch_data(df)
# syn_data = create_syn_data(n_contribution=21, n_punishment=31, default_values=default_values)

In [4]:
th_device = th.device(device)

metrics = []
confusion_matrix = []
syn_pred = []
ev = Evaluator()

th_device = th.device(device)

syn_index = ['prev_punishments', 'prev_contributions']
edge_index = create_fully_connected(n_player)


def shuffle_feature(data, feature_name):
    data = {**data}
    data[feature_name] = data[feature_name][th.randperm(len(data[feature_name]))]
    return data

for i, (train_data, test_data) in enumerate(get_cross_validations(data, n_cross_val, fraction_training)):
    model = AH_MODELS[model_name](default_values=default_values, **model_args).to(th_device)

    train_data_ = model.encode(train_data, mask=mask_name, edge_index=edge_index)
    if test_data is not None:
        test_data_ = model.encode(test_data, mask=mask_name, edge_index=edge_index)
    # syn_data_ = model.encode(syn_data, mask=None, y_encode=False, info_columns=syn_index, edge_index=edge_index)

    # syn_df = using_multiindex(
    #     Batch.from_data_list(syn_data_)['info'].detach().cpu().numpy(), ['idx', 'round_number'], syn_index)

    optimizer = th.optim.Adam(model.parameters(), **optimizer_args)
    loss_fn = th.nn.CrossEntropyLoss(reduction='none')
    sum_loss = 0
    n_steps = 0

    for e in range(train_args['epochs']):
        ev.set_labels(cv_split=i, epoch=e)
        model.train()
        for j, batch_data in enumerate(iter(DataLoader(train_data_, shuffle=True, batch_size=train_args['batch_size']))):
            optimizer.zero_grad()
            py = model(batch_data).flatten(end_dim=-2)
            y_true = batch_data['y_enc'].flatten(end_dim=-2)
            mask = batch_data['mask'].flatten()
            loss = loss_fn(py, y_true)
            loss = (loss * mask).sum() / mask.sum()

            loss.backward()

            if train_args['clamp_grad']:
                for param in model.parameters():
                    param.grad.data.clamp_(-train_args['clamp_grad'], train_args['clamp_grad'])
            optimizer.step()
            sum_loss += loss.item()
            n_steps +=1
        
        last_epoch = e == (train_args['epochs'] - 1)

        if (e % train_args['eval_period'] == 0) or last_epoch:
            avg_loss = sum_loss/n_steps
            print(f'CV {i} | Epoch {e} | Loss {avg_loss}')
            ev.add_loss(avg_loss)

            ev.eval_set(model, train_data_, calc_confusion=last_epoch, set='train')
            if test_data is not None:
                ev.eval_set(model, test_data_, calc_confusion=last_epoch, set='test')
                for sf in shuffle_features:
                    shuffled_data = shuffle_feature(test_data, sf)
                    shuffled_data = model.encode(shuffled_data, mask=mask_name, edge_index=edge_index)
                    ev.eval_set(model, shuffled_data, calc_confusion=False, set='test', shuffle_feature=sf)
            sum_loss = 0
            n_steps = 0
    # ev.eval_syn(model, syn_data_, syn_df)

ev.save(output_path, labels)
model_path = os.path.join(output_path, 'model.pt')
model.save(model_path)

CV 0 | Epoch 0 | Loss 3.0724822112492154


CV 0 | Epoch 10 | Loss 2.873840967246464


CV 0 | Epoch 20 | Loss 2.527130034991673


CV 0 | Epoch 30 | Loss 2.3801232823303766


CV 0 | Epoch 40 | Loss 2.3033718568938117


CV 0 | Epoch 50 | Loss 2.232803136110306


CV 0 | Epoch 60 | Loss 2.171295028073447


CV 0 | Epoch 70 | Loss 2.118096257959093


CV 0 | Epoch 80 | Loss 2.074875008208411


CV 0 | Epoch 90 | Loss 2.0484698150839122


CV 0 | Epoch 100 | Loss 2.0179469312940324


CV 0 | Epoch 110 | Loss 1.9988083890506199


CV 0 | Epoch 120 | Loss 1.9925649302346367


CV 0 | Epoch 130 | Loss 1.974823408467429


CV 0 | Epoch 140 | Loss 1.9638534111636026


CV 0 | Epoch 150 | Loss 1.9469006930078778


CV 0 | Epoch 160 | Loss 1.9436432463782174


CV 0 | Epoch 170 | Loss 1.937765666416713


CV 0 | Epoch 180 | Loss 1.9244589362825666


CV 0 | Epoch 190 | Loss 1.9115864711148398


CV 0 | Epoch 200 | Loss 1.9087885081768037


CV 0 | Epoch 210 | Loss 1.8976905422551291


CV 0 | Epoch 220 | Loss 1.885180059501103


CV 0 | Epoch 230 | Loss 1.8864352669034685


CV 0 | Epoch 240 | Loss 1.8803684728486196


CV 0 | Epoch 250 | Loss 1.872054808480399


CV 0 | Epoch 260 | Loss 1.8731209933757782


CV 0 | Epoch 270 | Loss 1.8663375641618456


CV 0 | Epoch 280 | Loss 1.8684042649609702


CV 0 | Epoch 290 | Loss 1.8623125544616155


CV 0 | Epoch 300 | Loss 1.8629028030804224


CV 0 | Epoch 310 | Loss 1.8578201387609754


CV 0 | Epoch 320 | Loss 1.8544398188591003


CV 0 | Epoch 330 | Loss 1.8482477605342864


CV 0 | Epoch 340 | Loss 1.8475742663655963


CV 0 | Epoch 350 | Loss 1.8470849982329778


CV 0 | Epoch 360 | Loss 1.8427626201084681


CV 0 | Epoch 370 | Loss 1.8478488547461374


CV 0 | Epoch 380 | Loss 1.8446684675557272


CV 0 | Epoch 390 | Loss 1.8430589658873422


CV 0 | Epoch 400 | Loss 1.8417558593409402


CV 0 | Epoch 410 | Loss 1.8401351400784083


CV 0 | Epoch 420 | Loss 1.8367727790560042


CV 0 | Epoch 430 | Loss 1.8390623543943678


CV 0 | Epoch 440 | Loss 1.8350911276681083


CV 0 | Epoch 450 | Loss 1.8364346095493862


CV 0 | Epoch 460 | Loss 1.8357852535588401


CV 0 | Epoch 470 | Loss 1.829982076372419


CV 0 | Epoch 480 | Loss 1.836296535389764


CV 0 | Epoch 490 | Loss 1.8275393609489714


CV 0 | Epoch 500 | Loss 1.8320405244827271


CV 0 | Epoch 510 | Loss 1.8246231692177908


CV 0 | Epoch 520 | Loss 1.8309813899653298


CV 0 | Epoch 530 | Loss 1.822685844557626


CV 0 | Epoch 540 | Loss 1.8240989182676588


CV 0 | Epoch 550 | Loss 1.826633999177388


CV 0 | Epoch 560 | Loss 1.819433684859957


CV 0 | Epoch 570 | Loss 1.8199635675975254


CV 0 | Epoch 580 | Loss 1.822504763518061


CV 0 | Epoch 590 | Loss 1.821158117055893


CV 0 | Epoch 600 | Loss 1.8195791499955314


CV 0 | Epoch 610 | Loss 1.8190991333552768


CV 0 | Epoch 620 | Loss 1.8219255030155181


CV 0 | Epoch 630 | Loss 1.8146072208881379


CV 0 | Epoch 640 | Loss 1.8128644815513066


CV 0 | Epoch 650 | Loss 1.8210971508707319


CV 0 | Epoch 660 | Loss 1.8175890377589634


CV 0 | Epoch 670 | Loss 1.8152562754494803


CV 0 | Epoch 680 | Loss 1.8171730756759643


CV 0 | Epoch 690 | Loss 1.8079574116638728


CV 0 | Epoch 700 | Loss 1.8112033035073962


CV 0 | Epoch 710 | Loss 1.8099025070667267


CV 0 | Epoch 720 | Loss 1.808925919021879


CV 0 | Epoch 730 | Loss 1.8123317565236772


CV 0 | Epoch 740 | Loss 1.8068795382976532


CV 0 | Epoch 750 | Loss 1.8140386181218284


CV 0 | Epoch 760 | Loss 1.806416270988328


CV 0 | Epoch 770 | Loss 1.8133113469396318


CV 0 | Epoch 780 | Loss 1.8089951012815748


CV 0 | Epoch 790 | Loss 1.8041754177638463


CV 0 | Epoch 800 | Loss 1.8030495720250266


CV 0 | Epoch 810 | Loss 1.8097633353301457


CV 0 | Epoch 820 | Loss 1.8052352556160518


CV 0 | Epoch 830 | Loss 1.8076482338564737


CV 0 | Epoch 840 | Loss 1.8183183985097067


CV 0 | Epoch 850 | Loss 1.807883163860866


CV 0 | Epoch 860 | Loss 1.8069151937961578


CV 0 | Epoch 870 | Loss 1.805785711322512


CV 0 | Epoch 880 | Loss 1.8088448601109641


CV 0 | Epoch 890 | Loss 1.8054460746901375


CV 0 | Epoch 900 | Loss 1.80837921500206


CV 0 | Epoch 910 | Loss 1.799895258460726


CV 0 | Epoch 920 | Loss 1.8023620196751187


CV 0 | Epoch 930 | Loss 1.80244123339653


CV 0 | Epoch 940 | Loss 1.7973296139921462


CV 0 | Epoch 950 | Loss 1.8014558034283774


CV 0 | Epoch 960 | Loss 1.8013598663466317


CV 0 | Epoch 970 | Loss 1.8031348892620631


CV 0 | Epoch 980 | Loss 1.7944554201194218


CV 0 | Epoch 990 | Loss 1.7988085716962814


CV 0 | Epoch 1000 | Loss 1.7965721453939165


CV 0 | Epoch 1010 | Loss 1.7988810207162584


CV 0 | Epoch 1020 | Loss 1.8017283924988339


CV 0 | Epoch 1030 | Loss 1.7969008335045407


CV 0 | Epoch 1040 | Loss 1.796467420884541


CV 0 | Epoch 1050 | Loss 1.7895247297627586


CV 0 | Epoch 1060 | Loss 1.7950963871819632


CV 0 | Epoch 1070 | Loss 1.7948219512190138


CV 0 | Epoch 1080 | Loss 1.8035139194556644


CV 0 | Epoch 1090 | Loss 1.7958407512732915


CV 0 | Epoch 1100 | Loss 1.798167793239866


CV 0 | Epoch 1110 | Loss 1.7992566977228437


CV 0 | Epoch 1120 | Loss 1.7929772555828094


CV 0 | Epoch 1130 | Loss 1.7909652386392867


CV 0 | Epoch 1140 | Loss 1.790451397214617


CV 0 | Epoch 1150 | Loss 1.7899292009217398


CV 0 | Epoch 1160 | Loss 1.7910017848014832


CV 0 | Epoch 1170 | Loss 1.794347104855946


CV 0 | Epoch 1180 | Loss 1.7943556700434005


CV 0 | Epoch 1190 | Loss 1.7844875855105264


CV 0 | Epoch 1200 | Loss 1.7956273325851986


CV 0 | Epoch 1210 | Loss 1.7896426047597613


CV 0 | Epoch 1220 | Loss 1.7912925090108598


CV 0 | Epoch 1230 | Loss 1.7865520886012487


CV 0 | Epoch 1240 | Loss 1.7935820758342742


CV 0 | Epoch 1250 | Loss 1.7890424762453352


CV 0 | Epoch 1260 | Loss 1.7873572749750954


CV 0 | Epoch 1270 | Loss 1.7848016943250384


CV 0 | Epoch 1280 | Loss 1.793694519996643


CV 0 | Epoch 1290 | Loss 1.7900516842092786


CV 0 | Epoch 1300 | Loss 1.7816607807363782


CV 0 | Epoch 1310 | Loss 1.7896257996559144


CV 0 | Epoch 1320 | Loss 1.7896881665502276


CV 0 | Epoch 1330 | Loss 1.7876452028751373


CV 0 | Epoch 1340 | Loss 1.7778811940125057


CV 0 | Epoch 1350 | Loss 1.7852874747344425


CV 0 | Epoch 1360 | Loss 1.7904097718851908


CV 0 | Epoch 1370 | Loss 1.7840905734470913


CV 0 | Epoch 1380 | Loss 1.7757891493184226


CV 0 | Epoch 1390 | Loss 1.7846561568123953


CV 0 | Epoch 1400 | Loss 1.7852755427360534


CV 0 | Epoch 1410 | Loss 1.7818754630429403


CV 0 | Epoch 1420 | Loss 1.7832771590777805


CV 0 | Epoch 1430 | Loss 1.7807730989796775


CV 0 | Epoch 1440 | Loss 1.7840175083705356


CV 0 | Epoch 1450 | Loss 1.7819494145257133


CV 0 | Epoch 1460 | Loss 1.7792614076818738


CV 0 | Epoch 1470 | Loss 1.7757945043700083


CV 0 | Epoch 1480 | Loss 1.780961606332234


CV 0 | Epoch 1490 | Loss 1.7808467039040157


CV 0 | Epoch 1500 | Loss 1.782241109439305


CV 0 | Epoch 1510 | Loss 1.7740887769630977


CV 0 | Epoch 1520 | Loss 1.7767156549862453


CV 0 | Epoch 1530 | Loss 1.768018708058766


CV 0 | Epoch 1540 | Loss 1.7763544636113302


CV 0 | Epoch 1550 | Loss 1.7756750319685255


CV 0 | Epoch 1560 | Loss 1.7811514939580644


CV 0 | Epoch 1570 | Loss 1.7709945491382053


CV 0 | Epoch 1580 | Loss 1.7776046182428087


CV 0 | Epoch 1590 | Loss 1.7738327179636275


CV 0 | Epoch 1600 | Loss 1.7733334941523415


CV 0 | Epoch 1610 | Loss 1.7784218209130422


CV 0 | Epoch 1620 | Loss 1.7734043530055454


CV 0 | Epoch 1630 | Loss 1.7699183745043618


CV 0 | Epoch 1640 | Loss 1.7729328836713518


CV 0 | Epoch 1650 | Loss 1.771484432901655


CV 0 | Epoch 1660 | Loss 1.7739495098590852


CV 0 | Epoch 1670 | Loss 1.774606601681028


CV 0 | Epoch 1680 | Loss 1.7729547909327916


CV 0 | Epoch 1690 | Loss 1.7726965870176044


CV 0 | Epoch 1700 | Loss 1.7715156878743852


CV 0 | Epoch 1710 | Loss 1.7752707242965697


CV 0 | Epoch 1720 | Loss 1.7742346176079342


CV 0 | Epoch 1730 | Loss 1.7706272866044726


CV 0 | Epoch 1740 | Loss 1.7709398439952306


CV 0 | Epoch 1750 | Loss 1.7739239437239511


CV 0 | Epoch 1760 | Loss 1.764308969463621


CV 0 | Epoch 1770 | Loss 1.7675303344215665


CV 0 | Epoch 1780 | Loss 1.765504424061094


CV 0 | Epoch 1790 | Loss 1.7760460930211204


CV 0 | Epoch 1800 | Loss 1.7688161960669926


CV 0 | Epoch 1810 | Loss 1.7667194102491652


CV 0 | Epoch 1820 | Loss 1.7723044659410203


CV 0 | Epoch 1830 | Loss 1.7676065632275173


CV 0 | Epoch 1840 | Loss 1.758405341420855


CV 0 | Epoch 1850 | Loss 1.7793122734342302


CV 0 | Epoch 1860 | Loss 1.7760081836155484


CV 0 | Epoch 1870 | Loss 1.7653478843825203


CV 0 | Epoch 1880 | Loss 1.7638688027858733


CV 0 | Epoch 1890 | Loss 1.766831406525203


CV 0 | Epoch 1900 | Loss 1.7694868113313402


CV 0 | Epoch 1910 | Loss 1.7681328185967036


CV 0 | Epoch 1920 | Loss 1.7663669671331133


CV 0 | Epoch 1930 | Loss 1.7675711035728454


CV 0 | Epoch 1940 | Loss 1.7668583350522178


CV 0 | Epoch 1950 | Loss 1.7686975027833667


CV 0 | Epoch 1960 | Loss 1.768448268515723


CV 0 | Epoch 1970 | Loss 1.76498213495527


CV 0 | Epoch 1980 | Loss 1.7597916679722923


CV 0 | Epoch 1990 | Loss 1.7600783331053598


CV 0 | Epoch 1999 | Loss 1.7718114067637731
