In [1]:
import os, sys
from google.colab import drive
from pathlib import Path

NOTEBOOK_NAME = "MLAP_test"

# --- do not change below this ---
DRIVE_PATH = "/content/drive/"
drive.mount(DRIVE_PATH, force_remount=True)

# shell commands for directory with space must be
# quoted, but not necessary in python
COLAB_PATH = "Colab Notebooks"
COLLAB_PATH_ESC = f"\"{COLAB_PATH}\""

# python path
nb_path = (
    "/".join(('drive/MyDrive', 
              COLAB_PATH, 
              "venv_" + NOTEBOOK_NAME)
    )
)

# shell path
nb_path_bash = (
    "/".join(('drive/MyDrive', 
              COLLAB_PATH_ESC, 
              "venv_" + NOTEBOOK_NAME)
    )
)


try:
  os.makedirs(nb_path)
except FileExistsError:
  # already created in G-drive
  print("Google Drive Folder already existed.")

try:
  # create symlink from drive to workspace
  os.symlink(nb_path, "/content/notebooks")
except FileExistsError:
  # already created in G-drive
  print("Symlink already existed.")

sys.path.insert(0, nb_path)

Mounted at /content/drive/
Google Drive Folder already existed.
Symlink already existed.


In [2]:
!cd $nb_path_bash && pwd

/content/drive/MyDrive/Colab Notebooks/venv_MLAP_test


In [None]:
# !pip install --quiet --target=$nb_path_bash torch

In [None]:
# !pip install --quiet --target=$nb_path_bash torchvision

In [7]:
# !pip install --quiet --target=$nb_path_bash torch_geometric -q

In [8]:
# !pip install --quiet --target=$nb_path_bash ogb

In [9]:
# !pip install torch-scatter

In [3]:
import torch
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops

from torch.nn import Linear, BatchNorm1d

class GINConv(MessagePassing):
    def __init__(self, dim_h, mlp, **kwargs):
        super(GINConv, self).__init__(aggr='add', **kwargs)

        self.mlp = mlp
        self.bn = BatchNorm1d(dim_h)
        self.edge_encoder = Linear(2, dim_h)

    
    def forward(self, x, edge_index, edge_attr):
        edge_attr = self.edge_encoder(edge_attr)
        edge_index, _ = remove_self_loops(edge_index)
        output = self.mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr))
        return self.bn(output)
    
    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out, x):
        return aggr_out + x

    def __repr__(self):
        return self.__class__.__name__

In [4]:
import torch
import numpy as np
import torch

