In [7]:
import json
import logging
import random
from pathlib import Path
from typing import Optional

from munch import Munch
import numpy as np
import torch
from torch.utils.data import DataLoader
import yaml

from gtn import GTN
from gtn.dataloader.graphbatch_collator import GraphBatchCollator
from gtn.dataloader.graphdist_dataset import GraphDistDataset
from gtn.dataloader.io import load_from_npz
from gtn.dataloader.pyg_ged import get_pyg_ged_gcolls
from gtn.model import aggregation, geometric_gnn
from gtn.training.metrics import Metrics
from gtn.training.optimizer import add_weight_decay
from gtn.training.training import train
from gtn.training.validation import evaluate

In [8]:
# Set up logging
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter(
        fmt='%(asctime)s (%(levelname)s): %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('INFO')

# Configuration

In [26]:
with open('configs/aids_sinkhorn.yaml', 'r') as c:
    config_seml = yaml.safe_load(c)

In [39]:
config = Munch(config_seml['fixed'])

language_src = 'en'
language_tgt = 'es'
data_dir = "./data"  # Download the data first, as described in the README

variant = '1head'  # Single and multi-head variants: 1head, 8head
config.update(config_seml[variant]["fixed"])

config.nystrom = None if config.nystrom == "None" else config.nystrom
config.sparse = None if config.sparse == "None" else config.sparse
config.weight_decay = float(config.weight_decay)

seed = 216871081
test = False
device = "cuda"  # Change to cpu if necessary

In [40]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

device = torch.device(device)
graph_distance = config.graph_distance.lower()
dataname = config.dataname.lower()

# Load data

In [41]:
# Load data
data_path = Path.cwd() / "data"
if dataname in ["linux"]:
    dataname = dataname.upper()
    gcolls, pair_idxs = get_pyg_ged_gcolls(
        pyg_data_path, dataname, use_norm_ged=True, similarity=similarity
    )
else:
    gcolls = {
        dataset: load_from_npz(
            data_path / f"{dataname}_{graph_distance}_{dataset}.npz"
        )
        for dataset in ["train", "val", "test"]
    }
    pair_idxs = {dataset: None for dataset in ["train", "val", "test"]}
node_onehot = True
if node_onehot:
    node_feat_size = int(
        max(
            (
                max((np.max(graph.attr_matrix) for graph in gcoll))
                for gcoll in gcolls.values()
            )
        )
        + 1
    )
else:
    node_feat_size = gcolls["train"][0].attr_matrix.shape[1]
if gcolls["train"][0].edge_attr_matrix is None:
    edge_feat_size = 0
else:
    edge_feat_size = int(
        max(
            (
                max((np.max(graph.edge_attr_matrix) for graph in gcoll))
                for gcoll in gcolls.values()
            )
        )
        + 1
    )

# Get datasets
datasets = {}
for key, gcoll in gcolls.items():
    datasets[key] = GraphDistDataset(
        gcoll,
        node_feat_size,
        edge_feat_size,
        node_onehot=node_onehot,
        edge_onehot=True,
        pair_idx=pair_idxs[key],
    )

# Get dataloader
collator = GraphBatchCollator()
dataloaders = {}
for key, dataset in datasets.items():
    dataloaders[key] = DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        collate_fn=collator,
        num_workers=1,
    )

In [47]:
# Get Metrics
metrics_list = ["rmse", "cvrmse", "label_std"]
metrics_trackers = {"iter": {}, "epoch": {}}
metric_to_stop_on = "rmse"
minimize_stop_on = True
patience = np.inf
if config.get("print_step", None) is not None:
    metrics_trackers["iter"]["train"] = Metrics(metrics_list)
metrics_trackers["epoch"]["train"] = Metrics(metrics_list)
if config.get("print_step", None) is not None:
    metrics_trackers["iter"]["val"] = Metrics(metrics_list)
metrics_trackers["epoch"]["val"] = Metrics(
    metrics_list, metric_to_stop_on, minimize_stop_on, patience
)
metrics_trackers["epoch"]["test"] = Metrics(metrics_list)

# Set up model

In [43]:
# Select activation function
if config.act_fn == "linear":
    act_fn = lambda x: x
