In [1]:
# Parameters
fraction_training = 1.0
n_cross_val = None
n_player = 4
data_file = "../../data/experiments/pilot_random1_player_round_slim.csv"
output_path = "../../data/artificial_humans/ah_1_1_simple"
model_name = "graph"
shuffle_features = [
    "prev_punishments",
    "prev_contributions",
    "prev_common_good",
    "prev_valid",
]
labels = {}
model_args = {
    "hidden_size": 5,
    "add_rnn": False,
    "add_edge_model": False,
    "add_global_model": False,
    "x_encoding": [
        {"name": "prev_contributions", "n_levels": 21, "encoding": "numeric"},
        {"name": "prev_punishments", "n_levels": 31, "encoding": "numeric"},
        {"etype": "bool", "name": "prev_valid"},
    ],
    "u_encoding": [],
    "y_levels": 21,
    "y_name": "contributions",
}
mask_name = "valid"
experiment_names = ["trail_rounds_2", "random_1"]
optimizer_args = {"lr": 0.001, "weight_decay": 1e-05}
train_args = {"epochs": 2000, "batch_size": 10, "clamp_grad": 1, "eval_period": 10}
device = "cpu"


In [2]:
%load_ext autoreload
%autoreload 2

import os
import pandas as pd
import torch as th
from aimanager.generic.data import create_syn_data, create_torch_data, get_cross_validations
from aimanager.artificial_humans import AH_MODELS
from aimanager.artificial_humans.evaluation import Evaluator
from aimanager.utils.array_to_df import using_multiindex
from aimanager.generic.graph_encode import create_fully_connected
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader

output_path = os.path.join(output_path, 'data')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
df = pd.read_csv(data_file)

df = df[df['experiment_name'].isin(experiment_names)]

data, default_values = create_torch_data(df)
# syn_data = create_syn_data(n_contribution=21, n_punishment=31, default_values=default_values)

In [4]:
th_device = th.device(device)

metrics = []
confusion_matrix = []
syn_pred = []
ev = Evaluator()

th_device = th.device(device)

syn_index = ['prev_punishments', 'prev_contributions']
edge_index = create_fully_connected(n_player)


def shuffle_feature(data, feature_name):
    data = {**data}
    data[feature_name] = data[feature_name][th.randperm(len(data[feature_name]))]
    return data

for i, (train_data, test_data) in enumerate(get_cross_validations(data, n_cross_val, fraction_training)):
    model = AH_MODELS[model_name](default_values=default_values, **model_args).to(th_device)

    train_data_ = model.encode(train_data, mask=mask_name, edge_index=edge_index)
    if test_data is not None:
        test_data_ = model.encode(test_data, mask=mask_name, edge_index=edge_index)
    # syn_data_ = model.encode(syn_data, mask=None, y_encode=False, info_columns=syn_index, edge_index=edge_index)

    # syn_df = using_multiindex(
    #     Batch.from_data_list(syn_data_)['info'].detach().cpu().numpy(), ['idx', 'round_number'], syn_index)

    optimizer = th.optim.Adam(model.parameters(), **optimizer_args)
    loss_fn = th.nn.CrossEntropyLoss(reduction='none')
    sum_loss = 0
    n_steps = 0

    for e in range(train_args['epochs']):
        ev.set_labels(cv_split=i, epoch=e)
        model.train()
        for j, batch_data in enumerate(iter(DataLoader(train_data_, shuffle=True, batch_size=train_args['batch_size']))):
            optimizer.zero_grad()
            py = model(batch_data).flatten(end_dim=-2)
            y_true = batch_data['y_enc'].flatten(end_dim=-2)
            mask = batch_data['mask'].flatten()
            loss = loss_fn(py, y_true)
            loss = (loss * mask).sum() / mask.sum()

            loss.backward()

            if train_args['clamp_grad']:
                for param in model.parameters():
                    param.grad.data.clamp_(-train_args['clamp_grad'], train_args['clamp_grad'])
            optimizer.step()
            sum_loss += loss.item()
            n_steps +=1
        
        last_epoch = e == (train_args['epochs'] - 1)

        if (e % train_args['eval_period'] == 0) or last_epoch:
            avg_loss = sum_loss/n_steps
            print(f'CV {i} | Epoch {e} | Loss {avg_loss}')
            ev.add_loss(avg_loss)

            ev.eval_set(model, train_data_, calc_confusion=last_epoch, set='train')
            if test_data is not None:
                ev.eval_set(model, test_data_, calc_confusion=last_epoch, set='test')
                for sf in shuffle_features:
                    shuffled_data = shuffle_feature(test_data, sf)
                    shuffled_data = model.encode(shuffled_data, mask=mask_name, edge_index=edge_index)
                    ev.eval_set(model, shuffled_data, calc_confusion=False, set='test', shuffle_feature=sf)
            sum_loss = 0
            n_steps = 0
    # ev.eval_syn(model, syn_data_, syn_df)

