In [3]:
from typing import Callable

# import os
# os.environ["DGLBACKEND"] = "pytorch"


# from dgl import function as fn
# from dgl.nn.pytorch import GlobalAttentionPooling
import torch as th
from torch import nn
from torch.nn import functional as F


# https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/global_attention.py
from torch_geometric.nn import AttentionalAggregation, SAGEConv


class GlobalAttentionNet(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super().__init__()
        self.conv1 = SAGEConv(dataset.num_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))
        self.att = AttentionalAggregation(Linear(hidden, 1))
        self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.att.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = self.att(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__


class GINLayer(nn.Module):
    def __init__(self, dim_feat: int, edge_encoder: nn.Module, edge_feat: str="feat"):
        # https://mlabonne.github.io/blog/gin/
        # TODO: must study this
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim_feat, 2 * dim_feat),
            nn.ReLU(),
            nn.Linear(2 * dim_feat, dim_feat),
        )
        self.eps = nn.Parameter(th.Tensor([0]))
        self.edge_encoder = edge_encoder
        self._dim_feat = dim_feat
        self._edge_feat = edge_feat

    def forward(self, graph, feat):
        graph.ndata["h"] = feat
        graph.apply_edges(lambda edges: {"e": F.relu(edges.src["h"] + self.edge_encoder(edges.data[self._edge_feat]).view((-1, self._dim_feat)))})
        graph.update_all(fn.copy_e("e", "m"), fn.sum("m", "a"))
        feat = self.mlp((1 + self.eps) * graph.ndata.pop("h") + graph.ndata.pop("a"))
        return feat


# class GraphNorm(nn.Module):
#     def __init__(self, dim_feat: int):
#         super().__init__()

#         self.weight = nn.Parameter(th.ones(dim_feat))
#         self.bias = nn.Parameter(th.zeros(dim_feat))
#         self.mean_scale = nn.Parameter(th.ones(dim_feat))

#     def forward(self, graph, feat):
#         batch_list = graph.batch_num_nodes().to(device=graph.device, dtype=th.int64)
#         batch_index = th.arange(graph.batch_size, device=graph.device).repeat_interleave(batch_list)
#         batch_index = batch_index.view((-1,) + (1,) * (feat.dim() - 1)).expand_as(feat)

#         mean = th.zeros(graph.batch_size, *feat.shape[1:], device=graph.device)
#         mean = mean.scatter_add_(0, batch_index, feat)
#         mean = (mean.T / batch_list).T
#         mean = mean.repeat_interleave(batch_list, dim=0)

#         sub = feat - mean * self.mean_scale

#         std = th.zeros(graph.batch_size, *feat.shape[1:], device=graph.device)
#         std = std.scatter_add_(0, batch_index, sub.pow(2))
#         std = ((std.T / batch_list).T + 1e-6).sqrt()
#         std = std.repeat_interleave(batch_list, dim=0)

#         return self.weight * sub / std + self.bias

    
from pytorch_geometric.norm import GraphNorm

class GNNBase(nn.Module):
    def __init__(
            self,
            layer: str,
            norm: str,
            res: bool,
            dim_feat: int,
            depth: int,
            edge_encoder: Callable[[], nn.Module],
            *,
            dropout: bool=True,
            edge_feat: str="feat"
    ):
        super().__init__()

        self._dim_feat = dim_feat
        self._depth = depth
        self._res = res
        self._dropout = dropout

        if layer == "gin":
            self.layers = nn.ModuleList(
                [GINLayer(dim_feat, edge_encoder(), edge_feat) for _ in range(depth)]
            )
        else:
            raise ValueError(f"invalid layer type: {layer}")

        if norm == "none":
            self.norms = None
        elif norm == "graphnorm":
            self.norms = nn.ModuleList([GraphNorm(dim_feat) for _ in range(depth)])
        else:
            raise ValueError(f"invalid norm: {norm}.")