elif config.act_fn == "relu":
    act_fn = torch.nn.functional.relu
elif config.act_fn == "sigmoid":
    act_fn = torch.nn.functional.sigmoid
elif config.act_fn == "leaky_relu":
    act_fn = torch.nn.functional.leaky_relu
else:
    raise ValueError(f"Invalid act_fn '{config.act_fn}'.")

# Select layer aggregation function
assert config.num_heads >= 1
if config.num_heads == 1:
    layer_aggregation = aggregation.MLP(
        emb_size=config.emb_size, nlayers=config.nlayers, output_size=config.emb_size
    )
else:
    layer_aggregation = aggregation.All()

# Average degree used to prevent embedding magnitude changes in non-normalized aggregation
avg_degree = np.mean([graph.adj_matrix.sum(1).mean() for graph in gcolls["train"]])

# Get GNN
gnn = geometric_gnn.Net(
    node_feat_size=node_feat_size,
    edge_feat_size=edge_feat_size,
    emb_size=config.emb_size,
    nlayers=config.nlayers,
    layer_aggregation=layer_aggregation,
    device=device,
    act_fn=act_fn,
    avg_degree=avg_degree,
    deg_norm_hidden=config.deg_norm_hidden,
)

# Statistics for normalizing embeddings used for Sinkhorn
emb_dist_scale = np.mean(gcolls["train"].dists.A[datasets["train"].pair_idx])
if config.extensive:
    emb_dist_scale /= np.mean(
        [
            gcolls["train"][idx].num_nodes()
            for idx in datasets["train"].pair_idx.flatten()
        ]
    )

# Overall GTN model
model = GTN(
    gnn=gnn,
    emb_dist_scale=emb_dist_scale,
    device=device,
    sinkhorn_reg=config.sinkhorn_reg,
    sinkhorn_niter=config.sinkhorn_niter,
    unbalanced_mode=config.unbalanced_mode,
    nystrom=config.nystrom,
    sparse=config.sparse,
    extensive=config.extensive,
    num_heads=config.num_heads,
    multihead_scale_basis=config.get("multihead_scale_basis", 1),
    similarity=config.similarity,
)

# Training

In [44]:
# Training
parameters = add_weight_decay(model, weight_decay=config.weight_decay)
optimizer = torch.optim.Adam(parameters, lr=config.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=config.lr_stepsize, gamma=config.lr_gamma
)

result = {k: {} for k in metrics_list}
result = train(
    model,
    device,
    dataloaders,
    optimizer,
    lr_scheduler,
    metrics_trackers,
    num_epochs=config.num_epochs,
    print_step=config.get("print_step", None),
)

2021-07-12 02:12:47 (INFO): Epoch 0/199,    train    loss: 206.2427, rmse: 14.3612, cvrmse: 0.2991, label_std: 14.0401 (21.52s)
2021-07-12 02:12:49 (INFO): Epoch 0/199,    val      loss: 36.3762, rmse: 6.0313, cvrmse: 0.1164, label_std: 15.8700 (2.10s)
Traceback (most recent call last):
  File "/nfs/staff-ssd/klicpera/anaconda3/envs/pytorch1.4/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/nfs/staff-ssd/klicpera/anaconda3/envs/pytorch1.4/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/nfs/staff-ssd/klicpera/anaconda3/envs/pytorch1.4/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/nfs/staff-ssd/klicpera/anaconda3/envs/pytorch1.4/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 

# Evaluation

In [54]:
logging.info("Evaluating on test")
result_test = evaluate(
    model,
    device,
    dataloaders["test"],
    metrics_trackers["epoch"]["test"],
    disable_tqdm=False,
)
for key in metrics_list:
    result[key]["test"] = result_test[key]

2021-07-12 02:15:37 (INFO): Evaluating on test


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))

2021-07-12 02:15:40 (INFO): loss: 40.9043, rmse: 6.3956, cvrmse: 0.1320, label_std: 13.9971 (2.94s)





In [55]:
logging.info(result)

2021-07-12 02:15:43 (INFO): {'loss': {'test': 40.904275630382784}, 'rmse': {'test': 6.395645}, 'cvrmse': {'test': 0.13199443281115636}, 'label_std': {'test': 13.997074}}