ev.save(output_path, labels)
model_path = os.path.join(output_path, 'model.pt')
model.save(model_path)

CV 0 | Epoch 0 | Loss 3.2217195204326083


CV 0 | Epoch 10 | Loss 3.1040441376822336


CV 0 | Epoch 20 | Loss 2.8315121361187527


CV 0 | Epoch 30 | Loss 2.5433589577674867


CV 0 | Epoch 40 | Loss 2.4814128041267396


CV 0 | Epoch 50 | Loss 2.465010060582842


CV 0 | Epoch 60 | Loss 2.4462982654571532


CV 0 | Epoch 70 | Loss 2.427445639031274


CV 0 | Epoch 80 | Loss 2.4173668009894236


CV 0 | Epoch 90 | Loss 2.3952109456062316


CV 0 | Epoch 100 | Loss 2.391174862214497


CV 0 | Epoch 110 | Loss 2.369913921185902


CV 0 | Epoch 120 | Loss 2.3577675947121213


CV 0 | Epoch 130 | Loss 2.3451263581003463


CV 0 | Epoch 140 | Loss 2.3374577156135015


CV 0 | Epoch 150 | Loss 2.3275561801024844


CV 0 | Epoch 160 | Loss 2.3089271775313787


CV 0 | Epoch 170 | Loss 2.3036900350025724


CV 0 | Epoch 180 | Loss 2.2898571406091963


CV 0 | Epoch 190 | Loss 2.2804923364094325


CV 0 | Epoch 200 | Loss 2.277889965261732


CV 0 | Epoch 210 | Loss 2.272124843086515


CV 0 | Epoch 220 | Loss 2.257448502097811


CV 0 | Epoch 230 | Loss 2.2569517586912426


CV 0 | Epoch 240 | Loss 2.24315773333822


CV 0 | Epoch 250 | Loss 2.2449071594647


CV 0 | Epoch 260 | Loss 2.2361351805073872


CV 0 | Epoch 270 | Loss 2.2312750220298767


CV 0 | Epoch 280 | Loss 2.2193724410874505


CV 0 | Epoch 290 | Loss 2.2226219262395586


CV 0 | Epoch 300 | Loss 2.204880336352757


CV 0 | Epoch 310 | Loss 2.191867564405714


CV 0 | Epoch 320 | Loss 2.193220079796655


CV 0 | Epoch 330 | Loss 2.1825887492724827


CV 0 | Epoch 340 | Loss 2.177130093744823


CV 0 | Epoch 350 | Loss 2.1698235648018973


CV 0 | Epoch 360 | Loss 2.1672125611986433


CV 0 | Epoch 370 | Loss 2.154352333715984


CV 0 | Epoch 380 | Loss 2.1532169912542614


CV 0 | Epoch 390 | Loss 2.1371764404433113


CV 0 | Epoch 400 | Loss 2.1407703604016985


CV 0 | Epoch 410 | Loss 2.129302153417042


CV 0 | Epoch 420 | Loss 2.1339898603303094


CV 0 | Epoch 430 | Loss 2.1324758887290955


CV 0 | Epoch 440 | Loss 2.117849039179938


CV 0 | Epoch 450 | Loss 2.119842402424131


CV 0 | Epoch 460 | Loss 2.1098273805209566


CV 0 | Epoch 470 | Loss 2.103941284758704


CV 0 | Epoch 480 | Loss 2.1042658388614655


CV 0 | Epoch 490 | Loss 2.101267626455852


CV 0 | Epoch 500 | Loss 2.0897314071655275


CV 0 | Epoch 510 | Loss 2.0828457244804928


CV 0 | Epoch 520 | Loss 2.0820424377918245


CV 0 | Epoch 530 | Loss 2.086986775909151


CV 0 | Epoch 540 | Loss 2.0788648792675564


CV 0 | Epoch 550 | Loss 2.075914006573813


CV 0 | Epoch 560 | Loss 2.07006123321397


CV 0 | Epoch 570 | Loss 2.0766502618789673


CV 0 | Epoch 580 | Loss 2.063584386450904


CV 0 | Epoch 590 | Loss 2.0610405726092202


CV 0 | Epoch 600 | Loss 2.063654966865267


CV 0 | Epoch 610 | Loss 2.0540550666196005


CV 0 | Epoch 620 | Loss 2.0592568022864204


CV 0 | Epoch 630 | Loss 2.0555728767599377


CV 0 | Epoch 640 | Loss 2.052333391564233