class GNNSimple(GNNBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.pooling = GlobalAttentionPooling(nn.Sequential(
            nn.Linear(self._dim_feat, 2 * self._dim_feat),
            nn.ReLU(),
            nn.Linear(2 * self._dim_feat, 1),
        ))

    def forward(self, graph, feat):
        for d in range(self._depth):
            feat_in = feat
            feat = self.layers[d](graph, feat)
            if self.norms:
                feat = self.norms[d](graph, feat)
            if d < self._depth - 1:
                feat = F.relu(feat)
            if self._dropout:
                feat = F.dropout(feat, training=self.training)
            if self._res:
                feat = feat + feat_in

        return self.pooling(graph, feat)

    def get_emb(self, graph, feat):
        return self.forward(graph, feat).unsqueeze(0)

In [2]:
!nvidia-smi

Thu Apr  6 05:28:18 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.73.05    Driver Version: 510.73.05    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro P5000        Off  | 00000000:00:05.0 Off |                  Off |
| 26%   34C    P8     6W / 180W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
# from dgl.nn.pytorch import GlobalAttentionPooling
import torch as th
from torch import nn
from torch.nn import functional as F

class MLAPBase(GNNBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.poolings = nn.ModuleList(
            [GlobalAttentionPooling(nn.Sequential(
                nn.Linear(self._dim_feat, 2 * self._dim_feat),
                nn.ReLU(),
                nn.Linear(2 * self._dim_feat, 1),
            )) for _ in range(self._depth)]
        )

    def forward(self, graph, feat):
        self._graph_embs = []

        for d in range(self._depth):
            feat_in = feat
            feat = self.layers[d](graph, feat)
            if self.norms:
                feat = self.norms[d](graph, feat)
            if d < self._depth - 1:
                feat = F.relu(feat)
            if self._dropout:
                feat = F.dropout(feat, training=self.training)
            if self._res:
                feat = feat + feat_in

            self._graph_embs.append(self.poolings[d](graph, feat))

        return self._aggregate()

    def _aggregate(self):
        raise NotImplementedError

    def get_emb(self, graph, feat):
        out = self.forward(graph, feat)
        self._graph_embs.append(out)
        return th.stack(self._graph_embs, dim=0)


class MLAPSum(MLAPBase):
    def _aggregate(self):
        return th.stack(self._graph_embs, dim=0).sum(dim=0)


class MLAPWeighted(MLAPBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight = nn.Parameter(th.ones(self._depth, 1, 1))

    def _aggregate(self):
        a = F.softmax(self.weight, dim=0)
        h = th.stack(self._graph_embs, dim=0)
        return (a * h).sum(dim=0)

In [5]:
import os
from pathlib import Path
from typing import Dict, List, Optional, Union

import random
import numpy as np
import torch as th
from torch import cuda


def log(path: Optional[Path], msg: str):
    print(msg)

    if path:
        with open(path, "a") as f:
            f.write(msg)
            f.write("\n")


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)
    if cuda.is_available():
        cuda.manual_seed(seed)


def get_repo_root(default_root="./mlap_root") -> Path:
    return Path(default_root)


def encode_seq_to_arr(seq: List[str], vocab2idx: Dict[str, int], max_seq_len: int) -> th.Tensor:
    seq = seq[:max_seq_len] + ["__EOS__"] * max(0, max_seq_len - len(seq))
    return th.tensor([vocab2idx[w] if w in vocab2idx else vocab2idx["__UNK__"] for w in seq], dtype=th.int64)


def decode_arr_to_seq(arr: Union[List[int], th.Tensor], idx2vocab: List[str], vocab2idx: Dict[str, int]) -> List[str]:
    if isinstance(arr, th.Tensor):
        arr = arr.tolist()
    if vocab2idx["__EOS__"] in arr:
        arr = arr[:arr.index(vocab2idx["__EOS__"])]
    return [idx2vocab[i] for i in arr]

In [6]:
from typing import Dict, List

import dgl
import torch as th
from torch import nn
from torch.nn import functional as F


class LinearDecoder(nn.Module):
    def __init__(self, dim_feat: int, max_seq_len: int, vocab2idx: Dict[str, int]):
        super().__init__()
        self._max_seq_len = max_seq_len
        self._vocab2idx = vocab2idx

        self.decoders = nn.ModuleList([nn.Linear(dim_feat, len(vocab2idx)) for _ in range(max_seq_len)])

    def forward(self, graph: dgl.DGLGraph, feats: th.Tensor, labels: List[List[str]]) -> List[th.Tensor]:
        return [d(feats[-1]) for d in self.decoders]


class LSTMDecoder(nn.Module):
    def __init__(self, dim_feat: int, max_seq_len: int, vocab2idx: Dict[str, int]):
        super().__init__()
        self._max_seq_len = max_seq_len
        self._vocab2idx = vocab2idx

        self.lstm = nn.LSTMCell(dim_feat, dim_feat)
        self.w_hc = nn.Linear(dim_feat * 2, dim_feat)
        self.layernorm = nn.LayerNorm(dim_feat)
        self.vocab_encoder = nn.Embedding(len(vocab2idx), dim_feat)
        self.vocab_bias = nn.Parameter(th.zeros(len(vocab2idx)))

    def forward(self, graph: dgl.DGLGraph, feats: th.Tensor, labels: List[List[str]]) -> List[th.Tensor]:
        if self.training:
            # teacher forcing
            batched_label = th.vstack([encode_seq_to_arr(label, self._vocab2idx, self._max_seq_len - 1) for label in labels])
            batched_label = th.hstack((th.zeros((graph.batch_size, 1), dtype=th.int64), batched_label))
            true_emb = self.vocab_encoder(batched_label.to(device=graph.device))

        h_t, c_t = feats[-1].clone(), feats[-1].clone()
        feats = feats.transpose(0, 1)  # (batch_size, L + 1, dim_feat)
        out = []
        pred_emb = self.vocab_encoder(th.zeros((graph.batch_size), dtype=th.int64, device=graph.device))

        vocab_mat = self.vocab_encoder(th.arange(len(self._vocab2idx), dtype=th.int64, device=graph.device))

        for i in range(self._max_seq_len):
            if self.training:
                _in = true_emb[:, i]
            else:
                _in = pred_emb
            h_t, c_t = self.lstm(_in, (h_t, c_t))

            a = F.softmax(th.bmm(feats, h_t.unsqueeze(-1)).squeeze(-1), dim=1)  # (batch_size, L + 1)
            context = th.bmm(a.unsqueeze(1), feats).squeeze(1)
            pred_emb = th.tanh(self.layernorm(self.w_hc(th.hstack((h_t, context)))))  # (batch_size, dim_feat)

            out.append(th.matmul(pred_emb, vocab_mat.T) + self.vocab_bias.unsqueeze(0))

        return out

In [7]:
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union

import dgl
import numpy as np
import torch as th
from torch import nn, optim
from tqdm import tqdm


GNN_CLASS = {
    "simple": GNNSimple,

    "mlap-sum": MLAPSum,
    "mlap-weighted": MLAPWeighted,
}


class TeacherForcing(nn.Module):
    def forward(self, graph: dgl.DGLGraph, labels: Any) -> th.Tensor:
        raise NotImplementedError


class Task:
    def __init__(self, dataset_name: str, device: th.device, save: bool) -> None:
        self.dataset_name = dataset_name
        self.device = device
        self.save = save

        if not self.log_dir.exists():
            self.log_dir.mkdir()
        if not self.model_dir.exists():
            self.model_dir.mkdir()

        self.train_loader = None
        self.valid_loader = None
        self.test_loader = None

    @property
    def log_dir(self) -> Path:
        return get_repo_root() / "log" / self.dataset_name

    @property
    def model_dir(self) -> Path:
        return get_repo_root() / "model" / self.dataset_name

    def load_dataset(self, batch_size: int) -> None:
        raise NotImplementedError

    def build_model(self, arch: str, norm: str, res: bool, dim_feat: int, depth: int) -> None:
        self.model = self._build_model(arch, norm, res, dim_feat, depth)

    def _build_model(self, arch: str, norm: str, res: bool, dim_feat: int, depth: int) -> nn.Module:
        raise NotImplementedError

    def _build_gnn(self, arch: str, *args, **kwargs) -> nn.Module:
        i = arch.index("-")
        self._current_gnn = GNN_CLASS[arch[(i + 1):]](arch[:i], *args, **kwargs)
        return self._current_gnn

    def set_seed(self, seed: int):
        self._seed = seed
        set_seed(seed)

    def evaluate(self, loader, *, silent=False) -> float:
        self.model.eval()
        batch_sizes = []
        y_true = []
        y_pred = []

        for batch in tqdm(loader, disable=silent):
            g, labels = batch
            with th.no_grad():
                if isinstance(self.model, TeacherForcing):
                    out = [t.detach().cpu() for t in self.model(g.to(self.device), None)]
                else:
                    out = self.model(g.to(self.device)).detach().cpu()
            batch_sizes.append(g.batch_size)
            y_true.append(labels)
            y_pred.append(out)

        return self._evaluate_score(y_true, y_pred, batch_sizes)

    def _evaluate_score(self, y_true: List[Any], y_pred: List[th.Tensor], batch_sizes: List[int]) -> float:
        raise NotImplementedError

    def train(
            self,
            epochs: int,
            optimizer: optim.Optimizer,
            scheduler: Optional[Union[optim.lr_scheduler.StepLR, optim.lr_scheduler.ReduceLROnPlateau]],
            *,
            save: Optional[Tuple[Path, Path, str]]=None,
    ) -> Tuple[List[float], List[float]]:

        train_curve = []
        valid_curve = []

        log_path = save[0] if save else None

        for epoch in range(1, epochs + 1):
            print("Training...")
            self.model.train()
            with tqdm(self.train_loader) as pbar:
                for i, batch in enumerate(pbar):
                    g, labels = batch
                    if isinstance(self.model, TeacherForcing):
                        out = self.model(g.to(self.device), labels)
                    else:
                        out = self.model(g.to(self.device))
                    loss = self._loss(out, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    pbar.set_description(f"Epoch {epoch: >3}, batch {i: > 5} loss: {loss.data:.3f}")

            print("Evaluating...")
            # train_perf = self.evaluate(self.train_loader)
            train_perf = 0.0
            valid_perf = self.evaluate(self.valid_loader)
            log(log_path, f"Epoch {epoch}, train {train_perf}, valid {valid_perf}")
            train_curve.append(train_perf)
            valid_curve.append(valid_perf)

            if scheduler:
                if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(valid_perf)
                else:
                    scheduler.step()

            if save:
                _, save_dir, save_name = save
                th.save(self.model.state_dict(), save_dir / f"{save_name}_e{epoch}")

        log(log_path, f"Best validation score: {np.max(valid_curve)}")
        return train_curve, valid_curve

    def _loss(self, out: th.Tensor, labels: th.Tensor) -> th.Tensor:
        raise NotImplementedError

In [8]:
from collections import defaultdict
from pathlib import Path
import pickle
import re
from typing import Any, Dict, List, Optional, Tuple

import dgl
import numpy as np
from ogb.graphproppred import DglGraphPropPredDataset, Evaluator
import pandas as pd
import torch as th
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm


class OGBCodeTask(Task):
    def __init__(
            self,
            dataset_name: str,
            device: th.device,
            save: bool,
            use_subtoken: bool,
            decoder_type: str
    ) -> None:

        super().__init__(dataset_name, device, save)
        self._use_subtoken = use_subtoken
        self._decoder_type = decoder_type

    def load_dataset(self, batch_size: int) -> None:
        self.dataset = DglGraphPropPredDataset(name=self.dataset_name)
        if "feat" in self.dataset[0][0].ndata:
            self.dataset.dim_node = self.dataset[0][0].ndata["feat"].shape[1]
        else:
            self.dataset.dim_node = 0
        if "feat" in self.dataset[0][0].edata:
            self.dataset.dim_edge = self.dataset[0][0].edata["feat"].shape[1]
        else:
            self.dataset.dim_edge = 0

        self._max_depth = 20
        self._max_seq_len = 5
        self._num_vocab = 5000
        self._num_nodetypes = len(pd.read_csv(Path(self.dataset.root) / "mapping" / "typeidx2type.csv.gz")["type"])

        split_idx = self.dataset.get_idx_split()
        self._vocab2idx, self._id2vocab = _get_vocab_mapping([l for _, l in self.dataset[split_idx["train"]]], self._num_vocab)

        if self._use_subtoken:
            pickle_path = Path(self.dataset.root) / "saved" / "loaders-subtoken.pkl"
        else:
            pickle_path = Path(self.dataset.root) / "saved" / "loaders.pkl"

        if pickle_path.exists():
            with open(pickle_path, "rb") as f:
                train_samples = pickle.load(f)
                valid_samples = pickle.load(f)
                test_samples = pickle.load(f)
                self._num_nodeattrs = pickle.load(f)

        else:
            if self._use_subtoken:
                def _subtokenize_attr(attr: str) -> List[str]:
                    def camel_case_split(s: str) -> List[str]:
                        matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", s)
                        return [m.group(0).lower() for m in matches]

                    if attr in ["__NONE__", "__UNK__"]:
                        res = [attr]
                    else:
                        res = []
                        for part in str(attr).split('_'):
                            res.extend(camel_case_split(part))

                    if len(res) >= self._max_seq_len:
                        return res[:self._max_seq_len]
                    else:
                        return res + ["__NONE__"] * (self._max_seq_len - len(res))

                attr_subtokens = {}
                for i, _, a in pd.read_csv(Path(self.dataset.root) / "mapping" / "attridx2attr.csv.gz").itertuples():
                    attr_subtokens[i] = _subtokenize_attr(a)

                def get_split(idx, is_training):
                    samples = []
                    for i in tqdm(idx):
                        samples.append((_augment_edge(_subtokenize(self.dataset[i][0], attr_subtokens, used_subtokens, is_training)), self.dataset[i][1]))
                    return samples

                used_subtokens = {"__NONE__": 0, "__UNK__": 1, "__NUM__": 2}
                train_samples = get_split(split_idx["train"], True)
                self._num_nodeattrs = len(used_subtokens)
                valid_samples = get_split(split_idx["valid"], False)
                test_samples = get_split(split_idx["test"], False)

            else:
                self._num_nodeattrs = len(pd.read_csv(Path(self.dataset.root) / "mapping" / "attridx2attr.csv.gz")["attr"])

                def get_split(idx):
                    samples = []
                    for i in tqdm(idx):
                        samples.append((_augment_edge(self.dataset[i][0]), self.dataset[i][1]))
                    return samples

                train_samples = get_split(split_idx["train"])
                valid_samples = get_split(split_idx["valid"])
                test_samples = get_split(split_idx["test"])

            pickle_path.parent.mkdir(exist_ok=True)
            with open(pickle_path, "wb") as f:
                pickle.dump(train_samples, f)
                pickle.dump(valid_samples, f)
                pickle.dump(test_samples, f)
                pickle.dump(self._num_nodeattrs, f)

        def collate(samples) -> Tuple[dgl.DGLGraph, List[List[str]]]:
            graphs, labels = map(list, zip(*samples))
            batched_graph = dgl.batch(graphs)
            return batched_graph, labels

        self.train_loader = DataLoader(train_samples, batch_size=batch_size, shuffle=True, collate_fn=collate)
        self.valid_loader = DataLoader(valid_samples, batch_size=batch_size, shuffle=False, collate_fn=collate)
        self.test_loader = DataLoader(test_samples, batch_size=batch_size, shuffle=False, collate_fn=collate)

    def _build_model(self, arch: str, norm: str, residual: bool, dim_feat: int, depth: int) -> nn.Module:
        edge_encoder = lambda: nn.Embedding(4, dim_feat)
        gnn = self._build_gnn(arch, norm, residual, dim_feat, depth, edge_encoder)
        return CodeEncDec(
            gnn,
            dim_feat,
            self._max_depth,
            self._max_seq_len,
            self._num_nodetypes,
            self._num_nodeattrs,
            self._vocab2idx,
            self._decoder_type,
        ).to(self.device)

    def _evaluate_score(self, y_true: List[Any], y_pred: List[th.Tensor], batch_sizes: List[int]) -> float:
        y_true = sum(y_true, [])
        y_pred = th.vstack([th.hstack([t.argmax(dim=1).view(b, -1) for t in l]) for l, b in zip(y_pred, batch_sizes)])
        y_pred = [decode_arr_to_seq(a, self._id2vocab, self._vocab2idx) for a in y_pred]
        metric: str = self.dataset.eval_metric
        return Evaluator(self.dataset_name).eval({"seq_ref": y_true, "seq_pred": y_pred})[metric]

    def _loss(self, out: th.Tensor, labels: List[List[str]]) -> th.Tensor:
        batched_label = th.vstack([encode_seq_to_arr(label, self._vocab2idx, self._max_seq_len) for label in labels])
        return sum([nn.CrossEntropyLoss()(out[i], batched_label[:, i].to(device=self.device)) for i in range(self._max_seq_len)])

    def get_emb(self, save_emb_name: str):
        for name, loader in {"train": self.train_loader, "valid": self.valid_loader, "test": self.test_loader}.items():
            embs = []
            labels = []
            for batch in tqdm(loader):
                g, l = batch
                with th.no_grad():
                    embs.append(self.model.get_emb(g.to(self.device)).detach().cpu())
                    labels.extend(l)
            embs = th.cat(embs, dim=1)
            labels = th.vstack([encode_seq_to_arr(label, self._vocab2idx, self._max_seq_len) for label in labels])

            save_path = self.emb_dir / f"{save_emb_name}_{name}"
            np.savez(save_path, embs=embs.numpy(), labels=labels.numpy())


class CodeEncDec(TeacherForcing, nn.Module):
    def __init__(
            self,
            gnn: nn.Module,
            dim_feat: int,
            max_depth: int,
            max_seq_len: int,
            num_nodetypes: int,
            num_nodeattrs: int,
            vocab2idx: Dict[str, int],
            decoder_type: str,
    ):

        super().__init__()
        self._dim_feat = dim_feat
        self._max_depth = max_depth
        self._max_seq_len = max_seq_len
        self._vocab2idx = vocab2idx
        self._decoder_type = decoder_type

        self.type_encoder = nn.Embedding(num_nodetypes, dim_feat)
        self.attr_encoder = nn.Embedding(num_nodeattrs, dim_feat)
        self.depth_encoder = nn.Embedding(max_depth + 1, dim_feat)
        self.node_mlp = nn.Sequential(
            nn.Linear(3 * dim_feat, 2 * dim_feat),
            nn.ReLU(),
            nn.Linear(2 * dim_feat, dim_feat),
        )
        self.gnn = gnn

        if decoder_type == "linear":
            self.linear_decoder = LinearDecoder(dim_feat, max_seq_len, vocab2idx)
        elif decoder_type == "lstm":
            self.lstm_decoder = LSTMDecoder(dim_feat, max_seq_len, vocab2idx)

    def forward(self, graph: dgl.DGLGraph, labels: Any) -> List[th.Tensor]:
        feats = self.get_emb(graph)  # (L+1, batch_size, dim_feat)

        if self._decoder_type == "linear":
            return self.linear_decoder(graph, feats, labels)
        elif self._decoder_type == "lstm":
            return self.lstm_decoder(graph, feats, labels)

    def get_emb(self, graph: dgl.DGLGraph) -> th.Tensor:
        type_emb = self.type_encoder(graph.ndata["feat"][:, 0])
        attr_emb = (self.attr_encoder(graph.ndata["feat"][:, 1:]) * (graph.ndata["feat"][:, 1:] > 0).unsqueeze(-1)).sum(dim=1)
        depth = graph.ndata["depth"].view(-1)
        depth[depth > self._max_depth] = self._max_depth
        depth_emb = self.depth_encoder(depth)
        feat = self.node_mlp(th.hstack((type_emb, attr_emb, depth_emb)))
        return self.gnn.get_emb(graph, feat)


def _get_vocab_mapping(words_list: List[List[str]], num_vocab: int) -> Tuple[Dict[str, int], List[str]]:
    vocab_count = defaultdict(int)
    for words in tqdm(words_list):
        for word in words:
            vocab_count[word] += 1
    idx2vocab = ["__SOS__", "__UNK__", "__EOS__"]
    idx2vocab += list(list(zip(*sorted([(c, w) for w, c in vocab_count.items()], reverse=True)[:num_vocab]))[1])
    vocab2idx = {w: i for i, w in enumerate(idx2vocab)}

    # test
    for idx, vocab in enumerate(idx2vocab):
        assert(idx == vocab2idx[vocab])
    assert(vocab2idx["__SOS__"] == 0)
    assert(vocab2idx["__UNK__"] == 1)
    assert(vocab2idx["__EOS__"] == 2)

    return vocab2idx, idx2vocab


def _augment_edge(graph: dgl.DGLGraph) -> dgl.DGLGraph:
    num_ast_edges = graph.num_edges()
    src_ast = th.hstack((graph.edges()[0], graph.edges()[1]))
    dst_ast = th.hstack((graph.edges()[1], graph.edges()[0]))
    attr_ast = th.vstack((th.zeros((num_ast_edges, 1)), th.ones((num_ast_edges, 1))))

    terminals = th.where(graph.ndata["is_attributed"] == 1)[0]
    num_nt_edges = terminals.shape[0] - 1
    src_nt = th.hstack((terminals[:-1], terminals[1:]))
    dst_nt = th.hstack((terminals[1:], terminals[:-1]))
    attr_nt = th.vstack((th.ones((num_nt_edges, 1)) * 2, th.ones((num_nt_edges, 1)) * 3))

    graph.remove_edges(np.arange(num_ast_edges))
    graph.add_edges(th.hstack((src_ast, src_nt)), th.hstack((dst_ast, dst_nt)), {"feat": th.vstack((attr_ast, attr_nt)).to(th.int64)})

    return graph


def _subtokenize(
        graph: dgl.DGLGraph,
        attr_subtokens: Dict[int, List[str]],
        used_subtokens: Dict[str, int],
        is_training: bool,
) -> dgl.DGLGraph:

    feat = th.hstack((graph.ndata["feat"][:, 0].view(-1, 1), th.zeros((graph.ndata["feat"].shape[0], 5)).to(graph.ndata["feat"])))

    for i in range(graph.ndata["feat"].shape[0]):
        if int(graph.ndata["feat"][i, 0]) == 67:  # Num
            feat[i, 1] = used_subtokens["__NUM__"]
        elif int(graph.ndata["feat"][i, 1]) == 10028:  # __NONE__
            pass
        else:
            for j, s in enumerate(attr_subtokens[int(graph.ndata["feat"][i, 1])]):
                if is_training:
                    if s not in used_subtokens:
                        used_subtokens[s] = len(used_subtokens)
                    feat[i, j + 1] = used_subtokens[s]
                else:
                    if s in used_subtokens:
                        feat[i, j + 1] = used_subtokens[s]
                    else:
                        feat[i, j + 1] = used_subtokens["__UNK__"]

    graph.ndata["feat"] = feat

    return graph

In [9]:
from datetime import datetime
import itertools
import os
from pathlib import Path
import re
import subprocess
from typing import List, Optional, Tuple

import numpy as np
import torch as th
from torch import cuda, optim
from torch.backends import cudnn
from torch.nn import functional as F


DEVICE = th.device("cuda:0" if cuda.is_available() else "cpu")
TIMESTR = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
GITHASH = os.environ.get("GITHASH") or subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).strip().decode()
MODEL_REGEX = re.compile(r".*models/(?P<dataset>[^\\]+)/[^\\]+/(?P<arch>gin-[^_]+)_(?P<norm>[^_]+)_d(?P<dim>\d+)_l(?P<depth>\d+)_s(?P<seed>\d+)_e\d+")

if cudnn.enabled:
    cudnn.deterministic = True
    cudnn.benchmark = False


def build_pathname(
        *,
        prefix: bool=False,
        arch: Optional[str]=None,
        norm: Optional[str]=None,
        residual: bool=False,
        dim_feat: Optional[int]=None,
        depth: Optional[int]=None,
        seed: Optional[int]=None,
        batch_size: Optional[int]=None,
        learning_rate: Optional[Tuple[float, int, float]]=None,
        suffix: Optional[str]=None
) -> str:

    name = ""

    if prefix:
        name = f"{TIMESTR}_{GITHASH}"
    if arch:
        name += f"_{arch}"
    if norm:
        name += f"_{norm}"
    if residual:
        name += f"_res"
    if dim_feat:
        name += f"_d{dim_feat}"
    if depth:
        name += f"_l{depth}"
    if seed is not None:
        name += f"_s{seed}"
    if batch_size is not None:
        name += f"_b{batch_size}"
    if learning_rate is not None:
        name += f"_r{learning_rate[0]}_{learning_rate[1]}_{learning_rate[2]}"
    if suffix:
        name += suffix

    if name.startswith("_"):
        name = name[1:]

    return name


def run_training(
        args_str: str,
        task: Task,
        arch: str,
        norm: str,
        residual: bool,
        dim_feat: int,
        depth: int,
        seed: int,
        epochs: int,
        batch_size: int,
        initial_lr: float,
        lr_interval: int,
        lr_scale: float,
        *,
        save: bool=True,
):
    print(f"Starting the training for arch: {arch}")
    task.build_model(arch, norm, residual, dim_feat, depth)
    task.set_seed(seed)

    log_path = task.log_dir / build_pathname(prefix=True, arch=arch, norm=norm, residual=residual, dim_feat=dim_feat, depth=depth, seed=seed, batch_size=batch_size, learning_rate=(initial_lr, lr_interval, lr_scale), suffix="_train.log")
    save_dir = task.model_dir / build_pathname(prefix=True, arch=arch, norm=norm, residual=residual, dim_feat=dim_feat, depth=depth, seed=seed, batch_size=batch_size, learning_rate=(initial_lr, lr_interval, lr_scale))
    save_dir.mkdir()
    save_name = build_pathname(arch=arch, norm=norm, residual=residual, dim_feat=dim_feat, depth=depth, seed=seed, batch_size=batch_size, learning_rate=(initial_lr, lr_interval, lr_scale))

    log(log_path, args_str)

    optimizer = optim.Adam(task.model.parameters(), lr=initial_lr)
    if lr_interval > 0:
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_interval, gamma=lr_scale, verbose=True)
    elif lr_interval < 0:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", factor=lr_scale, patience=-lr_interval, verbose=True)
    else:
        raise ValueError("lr_interval cannot be 0")

    task.train(epochs, optimizer, scheduler, save=(log_path, save_dir, save_name) if save else None)


def run_test(task: Task, model_paths: List[str]):
    for model_path in model_paths:
        if not Path(model_path).exists():
            continue

        match = MODEL_REGEX.search(model_path)
        assert match is not None
        arch = match.group("arch")
        norm = match.group("norm")
        dim_feat = int(match.group("dim"))
        depth = int(match.group("depth"))

        task.build_model(arch, norm, False, dim_feat, depth)
        task.model.load_state_dict(th.load(model_path, map_location=DEVICE))
        valid_perf = task.evaluate(task.valid_loader, silent=True)
        test_perf = task.evaluate(task.test_loader, silent=True)

        print(f"{model_path}: val={valid_perf} test={test_perf}")


def main(
        args_str: str="args",
        dataset_name: str="ogbg-code2",
        batch_size: int=50,
        arch: str="gin-simple",
        norm: str="none",
        residual: bool=False,
        dim_feat: int=20, # default was 200
        depth: int=5,
        seed: int=50,
        epochs: int=50,
        initial_lr: float=1e-3,
        lr_interval: int=15,
        lr_scale: float=0.2,
        train: bool=True,
        test: Optional[List[str]]=None,
        save: bool=True,
        code2_use_subtoken: bool=False,
        code2_decoder_type: Optional[str]=None,
):
    if dataset_name == "ogbg-code2":
        task = OGBCodeTask(dataset_name, DEVICE, save, code2_use_subtoken, code2_decoder_type)
    else:
        raise NotImplementedError
    task.load_dataset(batch_size)

    if train:
        run_training(args_str, task, arch, norm, residual, dim_feat, depth, seed, epochs, batch_size, initial_lr, lr_interval, lr_scale, save=save)
    elif test:
        run_test(task, test)

In [11]:
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import dropout_edge
from torch_geometric.logging import init_wandb, log

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
DATASETS_DIR = "./pyg_datasets"

# https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn.py
dataset = Planetoid(root=DATASETS_DIR, name='Cora')
data = dataset[0].to(DEVICE)


class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels, cached=True)
        self.conv2 = GCNConv(hidden_channels, out_channels, cached=True)

    def forward(self, x, edge_index, edge_weight=None):
        #edge_index, edge_mask = dropout_edge(edge_index, p=0.8)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv1(x, edge_index, edge_weight).relu()
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)
        return x


