# Model training

This notebook can be used to train the models used for the replication experiment. The notebook makes heavy use of predefined configuration files that describe the parameter setting of each model.

**To replace the pretrained models in the replication study** you need to copy the trained model from `checkpoints` to `Explanation/models/pretrained/<_model>/<_dataset>`. Where \_model and \_dataset are defined as in the code below. 

In [4]:
import torch
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
import numpy as np
from ExplanationEvaluation.configs.selector import Selector
from ExplanationEvaluation.tasks.training import meta_train_node, meta_train_graph

In [5]:
_dataset = 'bacommunity' # One of: bashapes, bacommunity, treecycles, treegrids, ba2motifs, mutag
_folder = 'replication'
_model = 'gnn'
config_path = f"./ExplanationEvaluation/configs/{_folder}/models/model_{_model}_{_dataset}.json"

config = Selector(config_path)
extension = (_folder == 'extension')

config = Selector(config_path).args

torch.manual_seed(config.model.seed)
torch.cuda.manual_seed(config.model.seed)
np.random.seed(config.model.seed)

In [None]:
_dataset = config.model.dataset
_explainer = config.model.paper

if _dataset[:3] == "syn":
    meta_train_node(_dataset, _explainer, config.model, device)
elif _dataset == "ba2" or _dataset == "mutag":
    meta_train_graph(_dataset, _explainer, config.model, device)

Loading syn1 dataset
NodeGCN(
  (conv1): GCNConv(10, 20)
  (relu1): ReLU()
  (conv2): GCNConv(20, 20)
  (relu2): ReLU()
  (conv3): GCNConv(20, 20)
  (relu3): ReLU()
  (lin): Linear(in_features=60, out_features=4, bias=True)
)
Epoch: 0, train_acc: 0.1196, val_acc: 0.0714, train_loss: 1.4274
Val improved
Epoch: 1, train_acc: 0.2179, val_acc: 0.3143, train_loss: 1.4113
Val improved
Epoch: 2, train_acc: 0.2179, val_acc: 0.3143, train_loss: 1.3936
Epoch: 3, train_acc: 0.4464, val_acc: 0.3143, train_loss: 1.3791
Epoch: 4, train_acc: 0.4464, val_acc: 0.3143, train_loss: 1.3657
Epoch: 5, train_acc: 0.4464, val_acc: 0.3143, train_loss: 1.3545
Epoch: 6, train_acc: 0.4464, val_acc: 0.3143, train_loss: 1.3452
Epoch: 7, train_acc: 0.4464, val_acc: 0.3143, train_loss: 1.3377
Epoch: 8, train_acc: 0.4464, val_acc: 0.3143, train_loss: 1.3304
Epoch: 9, train_acc: 0.4464, val_acc: 0.3143, train_loss: 1.3240
Epoch: 10, train_acc: 0.4464, val_acc: 0.3143, train_loss: 1.3181
Epoch: 11, train_acc: 0.4464, va