CV 0 | Epoch 650 | Loss 2.055504884890148


CV 0 | Epoch 660 | Loss 2.050924125739506


CV 0 | Epoch 670 | Loss 2.052869858060564


CV 0 | Epoch 680 | Loss 2.047356303674834


CV 0 | Epoch 690 | Loss 2.0466870980603353


CV 0 | Epoch 700 | Loss 2.0437390880925315


CV 0 | Epoch 710 | Loss 2.0472152471542358


CV 0 | Epoch 720 | Loss 2.0410083251340048


CV 0 | Epoch 730 | Loss 2.043183287552425


CV 0 | Epoch 740 | Loss 2.0360197586672646


CV 0 | Epoch 750 | Loss 2.038947615453175


CV 0 | Epoch 760 | Loss 2.0344144122941152


CV 0 | Epoch 770 | Loss 2.034136469023568


CV 0 | Epoch 780 | Loss 2.03224681360381


CV 0 | Epoch 790 | Loss 2.026940680401666


CV 0 | Epoch 800 | Loss 2.031033594267709


CV 0 | Epoch 810 | Loss 2.0259644091129303


CV 0 | Epoch 820 | Loss 2.034927519730159


CV 0 | Epoch 830 | Loss 2.0259989193507604


CV 0 | Epoch 840 | Loss 2.028445316212518


CV 0 | Epoch 850 | Loss 2.028379184007645


CV 0 | Epoch 860 | Loss 2.0240364917687006


CV 0 | Epoch 870 | Loss 2.0227450234549385


CV 0 | Epoch 880 | Loss 2.0210057147911615


CV 0 | Epoch 890 | Loss 2.0184995813029154


CV 0 | Epoch 900 | Loss 2.0234789448125023


CV 0 | Epoch 910 | Loss 2.0177497395447324


CV 0 | Epoch 920 | Loss 2.02218142918178


CV 0 | Epoch 930 | Loss 2.0171112511839184


CV 0 | Epoch 940 | Loss 2.013681101799011


CV 0 | Epoch 950 | Loss 2.0202770778111048


CV 0 | Epoch 960 | Loss 2.0103428666080747


CV 0 | Epoch 970 | Loss 2.0158905812672208


CV 0 | Epoch 980 | Loss 2.010819160938263


CV 0 | Epoch 990 | Loss 2.0173353339944566


CV 0 | Epoch 1000 | Loss 2.010574631180082


CV 0 | Epoch 1010 | Loss 2.0112335307257516


CV 0 | Epoch 1020 | Loss 2.0105170377663204


CV 0 | Epoch 1030 | Loss 2.0081470830099923


CV 0 | Epoch 1040 | Loss 2.004446895633425


CV 0 | Epoch 1050 | Loss 2.013406811441694


CV 0 | Epoch 1060 | Loss 2.005657357828958


CV 0 | Epoch 1070 | Loss 2.015169282470431


CV 0 | Epoch 1080 | Loss 2.001809812443597


CV 0 | Epoch 1090 | Loss 2.0133160190922874


CV 0 | Epoch 1100 | Loss 2.0003034932272774


CV 0 | Epoch 1110 | Loss 2.0044510202748436


CV 0 | Epoch 1120 | Loss 1.9994373764310565


CV 0 | Epoch 1130 | Loss 2.0048552547182354


CV 0 | Epoch 1140 | Loss 2.0022791172776904


CV 0 | Epoch 1150 | Loss 2.001563079868044


CV 0 | Epoch 1160 | Loss 1.9997527931417738


CV 0 | Epoch 1170 | Loss 1.994414611373629


CV 0 | Epoch 1180 | Loss 2.0033116672720226


CV 0 | Epoch 1190 | Loss 1.9999556737286703


CV 0 | Epoch 1200 | Loss 1.9940745847565786


CV 0 | Epoch 1210 | Loss 1.9984365480286734


CV 0 | Epoch 1220 | Loss 1.9980993977614812


CV 0 | Epoch 1230 | Loss 1.999337146963392


CV 0 | Epoch 1240 | Loss 1.9966966135161264


CV 0 | Epoch 1250 | Loss 1.9954472005367279


CV 0 | Epoch 1260 | Loss 1.9971210360527039


CV 0 | Epoch 1270 | Loss 1.9918289346354348


CV 0 | Epoch 1280 | Loss 1.9930351555347443


CV 0 | Epoch 1290 | Loss 1.9978357221399035


CV 0 | Epoch 1300 | Loss 1.992096415587834


CV 0 | Epoch 1310 | Loss 1.9904477383409227


CV 0 | Epoch 1320 | Loss 1.9848278769424983


CV 0 | Epoch 1330 | Loss 1.987185071195875