class ASTNodeEncoder(torch.nn.Module):
    '''
        Input:
            x: default node feature. the first and second column represents node type and node attributes.
            depth: The depth of the node in the AST.
        Output:
            emb_dim-dimensional vector
    '''
    def __init__(self, emb_dim, num_nodetypes, num_nodeattributes, max_depth):
        super(ASTNodeEncoder, self).__init__()

        self.max_depth = max_depth

        self.type_encoder = torch.nn.Embedding(num_nodetypes, emb_dim)
        self.attribute_encoder = torch.nn.Embedding(num_nodeattributes, emb_dim)
        self.depth_encoder = torch.nn.Embedding(self.max_depth + 1, emb_dim)

        self.node_mlp = torch.nn.Sequential(
            torch.nn.Linear(3 * emb_dim, 2 * emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(2 * emb_dim, emb_dim),
        )

    def forward(self, x, depth):
        depth[depth > self.max_depth] = self.max_depth
        return self.type_encoder(x[:,0]) + self.attribute_encoder(x[:,1]) + self.depth_encoder(depth)



def get_vocab_mapping(seq_list, num_vocab):
    '''
        Input:
            seq_list: a list of sequences
            num_vocab: vocabulary size
        Output:
            vocab2idx:
                A dictionary that maps vocabulary into integer index.
                Additioanlly, we also index '__UNK__' and '__EOS__'
                '__UNK__' : out-of-vocabulary term
                '__EOS__' : end-of-sentence
            idx2vocab:
                A list that maps idx to actual vocabulary.
    '''

    vocab_cnt = {}
    vocab_list = []
    for seq in seq_list:
        for w in seq:
            if w in vocab_cnt:
                vocab_cnt[w] += 1
            else:
                vocab_cnt[w] = 1
                vocab_list.append(w)

    cnt_list = np.array([vocab_cnt[w] for w in vocab_list])
    topvocab = np.argsort(-cnt_list, kind = 'stable')[:num_vocab]

    print('Coverage of top {} vocabulary:'.format(num_vocab))
    print(float(np.sum(cnt_list[topvocab]))/np.sum(cnt_list))

    vocab2idx = {vocab_list[vocab_idx]: idx for idx, vocab_idx in enumerate(topvocab)}
    idx2vocab = [vocab_list[vocab_idx] for vocab_idx in topvocab]

    vocab2idx['__UNK__'] = num_vocab
    idx2vocab.append('__UNK__')

    vocab2idx['__EOS__'] = num_vocab + 1
    idx2vocab.append('__EOS__')

    # test the correspondence between vocab2idx and idx2vocab
    for idx, vocab in enumerate(idx2vocab):
        assert(idx == vocab2idx[vocab])

    # test that the idx of '__EOS__' is len(idx2vocab) - 1.
    # This fact will be used in decode_arr_to_seq, when finding __EOS__
    assert(vocab2idx['__EOS__'] == len(idx2vocab) - 1)

    return vocab2idx, idx2vocab

def augment_edge(data):
    '''
        Input:
            data: PyG data object
        Output:
            data (edges are augmented in the following ways):
                data.edge_index: Added next-token edge. The inverse edges were also added.
                data.edge_attr (torch.Long):
                    data.edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1)
                    data.edge_attr[:,1]: whether it is original direction (0) or inverse direction (1)
    '''

    ##### AST edge
    edge_index_ast = data.edge_index
    edge_attr_ast = torch.zeros((edge_index_ast.size(1), 2))

    ##### Inverse AST edge
    edge_index_ast_inverse = torch.stack([edge_index_ast[1], edge_index_ast[0]], dim = 0)
    edge_attr_ast_inverse = torch.cat([torch.zeros(edge_index_ast_inverse.size(1), 1), torch.ones(edge_index_ast_inverse.size(1), 1)], dim = 1)


    ##### Next-token edge

    ## Obtain attributed nodes and get their indices in dfs order
    # attributed_node_idx = torch.where(data.node_is_attributed.view(-1,) == 1)[0]
    # attributed_node_idx_in_dfs_order = attributed_node_idx[torch.argsort(data.node_dfs_order[attributed_node_idx].view(-1,))]

    ## Since the nodes are already sorted in dfs ordering in our case, we can just do the following.
    attributed_node_idx_in_dfs_order = torch.where(data.node_is_attributed.view(-1,) == 1)[0]

    ## build next token edge
    # Given: attributed_node_idx_in_dfs_order
    #        [1, 3, 4, 5, 8, 9, 12]
    # Output:
    #    [[1, 3, 4, 5, 8, 9]
    #     [3, 4, 5, 8, 9, 12]
    edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]], dim = 0)
    edge_attr_nextoken = torch.cat([torch.ones(edge_index_nextoken.size(1), 1), torch.zeros(edge_index_nextoken.size(1), 1)], dim = 1)


    ##### Inverse next-token edge
    edge_index_nextoken_inverse = torch.stack([edge_index_nextoken[1], edge_index_nextoken[0]], dim = 0)
    edge_attr_nextoken_inverse = torch.ones((edge_index_nextoken.size(1), 2))

    data.edge_index = torch.cat([edge_index_ast, edge_index_ast_inverse, edge_index_nextoken, edge_index_nextoken_inverse], dim = 1)
    data.edge_attr = torch.cat([edge_attr_ast,   edge_attr_ast_inverse, edge_attr_nextoken,  edge_attr_nextoken_inverse], dim = 0)

    return data

def encode_y_to_arr(data, vocab2idx, max_seq_len):
    '''
    Input:
        data: PyG graph object
        output: add y_arr to data 
    '''

    # PyG >= 1.5.0
    seq = data.y

    data.y_arr = encode_seq_to_arr(seq, vocab2idx, max_seq_len)

    return data

def encode_seq_to_arr(seq, vocab2idx, max_seq_len):
    '''
    Input:
        seq: A list of words
        output: add y_arr (torch.Tensor)
    '''

    augmented_seq = seq[:max_seq_len] + ['__EOS__'] * max(0, max_seq_len - len(seq))
    return torch.tensor([[vocab2idx[w] if w in vocab2idx else vocab2idx['__UNK__'] for w in augmented_seq]], dtype = torch.long)