model = GCN(dataset.num_features, 16, dataset.num_classes)
model.to(DEVICE)

optimizer = torch.optim.Adam([
    dict(params=model.conv1.parameters(), weight_decay=5e-4),
    dict(params=model.conv2.parameters(), weight_decay=0)
], lr=0.01)  # Only perform weight-decay on first convolution.


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index, data.edge_attr)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x, data.edge_index, data.edge_attr).argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs


best_val_acc = final_test_acc = 0
for epoch in range(20000):
    loss = train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    
    if epoch % 50 == 0:
        log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)

Epoch: 000, Loss: 1.9404, Train: 0.6643, Val: 0.4200, Test: 0.4650
Epoch: 050, Loss: 0.0189, Train: 1.0000, Val: 0.7640, Test: 0.7920
Epoch: 100, Loss: 0.0234, Train: 1.0000, Val: 0.7620, Test: 0.7920
Epoch: 150, Loss: 0.0140, Train: 1.0000, Val: 0.7560, Test: 0.7920
Epoch: 200, Loss: 0.0157, Train: 1.0000, Val: 0.7700, Test: 0.7920
Epoch: 250, Loss: 0.0114, Train: 1.0000, Val: 0.7720, Test: 0.7920
Epoch: 300, Loss: 0.0105, Train: 1.0000, Val: 0.7520, Test: 0.7920
Epoch: 350, Loss: 0.0100, Train: 1.0000, Val: 0.7740, Test: 0.7920
Epoch: 400, Loss: 0.0106, Train: 1.0000, Val: 0.7680, Test: 0.7920
Epoch: 450, Loss: 0.0119, Train: 1.0000, Val: 0.7620, Test: 0.7920
Epoch: 500, Loss: 0.0063, Train: 1.0000, Val: 0.7700, Test: 0.7920
Epoch: 550, Loss: 0.0059, Train: 1.0000, Val: 0.7580, Test: 0.8020
Epoch: 600, Loss: 0.0059, Train: 1.0000, Val: 0.7760, Test: 0.8020
Epoch: 650, Loss: 0.0169, Train: 1.0000, Val: 0.7700, Test: 0.8020
Epoch: 700, Loss: 0.0071, Train: 1.0000, Val: 0.7660, Test: 0.