CV 0 | Epoch 1340 | Loss 1.9962478041648866


CV 0 | Epoch 1350 | Loss 1.9941962897777556


CV 0 | Epoch 1360 | Loss 1.9912937828472683


CV 0 | Epoch 1370 | Loss 1.9927130741732462


CV 0 | Epoch 1380 | Loss 1.9890705449240549


CV 0 | Epoch 1390 | Loss 1.9897258222103118


CV 0 | Epoch 1400 | Loss 1.99251241173063


CV 0 | Epoch 1410 | Loss 1.991244408914021


CV 0 | Epoch 1420 | Loss 1.989010133913585


CV 0 | Epoch 1430 | Loss 1.9848586244242532


CV 0 | Epoch 1440 | Loss 1.9888498978955405


CV 0 | Epoch 1450 | Loss 1.985841473511287


CV 0 | Epoch 1460 | Loss 1.990732261112758


CV 0 | Epoch 1470 | Loss 1.9888053238391876


CV 0 | Epoch 1480 | Loss 1.9813461678368705


CV 0 | Epoch 1490 | Loss 1.991432912009103


CV 0 | Epoch 1500 | Loss 1.9866252047674997


CV 0 | Epoch 1510 | Loss 1.986868578195572


CV 0 | Epoch 1520 | Loss 1.9891538075038364


CV 0 | Epoch 1530 | Loss 1.9814694762229919


CV 0 | Epoch 1540 | Loss 1.98393817118236


CV 0 | Epoch 1550 | Loss 1.980151241166251


CV 0 | Epoch 1560 | Loss 1.9862340612070901


CV 0 | Epoch 1570 | Loss 1.984018416064126


CV 0 | Epoch 1580 | Loss 1.9799756220408848


CV 0 | Epoch 1590 | Loss 1.9850191584655217


CV 0 | Epoch 1600 | Loss 1.9824102682726723


CV 0 | Epoch 1610 | Loss 1.9844546573502677


CV 0 | Epoch 1620 | Loss 1.9825703169618334


CV 0 | Epoch 1630 | Loss 1.9780361201081957


CV 0 | Epoch 1640 | Loss 1.9913844457694463


CV 0 | Epoch 1650 | Loss 1.9852923291070121


CV 0 | Epoch 1660 | Loss 1.9829074033669063


CV 0 | Epoch 1670 | Loss 1.9804954579898288


CV 0 | Epoch 1680 | Loss 1.9797384807041714


CV 0 | Epoch 1690 | Loss 1.9807159193924495


CV 0 | Epoch 1700 | Loss 1.976983209167208


CV 0 | Epoch 1710 | Loss 1.9797523174967084


CV 0 | Epoch 1720 | Loss 1.988167302949088


CV 0 | Epoch 1730 | Loss 1.9818725160190038


CV 0 | Epoch 1740 | Loss 1.978081544807979


CV 0 | Epoch 1750 | Loss 1.9794288745948247


CV 0 | Epoch 1760 | Loss 1.9810319244861603


CV 0 | Epoch 1770 | Loss 1.9781260813985553


CV 0 | Epoch 1780 | Loss 1.9842962656702314


CV 0 | Epoch 1790 | Loss 1.9815629056521824


CV 0 | Epoch 1800 | Loss 1.982296517065593


CV 0 | Epoch 1810 | Loss 1.979565337726048


CV 0 | Epoch 1820 | Loss 1.9809239745140075


CV 0 | Epoch 1830 | Loss 1.974367251566478


CV 0 | Epoch 1840 | Loss 1.9760643294879368


CV 0 | Epoch 1850 | Loss 1.9857161564486367


CV 0 | Epoch 1860 | Loss 1.9829535373619624


CV 0 | Epoch 1870 | Loss 1.9714605689048768


CV 0 | Epoch 1880 | Loss 1.9830192591462816


CV 0 | Epoch 1890 | Loss 1.9736227512359619


CV 0 | Epoch 1900 | Loss 1.9787108446870532


CV 0 | Epoch 1910 | Loss 1.9763869762420654


CV 0 | Epoch 1920 | Loss 1.969763662985393


CV 0 | Epoch 1930 | Loss 1.981737151316234


CV 0 | Epoch 1940 | Loss 1.9773678949901037


CV 0 | Epoch 1950 | Loss 1.9749861913067954


CV 0 | Epoch 1960 | Loss 1.974205711909703


CV 0 | Epoch 1970 | Loss 1.971550990002496


CV 0 | Epoch 1980 | Loss 1.976595115661621


CV 0 | Epoch 1990 | Loss 1.9753645973546163


CV 0 | Epoch 1999 | Loss 1.9721657160728696