def decode_arr_to_seq(arr, idx2vocab):
    '''
        Input: torch 1d array: y_arr
        Output: a sequence of words.
    '''

    eos_idx_list = (arr == len(idx2vocab) - 1).nonzero() # find the position of __EOS__ (the last vocab in idx2vocab)
    if len(eos_idx_list) > 0:
        clippted_arr = arr[: torch.min(eos_idx_list)] # find the smallest __EOS__
    else:
        clippted_arr = arr

    return list(map(lambda x: idx2vocab[x], clippted_arr.cpu()))

In [5]:
import torch
from torch import nn
from torch.nn import functional as F

class LSTMDecoder(torch.nn.Module):
    def __init__(self, dim_h, max_seq_len, vocab2idx, device):
        super(LSTMDecoder, self).__init__()
        
        self.max_seq_len = max_seq_len
        self.vocab2idx = vocab2idx

        self.lstm = nn.LSTMCell(dim_h, dim_h)
        self.w_hc = nn.Linear(dim_h * 2, dim_h)
        self.layernorm = nn.LayerNorm(dim_h)
        self.vocab_encoder = nn.Embedding(len(vocab2idx), dim_h)
        self.vocab_bias = nn.Parameter(torch.zeros(len(vocab2idx)))

        self.device = device
    
    def forward(self, batch_size, layer_reps, labels, training=False):
        if (training):
            batched_label = torch.vstack([encode_seq_to_arr(label, self.vocab2idx, self.max_seq_len - 1) for label in labels])
            batched_label = torch.hstack((torch.zeros((batch_size, 1), dtype=torch.int64), batched_label))
            true_emb = self.vocab_encoder(batched_label.to(device=self.device))
        
        h_t, c_t = layer_reps[-1].clone(), layer_reps[-1].clone()

        layer_reps = layer_reps.transpose(0,1)
        output = []

        pred_emb = self.vocab_encoder(torch.zeros((batch_size), dtype=torch.int64, device=self.device))
        vocab_mat = self.vocab_encoder(torch.arange(len(self.vocab2idx), dtype=torch.int64, device=self.device))

        for i in range(self.max_seq_len):
            if (training): # teacher forcing
                input = true_emb[:, i]
            else:
                input = pred_emb
            
            h_t, c_t = self.lstm(input, (h_t, c_t))

            a = F.softmax(torch.bmm(layer_reps, h_t.unsqueeze(-1)).squeeze(-1), dim=1)  # (batch_size, L + 1)
            context = torch.bmm(a.unsqueeze(1), layer_reps).squeeze(1)
            pred_emb = torch.tanh(self.layernorm(self.w_hc(torch.hstack((h_t, context)))))  # (batch_size, dim_h)

            # (batch_size, len(vocab)) x max_seq_len
            output.append(torch.matmul(pred_emb, vocab_mat.T) + self.vocab_bias.unsqueeze(0))
        
        return output

In [6]:
import torch
from torch_scatter import scatter
from torch_geometric.utils import dropout_adj, degree, to_undirected
import networkx as nx

# from: https://github.com/CRIPAC-DIG/GCA/blob/cd6a9f0cf06c0b8c48e108a6aab743585f6ba6f1/pGRACE/functional.py
# and: https://github.com/CRIPAC-DIG/GCA/blob/cd6a9f0cf06c0b8c48e108a6aab743585f6ba6f1/pGRACE/utils.py

def compute_pr(edge_index, damp: float = 0.85, k: int = 10):
    # page rank
    # interesting comment: https://github.com/CRIPAC-DIG/GCA/issues/4
    num_nodes = edge_index.max().item() + 1
    deg_out = degree(edge_index[0])
    x = torch.ones((num_nodes, )).to(edge_index.device).to(torch.float32)

    for i in range(k):
        edge_msg = x[edge_index[0]] / deg_out[edge_index[0]]
        agg_msg = scatter(edge_msg, edge_index[1], reduce='sum')

        x = (1 - damp) * x + damp * agg_msg

    return x

