In [None]:
# Parameters
n_contributions = 21
n_punishments = 31
n_cross_val = 2
fraction_training = 1.0
data_file = "../../data/experiments/pilot_random1_player_round_slim.csv"
output_path = "../../data/training/dev"
labels = {}
model_name = "graph"
model_args = {
    "add_rnn": False,
    "add_edge_model": False,
    "add_global_model": False,
    "hidden_size": 10,
    "x_encoding": [
        {"name": "prev_contributions", "n_levels": 21, "encoding": "numeric"},
        {"name": "prev_punishments", "n_levels": 31, "encoding": "numeric"},
        {"name": "round_number", "n_levels": 16, "encoding": "numeric"},
        {"name": "prev_common_good", "norm": 128, "etype": "float"},
        {"name": "prev_valid", "etype": "bool"},
    ],
    "u_encoding": [{"name": "prev_common_good", "norm": 128, "etype": "float"}],
}
optimizer_args = {"lr": 0.0001, "weight_decay": 1e-05}
train_args = {"epochs": 100, "batch_size": 20, "clamp_grad": 1, "eval_period": 10}
device = "cpu"


In [22]:
%load_ext autoreload
%autoreload 2

import os
import pandas as pd
import torch as th
from aimanager.generic.data import create_syn_data, create_torch_data, get_cross_validations
from aimanager.artificial_humans import AH_MODELS
from aimanager.artificial_humans.evaluation import Evaluator
from aimanager.utils.array_to_df import using_multiindex
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader

output_path = os.path.join(output_path, 'data')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
df = pd.read_csv(data_file)


data = create_torch_data(df)
syn_data = create_syn_data(n_contribution=21, n_punishment=31)

In [34]:
from torch.nn import Sequential as Seq, Linear as Lin, Tanh, GRU
import torch as th
from torch_scatter import scatter_mean
from torch_geometric.nn import MetaLayer
from aimanager.generic.encoder import Encoder, IntEncoder


class EdgeModel(th.nn.Module):
    def __init__(self, x_features, edge_features, u_features, out_features):
        super().__init__()
        in_features = 2*x_features+edge_features+u_features
        self.edge_mlp = Seq(Lin(in_features=in_features, out_features=out_features), Tanh())

    def forward(self, src, dest, edge_attr, u, batch):
        # src, dest: [E, F_x], where E is the number of edges.
        # edge_attr: [E, F_e]
        # u: [B, F_u], where B is the number of graphs.
        # batch: [E] with max entry B - 1.
        out = th.cat([src, dest, edge_attr, u[batch]], dim=-1)
        out = self.edge_mlp(out)
        return out


class NodeModel(th.nn.Module):
    def __init__(self, x_features, edge_features, u_features, out_features):
        super().__init__()
        in_features = x_features+edge_features+u_features
        self.node_mlp = Seq(Lin(in_features=in_features, out_features=out_features), Tanh())

    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.

        row, col = edge_index
        out = scatter_mean(edge_attr, col, dim=0, dim_size=x.size(0))
        out = th.cat([x, out, u[batch]], dim=-1)
        out = self.node_mlp(out)
        return out

class GlobalModel(th.nn.Module):
    def __init__(self, x_features, edge_features, u_features, out_features):
        super().__init__()
        in_features = u_features+x_features
        self.global_mlp = Seq(Lin(in_features=in_features, out_features=out_features), Tanh())

    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        out = th.cat([u, scatter_mean(x, batch, dim=0)], dim=-1)
        return self.global_mlp(out)