def eigenvector_centrality(data):
    graph = to_networkx(data)
    x = nx.eigenvector_centrality_numpy(graph)
    x = [x[i] for i in range(data.num_nodes)]
    return torch.tensor(x, dtype=torch.float32).to(data.edge_index.device)


def drop_feature(x, drop_prob):
    drop_mask = torch.empty((x.size(1),), dtype=torch.float32, device=x.device).uniform_(0, 1) < drop_prob
    x = x.clone()
    x[:, drop_mask] = 0

    return x


def drop_feature_weighted(x, w, p: float, threshold: float = 0.7):
    w = w / w.mean() * p
    w = w.where(w < threshold, torch.ones_like(w) * threshold)
    drop_prob = w.repeat(x.size(0)).view(x.size(0), -1)

    drop_mask = torch.bernoulli(drop_prob).to(torch.bool)

    x = x.clone()
    x[drop_mask] = 0.

    return x


def drop_feature_weighted_2(x, w, p: float, threshold: float = 0.7):
    w = w / w.mean() * p
    w = w.where(w < threshold, torch.ones_like(w) * threshold)
    drop_prob = w

    drop_mask = torch.bernoulli(drop_prob).to(torch.bool)

    x = x.clone()
    x[:, drop_mask] = 0.

    return x


def feature_drop_weights(x, node_c):
    x = x.to(torch.bool).to(torch.float32)
    w = x.t() @ node_c
    w = w.log()
    s = (w.max() - w) / (w.max() - w.mean())

    return s


def feature_drop_weights_dense(x, node_c):
    x = x.abs()
    w = x.t() @ node_c
    w = w.log()
    s = (w.max() - w) / (w.max() - w.mean())

    return s


def drop_edge_weighted(edge_index, edge_weights, p: float, threshold: float = 1.):
    edge_weights = edge_weights / edge_weights.mean() * p
    edge_weights = edge_weights.where(edge_weights < threshold, torch.ones_like(edge_weights) * threshold)
    sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool)

    return edge_index[:, sel_mask]


def degree_drop_weights(edge_index):
    edge_index_ = to_undirected(edge_index)
    deg = degree(edge_index_[1])
    deg_col = deg[edge_index[1]].to(torch.float32)
    s_col = torch.log(deg_col)
    weights = (s_col.max() - s_col) / (s_col.max() - s_col.mean())

    return weights


def pr_drop_weights(edge_index, aggr: str = 'sink', k: int = 10):
    pv = compute_pr(edge_index, k=k)
    pv_row = pv[edge_index[0]].to(torch.float32)
    pv_col = pv[edge_index[1]].to(torch.float32)
    s_row = torch.log(pv_row)
    s_col = torch.log(pv_col)
    if aggr == 'sink':
        s = s_col
    elif aggr == 'source':
        s = s_row
    elif aggr == 'mean':
        s = (s_col + s_row) * 0.5
    else:
        s = s_col
    weights = (s.max() - s) / (s.max() - s.mean())

    return weights


def evc_drop_weights(data):
    evc = eigenvector_centrality(data)
    evc = evc.where(evc > 0, torch.zeros_like(evc))
    evc = evc + 1e-8
    s = evc.log()

    edge_index = data.edge_index
    s_row, s_col = s[edge_index[0]], s[edge_index[1]]
    s = s_col

    return (s.max() - s) / (s.max() - s.mean())

def graph_perturb(data, drop_scheme='pr'):
  if drop_scheme == 'degree':
      drop_weights = degree_drop_weights(data.edge_index)
      edge_index_ = to_undirected(data.edge_index)
      node_deg = degree(edge_index_[1])
      feature_weights = feature_drop_weights(data.x, node_c=node_deg)
  elif drop_scheme == 'pr':
      drop_weights = pr_drop_weights(data.edge_index, aggr='sink', k=200)
      node_pr = compute_pr(data.edge_index)
      feature_weights = feature_drop_weights(data.x, node_c=node_pr)
  elif drop_scheme == 'evc':
      drop_weights = evc_drop_weights(data)
      node_evc = eigenvector_centrality(data)
      feature_weights = feature_drop_weights(data.x, node_c=node_evc)
  else:
      feature_weights = torch.ones((data.x.size(1),))
      drop_weights = None
  
  return feature_weights, drop_weights

def drop_edge(data, drop_edge_rate, drop_weights, drop_scheme='pr', drop_edge_weighted_threshold=0.7):
  if drop_scheme == 'uniform':
      return dropout_adj(data.edge_index, p=drop_edge_rate)[0]
  elif drop_scheme in ['degree', 'evc', 'pr']:
      return drop_edge_weighted(
          data.edge_index, 
          drop_weights, 
          p=drop_edge_rate, 
          threshold=drop_edge_weighted_threshold
      )
  else:
      raise Exception(f'undefined drop scheme: {drop_scheme}')

def get_contrastive_graph_pair(data, drop_scheme='pr', drop_feature_rates=(0.7, 0.7), drop_edge_rates=(0.5, 0.5)):
  # use augmentation scheme to determine the weights of each node
  # i.e. pagerank, eigenvector centrality, node degree
  feat_weights, drop_weights = graph_perturb(data, drop_scheme)

  # apply drop edge according to computed features
  dr_e_1, dr_e_2 = drop_edge_rates
  edge_index_1 = drop_edge(data, dr_e_1, drop_weights, drop_scheme)
  edge_index_2 = drop_edge(data, dr_e_2, drop_weights, drop_scheme)

  dr_f_1, dr_f_2 = drop_feature_rates

  if drop_scheme in ['pr', 'degree', 'evc']:
    # graph-aware drop feature
    x_1 = drop_feature_weighted_2(data.x, feat_weights, dr_f_1)
    e_1 = drop_feature_weighted_2(data.edge_attr, feat_weights, dr_f_1)
    
    x_2 = drop_feature_weighted_2(data.x, feat_weights, dr_f_2)
    e_2 = drop_feature_weighted_2(data.edge_attr, feat_weights, dr_f_2)
  else:
    # naive drop feature
    x_1 = drop_feature(data.x, dr_f_1)
    e_1 = drop_feature(data.edge_attr, dr_f_1)
    
    x_2 = drop_feature(data.x, dr_f_2)
    e_2 = drop_feature(data.edge_attr, dr_f_2)

  return (
      # graph 1
      (x_1, edge_index_1, e_1),
      # graph 2
      (x_2, edge_index_2, e_2),
  )

In [7]:
import torch
from torch.nn import Linear, Sequential, ReLU
from torch_geometric.nn.norm import GraphNorm
from torch_geometric.nn.glob import AttentionalAggregation

from torch.nn import functional as F

class MLAP_GIN(torch.nn.Module):
    def __init__(self, dim_h, depth, node_encoder, norm=False, residual=False):
        super(MLAP_GIN, self).__init__()

        self.dim_h = dim_h
        self.depth = depth

        self.node_encoder = node_encoder

        self.norm = norm
        self.residual = residual

        # GIN layers
        self.layers = torch.nn.ModuleList(
            [GINConv(dim_h, Sequential(
                Linear(dim_h, dim_h),
                ReLU(),
                Linear(dim_h, dim_h))) for _ in range(depth)]
        )
            
        # normalization layers
        self.norm = torch.nn.ModuleList([GraphNorm(dim_h) for _ in range(self.depth)])
        
        # layer-wise attention poolings
        self.att_poolings = torch.nn.ModuleList(
            [AttentionalAggregation(
                Sequential(
                    Linear(self.dim_h, 2*self.dim_h), 
                    ReLU(), 
                    Linear(2*self.dim_h, 1))) for _ in range(depth)
            ]
        )
        
    def forward(self, batched_data):
        self.graph_embs = []

        x = batched_data.x
        edge_index = batched_data.edge_index
        edge_attr = batched_data.edge_attr
        node_depth = batched_data.node_depth
        batch = batched_data.batch

        # graph augmentation step
        g_1, g_2 = get_contrastive_graph_pair(batched_data)
        print(g_1)
        print(g_2)
        
        x = self.node_encoder(x, node_depth.view(-1,))

        for d in range(self.depth):
            x_in = x

            x = self.layers[d](x, edge_index, edge_attr)
            if (self.norm):
                x = self.norm[d](x, batch)
            if (d < self.depth - 1):
                x = F.relu(x)
            if (self.residual):
                x = x + x_in
            
            h_g = self.att_poolings[d](x, batch)
            self.graph_embs.append(h_g)
        
        agg = self.aggregate()
        self.graph_embs.append(agg)
        output = torch.stack(self.graph_embs, dim=0)
        return output
    
    def aggregate(self):
        pass