class GraphNetwork(th.nn.Module):
    def __init__(self, n_contributions, n_punishments, x_encoding, u_encoding, add_rnn=True, add_edge_model=True, 
            add_global_model=True, hidden_size=None, op1=None, op2=None, rnn_n=None, rnn_g=None):
        super().__init__()
        self.x_encoder = Encoder(x_encoding)
        self.u_encoder = Encoder(u_encoding, aggregation='mean')
        self.y_encoder = IntEncoder(encoding='onehot', name='contributions', n_levels=n_contributions)
        x_features = self.x_encoder.size
        u_features = self.u_encoder.size
        y_features = self.y_encoder.size
        self.n_contributions = n_contributions
        self.n_punishments = n_punishments
        self.x_encoding = x_encoding
        self.u_encoding = u_encoding

        edge_features = 0
        if op1 is None:
            if add_edge_model:
                edge_model = EdgeModel(
                    x_features=x_features, edge_features=edge_features, 
                    u_features=u_features, out_features=hidden_size)
                edge_features = hidden_size
            else:
                edge_model = None

            node_model = NodeModel(
                x_features=x_features, edge_features=edge_features, 
                u_features=u_features, out_features=hidden_size)
            x_features = hidden_size

            if add_global_model:
                gobal_model = GlobalModel(
                    x_features=x_features, edge_features=edge_features, 
                    u_features=u_features, out_features=hidden_size)
                u_features = hidden_size
            else:
                gobal_model = None

            self.op1 = MetaLayer(edge_model, node_model, gobal_model)

            if add_rnn:
                self.rnn_n = GRU(input_size=x_features, hidden_size=hidden_size, num_layers=1, batch_first=True)
                x_features = hidden_size
            else:
                self.rnn_n = None

            if add_rnn and add_global_model:
                self.rnn_g = GRU(input_size=u_features, hidden_size=hidden_size, num_layers=1, batch_first=True)
                u_features = hidden_size
            else:
                self.rnn_g = None


            self.op2 = MetaLayer(
                None,
                NodeModel(
                    x_features=x_features, edge_features=0, 
                    u_features=u_features, out_features=y_features), 
                None
            )
        else:
            self.op1 = op1
            self.op2 = op2
            self.rnn_n = rnn_n
            self.rnn_g = rnn_g
    
    def forward(self, data):
        x = data['x']
        edge_index = data['edge_index']
        edge_attr = data['edge_attr']
        u = data['u']
        batch = data['batch']
        x, _, u = self.op1(x, edge_index, edge_attr, u, batch)
        if self.rnn_n is not None:
            x, x_h_n = self.rnn_n(x)
        if self.rnn_g is not None:
            u, u_h_n = self.rnn_g(u)
        x, _, _ = self.op2(x, edge_index, edge_attr, u, batch)
        return x

    def predict(self, data):
        self.eval()
        y_pred_logit = th.cat([self(d)
            for d in iter(DataLoader(data, shuffle=False, batch_size=10))
        ])
        y_pred_proba = th.nn.functional.softmax(y_pred_logit, dim=-1)
        y_pred = self.y_encoder.decode(y_pred_proba)
        return y_pred, y_pred_proba

    def save(self, filename):
        to_save = {
            'op1': self.op1,
            'op2': self.op2,
            'n_contributions': self.n_contributions,
            'n_punishments': self.n_punishments,
            'x_encoding': self.x_encoding, 
            'u_encoding': self.u_encoding
        }
        th.save(to_save, filename)

    @classmethod
    def load(cls, filename):
        to_load = th.load(filename)
        ah = cls(**to_load)
        return ah


AH_MODELS['graph'] = GraphNetwork

In [35]:
th_device = th.device(device)

metrics = []
confusion_matrix = []
syn_pred = []
ev = Evaluator()

th_device = th.device(device)

syn_index = ['prev_punishments', 'prev_contributions']

def create_fully_connected(n_nodes):
    return th.tensor([[i,j]
        for i in range(n_nodes)
        for j in range(n_nodes)
    ]).T

def encode(model, data, *, mask=True, index=False, x_encode=True, y_encode=True, u_encode=False, device, n_player=4):
    data = {
        'mask': data['valid'] if mask else None,
        'x': model.x_encoder(**data) if x_encode else None,
        'y_enc': model.y_encoder(**data) if y_encode else None,
        'y': data['contributions'] if y_encode else None,
        'u': model.u_encoder(**data) if u_encode else None,
        'info': th.stack([data[c] for c in syn_index], dim=-1) if index else None,
    }
    data = {
        k: v.to(device)
        for k, v in data.items()
        if v is not None
    }

    n_episodes, n_agents, n_rounds, _ = data['x'].shape

    edge_attr = th.zeros(n_player*n_player, n_rounds,0)
    edge_index = create_fully_connected(n_player)

    n_episodes = list(data.values())[0].shape[0]
    dataset = [
        Data(**{k: v[i] for k, v in data.items()}, edge_attr=edge_attr, edge_index=edge_index, idx=i, group_idx=i, num_nodes=n_player)
        for i in range(n_episodes)
    ]
    return dataset