class MLAP_Sum(MLAP_GIN):
    def aggregate(self):
        return torch.stack(self.graph_embs, dim=0).sum(dim=0)

class MLAP_Weighted(MLAP_GIN):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight = torch.nn.Parameter(torch.ones(self.depth, 1, 1))

    def aggregate(self):
        a = F.softmax(self.weight, dim=0)
        h = torch.stack(self.graph_embs, dim=0)
        return (a * h).sum(dim=0)

In [8]:
import torch

class Model(torch.nn.Module):
    def __init__(self, batch_size, depth, dim_h, max_seq_len, node_encoder, vocab2idx, device):
        super(Model, self).__init__()
        
        self.batch_size = batch_size

        self.depth = depth
        self.dim_h = dim_h
        self.max_seq_len = max_seq_len

        self.node_encoder = node_encoder

        self.vocab2idx = vocab2idx

        self.device = device

        # contrastive learning must go here over each of the layers
        self.gnn = MLAP_Weighted(dim_h, depth, node_encoder, norm=True, residual=True)

        self.decoder = LSTMDecoder(dim_h, max_seq_len, vocab2idx, device)

    def forward(self, batched_data, labels, training=False):
        embeddings = self.gnn(batched_data)
        predictions = self.decoder(self.batch_size, embeddings, labels, training=training)

        # for each batch, the prediction for the ith word is a logit
        # decoding each prediction to a word is done in the evaluation task in main
        return predictions

In [9]:
import random
import torch
from torch_geometric.data import DataLoader

import torch.optim as optim
from torchvision import transforms

import numpy as np
import pandas as pd
import os

# importing OGB
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator

import sys
sys.path.append('../..')

DATA_ROOT = nb_path + "/ogb"


import datetime

def train(model, device, loader, optimizer, multicls_criterion):
    loss_accum = 0
    print('New epoch: ', 'loader size = ' + str(len(loader)))

    # TODO: cut down on loader size
    for step, batch in enumerate(loader):
        batch = batch.to(device)

        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            # pyg does batching in an interesting way
            labels = [batch.y[i] for i in range(len(batch.y))]
            pred_list = model(batch, labels, training=True)
            optimizer.zero_grad()

            loss = 0
            for i in range(len(pred_list)):
                loss += multicls_criterion(
                    pred_list[i].to(torch.float32), 
                    batch.y_arr[:, i]
                )

            # TODO: add contrastive learning objective
            loss = loss / len(pred_list)

            with torch.autograd.set_detect_anomaly(True):
                loss.backward()
            optimizer.step()

            loss_accum += loss.item()
            print('Average loss after batch ' + str(step) + ': ' + str(loss_accum / (step + 1)))

    print('Average training loss: {}'.format(loss_accum / (step + 1)))
    return loss_accum / (step + 1)

def eval(model, device, loader, evaluator, arr_to_seq):
    seq_ref_list = []
    seq_pred_list = []

    for step, batch in enumerate(loader):
        batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                labels = [batch.y[i] for i in range(len(batch.y))]
                pred_list = model(batch, labels) # training=False by default

            mat = []
            for i in range(len(pred_list)):
                mat.append(torch.argmax(pred_list[i], dim=1).view(-1, 1))
            mat = torch.cat(mat, dim=1)

            seq_pred = [arr_to_seq(arr) for arr in mat]

            seq_ref = labels

            seq_ref_list.extend(seq_ref)
            seq_pred_list.extend(seq_pred)

    input_dict = {"seq_ref": seq_ref_list, "seq_pred": seq_pred_list}
    return evaluator.eval(input_dict)


def main():
    # constants
    dataset_name = "ogbg-code2"

    num_vocab = 5000
    max_seq_len = 5

    depth = 3
    batch_size = 50
    epochs = 50
    learning_rate = 0.001
    step_size = 10
    decay_rate = 0.1
    weight_decay = 0.00005

    dim_h = 512

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    # TODO: add colab path as root
    dataset = PygGraphPropPredDataset(dataset_name, root="")


    split_idx = dataset.get_idx_split()
    vocab2idx, idx2vocab = get_vocab_mapping(
        [dataset.data.y[i] for i in split_idx['train']], 
        num_vocab
    )
    dataset.transform = transforms.Compose(
        [augment_edge, lambda data: encode_y_to_arr(data, vocab2idx, max_seq_len)]
    )

    evaluator = Evaluator(dataset_name)

    train_loader = DataLoader(dataset[split_idx["train"]], batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=batch_size, shuffle=False)

    nodetypes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz'))
    nodeattributes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz'))

    node_encoder = ASTNodeEncoder(
        dim_h, 
        num_nodetypes=len(nodetypes_mapping['type']), 
        num_nodeattributes=len(nodeattributes_mapping['attr']), 
        max_depth=20
    )

    model = Model(
        batch_size, 
        depth, 
        dim_h, 
        max_seq_len, 
        node_encoder, 
        vocab2idx, 
        device
    ).to(device)

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=decay_rate)

    multicls_criterion = torch.nn.CrossEntropyLoss()

    valid_curve = []
    test_curve = []
    train_curve = []
    trainL_curve = []

    for epoch in range(1, epochs + 1):
        print (datetime.datetime.now().strftime('%Y.%m.%d-%H:%M:%S'))
        print("Epoch {} training...".format(epoch))
        print ("lr: ", optimizer.param_groups[0]['lr'])
        train_loss = train(model, device, train_loader, optimizer, multicls_criterion)

        scheduler.step()

        print('Evaluating...')
        train_perf = eval(model, device, train_loader, evaluator, arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab))
        valid_perf = eval(model, device, valid_loader, evaluator, arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab))
        test_perf = eval(model, device, test_loader, evaluator, arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab))

        print('Train:', train_perf[dataset.eval_metric],
              'Validation:', valid_perf[dataset.eval_metric],
              'Test:', test_perf[dataset.eval_metric],
              'Train loss:', train_loss)

        train_curve.append(train_perf[dataset.eval_metric])
        valid_curve.append(valid_perf[dataset.eval_metric])
        test_curve.append(test_perf[dataset.eval_metric])
        trainL_curve.append(train_loss)

    print('F1')
    best_val_epoch = np.argmax(np.array(valid_curve))
    best_train = max(train_curve)
    print('Finished training!')
    print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
    print('Test score: {}'.format(test_curve[best_val_epoch]))

    print('Finished test: {}, Validation: {}, Train: {}, epoch: {}, best train: {}, best loss: {}'
          .format(test_curve[best_val_epoch], valid_curve[best_val_epoch], train_curve[best_val_epoch],
                  best_val_epoch, best_train, min(trainL_curve)))

if __name__ == "__main__":
    main()

Downloading http://snap.stanford.edu/ogb/data/graphproppred/code2.zip


Downloaded 0.91 GB: 100%|██████████| 934/934 [00:26<00:00, 34.87it/s]


Extracting code2.zip


Processing...


Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 452741/452741 [00:01<00:00, 339626.70it/s]


Converting graphs into PyG objects...


100%|██████████| 452741/452741 [00:21<00:00, 21075.63it/s]


Saving...


Done!


Coverage of top 5000 vocabulary:
0.9025832389087423




#Params: 15655312
2023.04.12-16:55:20
Epoch 1 training...
lr:  0.001
New epoch:  loader size = 8160
(tensor([[    0, 10028],
        [    0,  1461],
        [    0, 10028],
        ...,
        [    0, 10028],
        [    0,  4370],
        [    0, 10028]], device='cuda:0'), tensor([[    1,     5,     8,  ..., 10066, 10070, 10072],
        [    5,     6,    15,  ..., 10062, 10066, 10071]], device='cuda:0'), tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        ...,
        [1., 1.],
        [1., 1.],
        [1., 1.]], device='cuda:0'))
(tensor([[   59, 10028],
        [   35,  1461],
        [   93, 10028],
        ...,
        [   54, 10028],
        [   61,  4370],
        [   54, 10028]], device='cuda:0'), tensor([[    1,     1,     2,  ..., 10070, 10072, 10076],
        [    5,     7,     3,  ..., 10066, 10071, 10072]], device='cuda:0'), tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        ...,
        [0., 1.],
        [0., 1.],
        [0., 1.]], device='cuda:0

KeyboardInterrupt: ignored