for i, (train_data, test_data) in enumerate(get_cross_validations(data, n_cross_val, fraction_training)):
    model = AH_MODELS[model_name](
        n_contributions=n_contributions, n_punishments=n_punishments,
        **model_args).to(th_device)

    train_data_ = encode(model, train_data, mask=True, u_encode=True, device=th_device)
    test_data_ = encode(model, test_data, mask=True, u_encode=True, device=th_device)
    syn_data_ = encode(model, syn_data, mask=False, y_encode=False, u_encode=True, index=True, device=th_device)

    syn_df = using_multiindex(
        Batch.from_data_list(syn_data_)['info'], ['idx', 'round_number'], syn_index)

    ev.set_data(test=test_data_, train=train_data_, syn=syn_data_, syn_df=syn_df)

    optimizer = th.optim.Adam(model.parameters(), **optimizer_args)
    loss_fn = th.nn.CrossEntropyLoss(reduction='none')
    sum_loss = 0
    n_steps = 0

    for e in range(train_args['epochs']):
        ev.set_labels(cv_split=i, epoch=e)
        model.train()
        for j, batch_data in enumerate(iter(DataLoader(train_data_, shuffle=True, batch_size=train_args['batch_size']))):

            optimizer.zero_grad()
            py = model(batch_data).flatten(end_dim=-2)
            y_true = batch_data['y_enc'].flatten(end_dim=-2)
            mask = batch_data['mask'].flatten()
            loss = loss_fn(py, y_true)
            loss = (loss * mask).sum() / mask.sum()

            loss.backward()

            if train_args['clamp_grad']:
                for param in model.parameters():
                    param.grad.data.clamp_(-train_args['clamp_grad'], train_args['clamp_grad'])
            optimizer.step()
            sum_loss += loss.item()
            n_steps +=1
        
        if e % train_args['eval_period'] == 0:
            avg_loss = sum_loss/n_steps
            print(f'CV {i} | Epoch {e} | Loss {avg_loss}')
            ev.add_loss(avg_loss)
            ev.eval_set(model, 'train')
            ev.eval_set(model, 'test')
            sum_loss = 0
            n_steps = 0

    ev.eval_sync(model, syn_index=syn_index)

ev.save(output_path, labels)
model_path = os.path.join(output_path, 'model.pt')
model.save(model_path)

135 4 16
68 4 16
651 4 16
135
CV 0 | Epoch 0 | Loss 3.0514208248683383
CV 0 | Epoch 10 | Loss 3.028706737927028
CV 0 | Epoch 20 | Loss 2.9845533439091274
CV 0 | Epoch 30 | Loss 2.936578382764544
CV 0 | Epoch 40 | Loss 2.8867049660001483
CV 0 | Epoch 50 | Loss 2.836121620450701
CV 0 | Epoch 60 | Loss 2.793965765408107
CV 0 | Epoch 70 | Loss 2.7599569286618912
CV 0 | Epoch 80 | Loss 2.7316042218889507
CV 0 | Epoch 90 | Loss 2.7124502079827444
135 4 16
67 4 16
651 4 16
135
CV 1 | Epoch 0 | Loss 3.1176112719944546
CV 1 | Epoch 10 | Loss 3.0891532284872874
CV 1 | Epoch 20 | Loss 3.0383883919034687
CV 1 | Epoch 30 | Loss 2.9914779015949793
CV 1 | Epoch 40 | Loss 2.9447254283087596
CV 1 | Epoch 50 | Loss 2.901524247441973
CV 1 | Epoch 60 | Loss 2.8620461668287005
CV 1 | Epoch 70 | Loss 2.824734786578587
CV 1 | Epoch 80 | Loss 2.7936674867357527
CV 1 | Epoch 90 | Loss 2.7664338690893993
