In [1]:
import os, sys
from google.colab import drive
from pathlib import Path

NOTEBOOK_NAME = "MLAP_test_eval"

# --- do not change below this ---
DRIVE_PATH = "/content/drive/"
drive.mount(DRIVE_PATH, force_remount=True)

# shell commands for directory with space must be
# quoted, but not necessary in python
COLAB_PATH = "Colab Notebooks"
COLLAB_PATH_ESC = f"\"{COLAB_PATH}\""

# python path
nb_path = (
    "/".join(('drive/MyDrive', 
              COLAB_PATH, 
              "venv_" + NOTEBOOK_NAME)
    )
)

# shell path
nb_path_bash = (
    "/".join(('drive/MyDrive', 
              COLLAB_PATH_ESC, 
              "venv_" + NOTEBOOK_NAME)
    )
)


try:
  os.makedirs(nb_path)
except FileExistsError:
  # already created in G-drive
  print("Google Drive Folder already existed.")

try:
  # create symlink from drive to workspace
  os.symlink(nb_path, "/content/notebooks")
except FileExistsError:
  # already created in G-drive
  print("Symlink already existed.")

sys.path.insert(0, nb_path)

Mounted at /content/drive/
Google Drive Folder already existed.
Symlink already existed.


In [2]:
!pip install --quiet torch_geometric ogb

### Encoder

In [3]:
import numpy as np
import torch
from torch_geometric.utils import dropout_edge, degree, to_undirected, scatter, to_networkx

class ASTNodeEncoder(torch.nn.Module):
    '''
        Input:
            x: default node feature. the first and second column represents node type and node attributes.
            depth: The depth of the node in the AST.

        Output:
            emb_dim-dimensional vector

    '''
    def __init__(self, emb_dim, num_nodetypes, num_nodeattributes, max_depth):
        super(ASTNodeEncoder, self).__init__()

        self.max_depth = max_depth

        self.type_encoder = torch.nn.Embedding(num_nodetypes, emb_dim)
        self.attribute_encoder = torch.nn.Embedding(num_nodeattributes, emb_dim)
        self.depth_encoder = torch.nn.Embedding(self.max_depth + 1, emb_dim)

        self.node_mlp = torch.nn.Sequential(
            torch.nn.Linear(3 * emb_dim, 2 * emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(2 * emb_dim, emb_dim),
        )

    def forward(self, x, depth):
        depth[depth > self.max_depth] = self.max_depth
        mlp_input = torch.hstack(
            (
                self.type_encoder(x[:,0]), 
                self.attribute_encoder(x[:,1]), 
                self.depth_encoder(depth)
             )
        )
        return self.node_mlp(mlp_input)

### Utils - AST / MLAP

Utilities for editing and parsing the AST inputs.

In [None]:
import numpy as np
import torch
from torch_geometric.utils import dropout_edge, degree, to_undirected, scatter, to_networkx
import networkx as nx

def get_vocab_mapping(seq_list, num_vocab):
    '''
        Input:
            seq_list: a list of sequences
            num_vocab: vocabulary size
        Output:
            vocab2idx:
                A dictionary that maps vocabulary into integer index.
                Additioanlly, we also index '__UNK__' and '__EOS__'
                '__UNK__' : out-of-vocabulary term
                '__EOS__' : end-of-sentence

            idx2vocab:
                A list that maps idx to actual vocabulary.
    '''

    vocab_cnt = {}
    vocab_list = []
    for seq in seq_list:
        for w in seq:
            if w in vocab_cnt:
                vocab_cnt[w] += 1
            else:
                vocab_cnt[w] = 1
                vocab_list.append(w)

    cnt_list = np.array([vocab_cnt[w] for w in vocab_list])
    topvocab = np.argsort(-cnt_list, kind = 'stable')[:num_vocab]

    print('Coverage of top {} vocabulary:'.format(num_vocab))
    print(float(np.sum(cnt_list[topvocab]))/np.sum(cnt_list))

    vocab2idx = {vocab_list[vocab_idx]: idx for idx, vocab_idx in enumerate(topvocab)}
    idx2vocab = [vocab_list[vocab_idx] for vocab_idx in topvocab]

    vocab2idx['__UNK__'] = num_vocab
    idx2vocab.append('__UNK__')

    vocab2idx['__EOS__'] = num_vocab + 1
    idx2vocab.append('__EOS__')

    # test the correspondence between vocab2idx and idx2vocab
    for idx, vocab in enumerate(idx2vocab):
        assert(idx == vocab2idx[vocab])

    # test that the idx of '__EOS__' is len(idx2vocab) - 1.
    # This fact will be used in decode_arr_to_seq, when finding __EOS__
    assert(vocab2idx['__EOS__'] == len(idx2vocab) - 1)

    return vocab2idx, idx2vocab

def augment_edge(data):
    '''
        Input:
            data: PyG data object
        Output:
            data (edges are augmented in the following ways):
                data.edge_index: Added next-token edge. The inverse edges were also added.
                data.edge_attr (torch.Long):
                    data.edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1)
                    data.edge_attr[:,1]: whether it is original direction (0) or inverse direction (1)
    '''
    ##### AST edge
    edge_index_ast = data.edge_index
    edge_attr_ast = torch.zeros((edge_index_ast.size(1), 2))

    ##### Inverse AST edge
    edge_index_ast_inverse = torch.stack([edge_index_ast[1], edge_index_ast[0]], dim = 0)
    edge_attr_ast_inverse = torch.cat([torch.zeros(edge_index_ast_inverse.size(1), 1), torch.ones(edge_index_ast_inverse.size(1), 1)], dim = 1)

    ##### Next-token edge

    ## Obtain attributed nodes and get their indices in dfs order
    # attributed_node_idx = torch.where(data.node_is_attributed.view(-1,) == 1)[0]
    # attributed_node_idx_in_dfs_order = attributed_node_idx[torch.argsort(data.node_dfs_order[attributed_node_idx].view(-1,))]

    ## Since the nodes are already sorted in dfs ordering in our case, we can just do the following.
    attributed_node_idx_in_dfs_order = torch.where(data.node_is_attributed.view(-1,) == 1)[0]

    ## build next token edge
    # Given: attributed_node_idx_in_dfs_order
    #        [1, 3, 4, 5, 8, 9, 12]
    # Output:
    #    [[1, 3, 4, 5, 8, 9]
    #     [3, 4, 5, 8, 9, 12]
    edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]], dim = 0)
    edge_attr_nextoken = torch.cat([torch.ones(edge_index_nextoken.size(1), 1), torch.zeros(edge_index_nextoken.size(1), 1)], dim = 1)

    ##### Inverse next-token edge
    edge_index_nextoken_inverse = torch.stack([edge_index_nextoken[1], edge_index_nextoken[0]], dim = 0)
    edge_attr_nextoken_inverse = torch.ones((edge_index_nextoken.size(1), 2))

    data.edge_index = torch.cat([edge_index_ast, edge_index_ast_inverse, edge_index_nextoken, edge_index_nextoken_inverse], dim = 1)
    data.edge_attr = torch.cat([edge_attr_ast,   edge_attr_ast_inverse, edge_attr_nextoken,  edge_attr_nextoken_inverse], dim = 0)

    return data

def encode_y_to_arr(data, vocab2idx, max_seq_len):
    '''
    Input:
        data: PyG graph object
        output: add y_arr to data 
    '''
    # PyG >= 1.5.0
    seq = data.y
    data.y_arr = encode_seq_to_arr(seq, vocab2idx, max_seq_len)
    return data

def encode_seq_to_arr(seq, vocab2idx, max_seq_len):
    '''
    Input:
        seq: A list of words
        output: add y_arr (torch.Tensor)
    '''
    augmented_seq = seq[:max_seq_len] + ['__EOS__'] * max(0, max_seq_len - len(seq))
    return torch.tensor([[vocab2idx[w] if w in vocab2idx else vocab2idx['__UNK__'] for w in augmented_seq]], dtype = torch.long)


def decode_arr_to_seq(arr, idx2vocab):
    '''
        Input: torch 1d array: y_arr
        Output: a sequence of words.
    '''
    # find the position of __EOS__ (the last vocab in idx2vocab)
    eos_idx_list = (arr == len(idx2vocab) - 1).nonzero() 
    if len(eos_idx_list) > 0:
        # find the smallest __EOS__
        clippted_arr = arr[: torch.min(eos_idx_list)] 
    else:
        clippted_arr = arr

    return list(map(lambda x: idx2vocab[x], clippted_arr.cpu()))

### Utils - CAP
Utilities for generating Graph Contrastive Pairs

In [4]:
# ---- CAP functions ----
# from: https://github.com/CRIPAC-DIG/GCA/blob/cd6a9f0cf06c0b8c48e108a6aab743585f6ba6f1/pGRACE/functional.py
# and: https://github.com/CRIPAC-DIG/GCA/blob/cd6a9f0cf06c0b8c48e108a6aab743585f6ba6f1/pGRACE/utils.py
def compute_pr(edge_index, damp: float = 0.85, k: int = 10):
    # page rank
    # interesting comment: https://github.com/CRIPAC-DIG/GCA/issues/4
    num_nodes = edge_index.max().item() + 1
    deg_out = degree(edge_index[0])
    x = torch.ones((num_nodes, )).to(edge_index.device).to(torch.float32)

    for i in range(k):
        edge_msg = x[edge_index[0]] / deg_out[edge_index[0]]
        agg_msg = scatter(edge_msg, edge_index[1], reduce='sum')

        x = (1 - damp) * x + damp * agg_msg

    return x

def eigenvector_centrality(data):
    graph = to_networkx(data)
    x = nx.eigenvector_centrality_numpy(graph)
    x = [x[i] for i in range(data.num_nodes)]
    return torch.tensor(x, dtype=torch.float32).to(data.edge_index.device)


def drop_feature(x, drop_prob):
    drop_mask = torch.empty((x.size(1),), dtype=torch.float32, device=x.device).uniform_(0, 1) < drop_prob
    x = x.clone()
    x[:, drop_mask] = 0
    return x


def drop_feature_weighted(x, w, p: float, threshold: float = 0.7):
    w = w / w.mean() * p
    w = w.where(w < threshold, torch.ones_like(w) * threshold)
    drop_prob = w.repeat(x.size(0)).view(x.size(0), -1)

    drop_mask = torch.bernoulli(drop_prob).to(torch.bool)

    x = x.clone()
    x[drop_mask] = 0.

    return x

def drop_feature_weighted_2(x, w, p: float, threshold: float = 0.7, dgi_task=False):
    w = w / w.mean() * p
    # if (dgi_task):
    #     threshold = 0.9

    w = w.where(w < threshold, torch.ones_like(w) * threshold)
    drop_prob = w

    if (dgi_task):
        drop_mask = torch.bernoulli(1. - drop_prob).to(torch.bool)
    else:
        drop_mask = torch.bernoulli(drop_prob).to(torch.bool)

    x = x.clone()
    x[:, drop_mask] = 0.

    return x

def feature_drop_weights(x, node_c):
    x = x.to(torch.bool).to(torch.float32)
    w = x.t() @ node_c
    w = w.log()
    s = (w.max() - w) / (w.max() - w.mean())

    return s


def feature_drop_weights_dense(x, node_c):
    x = x.abs()
    w = x.t() @ node_c
    w = w.log()
    s = (w.max() - w) / (w.max() - w.mean())

    return s


def drop_edge_weighted(edge_index, edge_weights, p: float, threshold: float = 1., dgi_task=False):
    edge_weights = edge_weights / edge_weights.mean() * p
    # if (dgi_task):
    #     threshold = 0.9

    edge_weights = edge_weights.where(edge_weights < threshold, torch.ones_like(edge_weights) * threshold)

    if (dgi_task): 
        # drop edges by importance
        sel_mask = torch.bernoulli(edge_weights).to(torch.bool)
    else:
        sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool)

    return edge_index[:, sel_mask]


def degree_drop_weights(edge_index):
    edge_index_ = to_undirected(edge_index)
    deg = degree(edge_index_[1])
    deg_col = deg[edge_index[1]].to(torch.float32)
    s_col = torch.log(deg_col)
    weights = (s_col.max() - s_col) / (s_col.max() - s_col.mean())

    return weights


def pr_drop_weights(edge_index, aggr: str = 'sink', k: int = 10):
    pv = compute_pr(edge_index, k=k)
    pv_row = pv[edge_index[0]].to(torch.float32)
    pv_col = pv[edge_index[1]].to(torch.float32)
    s_row = torch.log(pv_row)
    s_col = torch.log(pv_col)
    if aggr == 'sink':
        s = s_col
    elif aggr == 'source':
        s = s_row
    elif aggr == 'mean':
        s = (s_col + s_row) * 0.5
    else:
        s = s_col
    weights = (s.max() - s) / (s.max() - s.mean())

    return weights


def evc_drop_weights(data):
    evc = eigenvector_centrality(data)
    evc = evc.where(evc > 0, torch.zeros_like(evc))
    evc = evc + 1e-8
    s = evc.log()

    edge_index = data.edge_index
    s_row, s_col = s[edge_index[0]], s[edge_index[1]]
    s = s_col

    return (s.max() - s) / (s.max() - s.mean())

def graph_perturb(data, drop_scheme='pr'):
  if drop_scheme == 'degree':
      drop_weights = degree_drop_weights(data.edge_index)
      edge_index_ = to_undirected(data.edge_index)
      node_deg = degree(edge_index_[1])
      feature_weights = feature_drop_weights(data.x, node_c=node_deg)
  elif drop_scheme == 'pr':
      drop_weights = pr_drop_weights(data.edge_index, aggr='sink', k=200)
      node_pr = compute_pr(data.edge_index)
      feature_weights = feature_drop_weights(data.x, node_c=node_pr)
  elif drop_scheme == 'evc':
      drop_weights = evc_drop_weights(data)
      node_evc = eigenvector_centrality(data)
      feature_weights = feature_drop_weights(data.x, node_c=node_evc)
  else:
      feature_weights = torch.ones((data.x.size(1),))
      drop_weights = None
  
  return feature_weights, drop_weights

def drop_edge(data, drop_edge_rate, drop_weights, drop_scheme='pr', drop_edge_weighted_threshold=0.7, dgi_task=False):
  if drop_scheme == 'uniform':
      return dropout_edge(data.edge_index, p=drop_edge_rate)[0]
  elif drop_scheme in ['degree', 'evc', 'pr']:
      return drop_edge_weighted(
          data.edge_index, 
          drop_weights, 
          p=drop_edge_rate, 
          threshold=drop_edge_weighted_threshold,
          dgi_task=dgi_task
        )
  else:
      raise Exception(f'undefined drop scheme: {drop_scheme}')

def get_contrastive_graph_pair(data, drop_scheme='pr', drop_feature_rates=(0.7, 0.7), drop_edge_rates=(0.5, 0.5), dgi_task=False):
  # use augmentation scheme to determine the weights of each node
  # i.e. pagerank, eigenvector centrality, node degree
  feat_weights, drop_weights = graph_perturb(data, drop_scheme)

  # apply drop edge according to computed features
  dr_e_1, dr_e_2 = drop_edge_rates
  edge_index_1 = drop_edge(data, dr_e_1, drop_weights, drop_scheme, dgi_task=dgi_task)

  if (not dgi_task):
    edge_index_2 = drop_edge(data, dr_e_2, drop_weights, drop_scheme)

  dr_f_1, dr_f_2 = drop_feature_rates

  if drop_scheme in ['pr', 'degree', 'evc']:
    # graph-aware drop feature
    x_1 = drop_feature_weighted_2(data.x, feat_weights, dr_f_1, dgi_task=dgi_task)
    #e_1 = drop_feature_weighted_2(data.edge_attr, feat_weights, dr_f_1)

    if not dgi_task:
        x_2 = drop_feature_weighted_2(data.x, feat_weights, dr_f_2, dgi_task=dgi_task)
        #e_2 = drop_feature_weighted_2(data.edge_attr, feat_weights, dr_f_2)
  else:
    # naive drop feature
    x_1 = drop_feature(data.x, dr_f_1)
    #e_1 = drop_feature(data.edge_attr, dr_f_1)
    
    x_2 = drop_feature(data.x, dr_f_2)
    e_2 = drop_feature(data.edge_attr, dr_f_2)
  
  if dgi_task:
      return (x_1, edge_index_1)

  return (
      # graph 1
      (x_1, edge_index_1),
      # graph 2
      (x_2, edge_index_2)
  )

### GIN

In [5]:
import torch
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops

from torch.nn import Linear, BatchNorm1d

class GINConv(MessagePassing):
    def __init__(self, dim_h, mlp, **kwargs):
        super(GINConv, self).__init__(aggr='add', **kwargs)

        self.mlp = mlp
        self.bn = BatchNorm1d(dim_h)
        self.edge_encoder = Linear(2, dim_h)

    
    def forward(self, x, edge_index, edge_attr):
        edge_attr = self.edge_encoder(edge_attr)
        edge_index, _ = remove_self_loops(edge_index)
        output = self.mlp(
            self.propagate(edge_index, x=x, edge_attr=edge_attr)
        )
        return self.bn(output)
    
    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out, x):
        return aggr_out + x

    def __repr__(self):
        return self.__class__.__name__

### Decoders

In [6]:
import torch
from torch import nn
from torch.nn import functional as F


class LinearDecoder(torch.nn.Module):
    def __init__(self, dim_h, max_seq_len, vocab2idx, device):
        """
        Noted in the MLAP paper to have performed better than the LSTM
        """
        super().__init__()
        self.max_seq_len = max_seq_len
        self.vocab2idx = vocab2idx
        self.decoders = nn.ModuleList(
            [nn.Linear(dim_h, len(vocab2idx)) for _ in range(max_seq_len)]
        )

    def forward(self, batch_size, layer_reps, labels, training=False):
        return [d(layer_reps[-1]) for d in self.decoders]


class LSTMDecoder(torch.nn.Module):
    def __init__(self, dim_h, max_seq_len, vocab2idx, device):
        super(LSTMDecoder, self).__init__()
        
        self.max_seq_len = max_seq_len
        self.vocab2idx = vocab2idx

        self.lstm = nn.LSTMCell(dim_h, dim_h)
        self.w_hc = nn.Linear(dim_h * 2, dim_h)
        self.layernorm = nn.LayerNorm(dim_h)
        self.vocab_encoder = nn.Embedding(len(vocab2idx), dim_h)
        self.vocab_bias = nn.Parameter(torch.zeros(len(vocab2idx)))

        self.device = device
    
    def forward(self, batch_size, layer_reps, labels, training=False):
        if (training):
            batched_label = torch.vstack(
                [
                    encode_seq_to_arr(label, self.vocab2idx, self.max_seq_len - 1) 
                    for label in labels
                ]
            )
            batched_label = torch.hstack((torch.zeros((batch_size, 1), dtype=torch.int64), batched_label))
            true_emb = self.vocab_encoder(batched_label.to(device=self.device))
        
        h_t, c_t = layer_reps[-1].clone(), layer_reps[-1].clone()

        layer_reps = layer_reps.transpose(0,1)
        output = []

        pred_emb = self.vocab_encoder(torch.zeros((batch_size), dtype=torch.int64, device=self.device))
        vocab_mat = self.vocab_encoder(torch.arange(len(self.vocab2idx), dtype=torch.int64, device=self.device))

        for i in range(self.max_seq_len):
            if (training): 
                # teacher forcing
                input = true_emb[:, i]
            else:
                input = pred_emb
            
            h_t, c_t = self.lstm(input, (h_t, c_t))

            # (batch_size, L + 1)
            a = F.softmax(torch.bmm(layer_reps, h_t.unsqueeze(-1)).squeeze(-1), dim=1)  
            context = torch.bmm(a.unsqueeze(1), layer_reps).squeeze(1)

            # (batch_size, dim_h)
            pred_emb = torch.tanh(self.layernorm(self.w_hc(torch.hstack((h_t, context)))))  

            # (batch_size, len(vocab)) x max_seq_len
            output.append(torch.matmul(pred_emb, vocab_mat.T) + self.vocab_bias.unsqueeze(0))
        
        return output

### MLAP

In [7]:
import torch
from torch.nn import Linear, Sequential, ReLU, ELU, Sigmoid

from torch_geometric.nn.conv import GINConv
from torch_geometric.nn.norm import GraphNorm
from torch_geometric.nn.glob import AttentionalAggregation

from torch.nn import functional as F

class DISC(torch.nn.Module):
    def __init__(self, dim_h):
        super(DISC, self).__init__()

        W = torch.empty(dim_h, dim_h)
        torch.nn.init.xavier_normal_(W)

        self.W = torch.nn.Parameter(W)
        self.W.requires_grad = True

        self.sig = Sigmoid()
    
    def forward(self, h, s):
        out = torch.matmul(self.W, s)
        out = torch.matmul(h, out.unsqueeze(-1))
        return self.sig(out)


class MLAP_GIN(torch.nn.Module):
    def __init__(self, dim_h, batch_size, depth, node_encoder, norm=False, residual=False, dropout=False):
        super(MLAP_GIN, self).__init__()

        self.dim_h = dim_h
        self.batch_size = batch_size
        self.depth = depth

        self.node_encoder = node_encoder

        self.norm = norm
        self.residual = residual
        self.dropout = dropout

        self.loss_fn = torch.nn.BCELoss(reduction='sum')
        self.discriminator = DISC(dim_h)

        # non-linear projection function for cl task
        self.projection = Sequential(
            Linear(dim_h, int(dim_h/8)),
            ELU(),
            Linear(int(dim_h/8), dim_h)
        )

        # GIN layers
        self.layers = torch.nn.ModuleList(
            [GINConv(Sequential(
                Linear(dim_h, dim_h),
                ReLU(),
                Linear(dim_h, dim_h))) for _ in range(depth)])
            
        # normalization layers
        self.norm = torch.nn.ModuleList([GraphNorm(dim_h) for _ in range(self.depth)])
        
        # layer-wise attention poolings
        self.att_poolings = torch.nn.ModuleList(
            [
                AttentionalAggregation(
                Sequential(Linear(self.dim_h, 2*self.dim_h), 
                           ReLU(), 
                           Linear(2*self.dim_h, 1))) 
                for _ in range(depth)
            ]
        )
        
    def contrastive_loss(self, g1_x, g2_x):
        # compute projections + L2 row-wise normalizations
        g1_projections = torch.nn.functional.normalize(
            self.projection(g1_x), p=2, dim=1
        )
        g2_projections = torch.nn.functional.normalize(
            self.projection(g2_x), p=2, dim=1
        )
        
        g1_proj_T = torch.transpose(g1_projections, 0, 1)
        g2_proj_T = torch.transpose(g2_projections, 0, 1)

        inter_g1 = torch.exp(torch.matmul(g1_projections, g1_proj_T))
        inter_g2 = torch.exp(torch.matmul(g2_projections, g2_proj_T))
        intra_view = torch.exp(torch.matmul(g1_projections, g2_proj_T))

        # main diagonal
        corresponding_terms = torch.diagonal(intra_view, 0) 
        non_matching_intra = torch.diagonal(intra_view, -1).sum()
        non_matching_inter_g1 = torch.diagonal(inter_g1, -1).sum()
        non_matching_inter_g2 = torch.diagonal(inter_g2, -1).sum()

        # inter-view pairs using g1
        corresponding_terms_g1 = corresponding_terms / (
            corresponding_terms + 
            non_matching_inter_g1 + 
            non_matching_intra
        )
        corresponding_terms_g1 = torch.log(corresponding_terms_g1)

        # inter-view pairs using g2
        corresponding_terms_g2 = corresponding_terms / (
            corresponding_terms + 
            non_matching_inter_g2 + 
            non_matching_intra
        )
        corresponding_terms_g2 = torch.log(corresponding_terms_g2)

        # contrasting terms of both divided by total nodes
        loss = (
            corresponding_terms_g1.sum() + 
            corresponding_terms_g2.sum()
        ) / (
            g1_x.shape[0] + 
            g2_x.shape[0]
        )
        
        loss = loss / self.batch_size
        return loss
    
    def layer_loop(self, x, edge_index, batch, cl=False, cl_all=False, dgi_task=False):
        cl_embs = []
        for d in range(self.depth):
            x_in = x

            # get node representation at layer d
            x = self.layers[d](x, edge_index)
            
            if self.norm:
                x = self.norm[d](x, batch)
            
            if d < self.depth - 1:
                x = F.relu(x)
            
            if self.dropout:
                x = F.dropout(x)
            
            if self.residual:
                x = x + x_in

            if not cl:
                # use attention pooling for given depth
                h_g = self.att_poolings[d](x, batch)
                self.graph_embs.append(h_g)

            if (
                (cl and cl_all) or 
                (cl and (d == self.depth-1)) or 
                (dgi_task and (d == self.depth-1))
            ):
                # if using contrastive learning or DGI
                cl_embs += [x]
            
        return cl_embs

    def forward(self, batched_data, cl=False, cl_all=False, dgi_task=False):
        self.graph_embs = []
        # non-augmented graph
        # note: populates self.graph_embs

        node_depth = batched_data.node_depth
        x_emb = self.node_encoder(batched_data.x, node_depth.view(-1,))
        edge_index = batched_data.edge_index
        batch = batched_data.batch

        self.layer_loop(x_emb, edge_index, batch, dgi_task=dgi_task)

        agg = self.aggregate()
        self.graph_embs.append(agg)
        output = torch.stack(self.graph_embs, dim=0)

        # dgi task
        dgi_loss = 0
        if dgi_task:
            for i in range(self.batch_size // 5):
                g = batched_data.get_example(i)

                nd = g.node_depth
                b = g.batch
                
                # contrastive pair
                g1, g2 = self.get_contrastive_pair_from_batch(g, dgi_task=True)
                g_diff_embs = self.layer_loop(g1, g2, b, dgi_task=True)[0]

                g.x = self.node_encoder(g.x, nd.view(-1,).clone())
                g_embs = self.layer_loop(g.x, g.edge_index, g.batch, dgi_task=True)[0]

                # dgi objective on final_layer_embs, g_diff_embs, and output
                agg = agg.clone()
                positive = self.discriminator(g_embs, agg[i])
                ones = torch.ones_like(positive)
                negative = self.discriminator(g_diff_embs, agg[i])
                zeros = torch.zeros_like(negative)

                dgi_loss += (
                    self.loss_fn(positive, ones) + self.loss_fn(negative, zeros)
                ) / (positive.shape[0] + negative.shape[0])
            
            dgi_loss /= (self.batch_size // 5)

        # contrastive learning task
        cl_loss = 0
        if cl:
            for i in range(self.batch_size // 5):
                g = batched_data.get_example(i)

                # contrastive pair
                g1, g2 = self.get_contrastive_pair_from_batch(g, dgi_task=False)
                g1_embs = self.get_node_embedding(g.batch, g1, cl=True, cl_all=cl_all)
                g2_embs = self.get_node_embedding(g.batch, g2, cl=True, cl_all=cl_all)

                batch_cl_loss = 0
                for j in range(len(g1_embs)):
                    batch_cl_loss += self.contrastive_loss(g1_embs[j], g2_embs[j])
                
                batch_cl_loss = batch_cl_loss / len(g1_embs)
                cl_loss += batch_cl_loss
            
            cl_loss /= (self.batch_size // 5)

        return output, cl_loss, dgi_loss

    def get_node_embedding(self, batch, g, cl, cl_all):
        g_x, g_edge_index = g
        return self.layer_loop(
            g_x.clone(), 
            g_edge_index, 
            batch, 
            cl=cl, 
            cl_all=cl_all
        )

    def get_contrastive_pair_from_batch(self, g, dgi_task=False):
        g_clone = g.clone()
        nd = g.node_depth
        # encode the nodes in the clone of g using given encoding network
        g_clone.x = self.node_encoder(g_clone.x, nd.view(-1,).clone())

        # create contrastive pairs from the input graph
        return get_contrastive_graph_pair(g_clone, dgi_task=dgi_task)

    def aggregate(self):
        pass

class MLAP_Sum(MLAP_GIN):
    def aggregate(self):
        return torch.stack(self.graph_embs, dim=0).sum(dim=0)

class MLAP_Weighted(MLAP_GIN):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight = torch.nn.Parameter(torch.ones(self.depth, 1, 1))

    def aggregate(self):
        a = F.softmax(self.weight, dim=0)
        h = torch.stack(self.graph_embs, dim=0)
        return (a * h).sum(dim=0)

### Model

In [8]:
import torch

class Model(torch.nn.Module):
    def __init__(self, batch_size, depth, dim_h, max_seq_len, node_encoder, vocab2idx, device):
        super(Model, self).__init__()
        self.batch_size = batch_size
        self.depth = depth
        self.dim_h = dim_h
        self.max_seq_len = max_seq_len
        self.device = device

        # token to idx lookup
        self.vocab2idx = vocab2idx 
        
        # architecture choices
        self.node_encoder = node_encoder
        self.gnn = MLAP_Weighted(
            dim_h, batch_size, depth, 
            node_encoder, 
            norm=True, 
            residual=True, 
            dropout=True
        )
        self.decoder = LinearDecoder(
            dim_h, max_seq_len, vocab2idx, device
        )

    def forward(self, batched_data, labels, training=False, cl=False, cl_all=False, dgi_task=False):
        # GNN layer, contrastive work done here
        embeddings, cl_loss, dgi_loss = self.gnn(
            batched_data, 
            cl=cl, 
            cl_all=cl_all, 
            dgi_task=dgi_task
        )

        predictions = self.decoder(len(labels), embeddings, labels, training=training)

        # for each batch, the prediction for the ith word is a logit
        # decoding each prediction to a word is done in the evaluation task in main
        return predictions, cl_loss, dgi_loss

In [9]:
!cd $nb_path_bash && mkdir "checkpoints"

mkdir: cannot create directory ‘checkpoints’: File exists


### Main

Model configuration and training loop.

In [15]:
import datetime
import numpy as np
import pandas as pd
import os

import torch
from torch_geometric.loader import DataLoader

import torch.optim as optim
from torchvision import transforms
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator


def train(model, device, loader, optimizer, scheduler, multicls_criterion, epoch, 
          alpha=0.05, 
          cl=False, 
          cl_all=False, 
          dgi_task=False,
          eval_hook=lambda x: x,
    ):
    # total loss for this epoch
    loss_accum = 0

    chkpt_folder = nb_path + '/checkpoints/epoch' + str(epoch)
    if not os.path.exists(chkpt_folder):
        os.mkdir(chkpt_folder)

    if cl and dgi_task:
        raise Exception("Cannot use both a contrastive and dgi loss term\n")

    for step, batch in enumerate(loader):
        # run eval if requested
        eval_hook(step)

        batch = batch.to(device)
        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            # train
            labels = [batch.y[i] for i in range(len(batch.y))]
            pred_list, cl_loss, dgi_loss = model(
                batch, labels, training=True,
                cl=cl, 
                cl_all=cl_all, 
                dgi_task=dgi_task
            )
            optimizer.zero_grad()

            # loss + update
            loss = 0
            for i in range(len(pred_list)):
                loss += (1-alpha) * multicls_criterion(
                    pred_list[i].to(torch.float32), 
                    batch.y_arr[:, i]
                )

            loss /= len(pred_list)
            if cl:
                loss -= alpha * cl_loss
            if dgi_task:
                loss -= alpha * dgi_loss

            with torch.autograd.set_detect_anomaly(True):
                loss.backward()
            optimizer.step()

            # report loss after batch
            loss_accum += loss.item()
            print('Average loss after batch ' + str(step) + ': ' + str(loss_accum / (step + 1)))
            print(f"\tContrastive Term: {cl_loss:.3f}")
        
        if ((step+1) % 35 == 0 or step == len(loader)-1): 
            # save model after every 35 batches
            print("Checkpoint saved.")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': loss_accum / (step + 1),
            }, chkpt_folder + '/model' + str((step+1) // 35) + '.pt')

    # end of this epoch
    print('Average training loss: {}'.format(loss_accum / (step + 1)))
    return loss_accum / (step + 1)

def eval(model, device, loader, evaluator, arr_to_seq):
    """
    Use official OGB evaluator to test results of model output
    """
    seq_ref_list = []
    seq_pred_list = []
    for step, batch in enumerate(loader):
        batch = batch.to(device)
        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                labels = [batch.y[i] for i in range(len(batch.y))]
                # no cl by default
                pred_list, _, _ = model(batch, labels) 

            mat = []
            for i in range(len(pred_list)):
                # get model's predictions
                mat.append(torch.argmax(pred_list[i], dim=1).view(-1, 1))
            
            # save for eval
            seq_ref_list.extend(labels)
            mat = torch.cat(mat, dim=1)
            seq_pred = [arr_to_seq(arr) for arr in mat]
            seq_pred_list.extend(seq_pred)

    input_dict = {"seq_ref": seq_ref_list, "seq_pred": seq_pred_list}
    return evaluator.eval(input_dict)

def randomly_mask(dataset, size):
    bool_mask = np.zeros(len(dataset), dtype=bool)
    bool_mask[:size] = True
    np.random.shuffle(bool_mask)
    out = dataset[bool_mask]
    return out


def main(
      starting_chkpt=None, 
      cl=False, 
      cl_all=False, 
      dgi_task=False, 
      run_eval_every_n_batches=None, 
      # CL hyperparameter
      alpha=0.05
  ):
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # model & training conf
    depth = 3
    epochs = 50
    learning_rate = 0.001
    step_size = 10
    decay_rate = 0.1
    weight_decay = 0.00005
    dim_h = 512

    # model initialization
    node_encoder = ASTNodeEncoder(
        dim_h, 
        num_nodetypes=len(nodetypes_mapping['type']), 
        num_nodeattributes=len(nodeattributes_mapping['attr']), 
        max_depth=20
    )
    model = Model(
        batch_size, 
        depth, 
        dim_h, 
        max_seq_len, 
        node_encoder, 
        vocab2idx, 
        DEVICE
    ).to(DEVICE)
    num_params = sum(p.numel() for p in model.parameters())
    print(f'Model # Params: {num_params}')
    print("-------------\n\n\n")

    # training configuration
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=decay_rate)
    multicls_criterion = torch.nn.CrossEntropyLoss()

    starting_epoch = 1

    if (starting_chkpt != None):
        checkpoint = torch.load(starting_chkpt)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        starting_epoch = checkpoint['epoch']

    valid_curve = []
    test_curve = []
    train_curve = []
    trainL_curve = []

    def eval_hook():
      print('Evaluating...')
      train_perf = eval(model, DEVICE, train_loader, evaluator, arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab))
      valid_perf = eval(model, DEVICE, valid_loader, evaluator, arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab))
      test_perf = eval(model, DEVICE, test_loader, evaluator, arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab))

      print(
          'Train:', train_perf[dataset.eval_metric],
          'Validation:', valid_perf[dataset.eval_metric],
          'Test:', test_perf[dataset.eval_metric]
      )
      
      return train_perf, valid_perf, test_perf

    for epoch in range(starting_epoch, epochs + 1):
        print (datetime.datetime.now().strftime('%Y.%m.%d-%H:%M:%S'))
        print("Epoch {} training...".format(epoch))
        print ("lr: ", optimizer.param_groups[0]['lr'])
        
        # training model
        train_loss = train(
            model, 
            DEVICE, 
            train_loader, 
            optimizer, 
            scheduler, 
            multicls_criterion, 
            epoch, 
            cl=cl, 
            alpha=alpha,
            cl_all=cl_all, 
            dgi_task=dgi_task,
            # run evaluation every n batches
            eval_hook=lambda x: (
                eval_hook() 
                  if run_eval_every_n_batches is not None and 
                  (x != 0 and x % run_eval_every_n_batches == 0) 
                else None
            ),
        )
        scheduler.step()

        # run evaluation after each epoch anyways
        train_perf, valid_perf, test_perf = eval_hook()

        print(f"Train Loss: {train_loss}")

        train_curve.append(train_perf[dataset.eval_metric])
        valid_curve.append(valid_perf[dataset.eval_metric])
        test_curve.append(test_perf[dataset.eval_metric])
        trainL_curve.append(train_loss)

    print('F1')
    best_val_epoch = np.argmax(np.array(valid_curve))
    best_train = max(train_curve)
    print('Finished training!')
    print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
    print('Test score: {}'.format(test_curve[best_val_epoch]))
    print('Finished test: {}, Validation: {}, Train: {}, epoch: {}, best train: {}, best loss: {}'
          .format(
              test_curve[best_val_epoch], 
              valid_curve[best_val_epoch], 
              train_curve[best_val_epoch],
              best_val_epoch, 
              best_train, 
              min(trainL_curve)
          )
    )

In [12]:
num_vocab = 5000
max_seq_len = 5
batch_size = 50

# dataset objects
# best to load these only once in colab
# otherwise, memory never freed and runtime crashes
dataset_name = "ogbg-code2"
dataset = PygGraphPropPredDataset(dataset_name)
evaluator = Evaluator(dataset_name)

split_idx = dataset.get_idx_split()
vocab2idx, idx2vocab = get_vocab_mapping([dataset.data.y[i] for i in split_idx['train']], num_vocab)
dataset.transform = transforms.Compose([augment_edge, lambda data: encode_y_to_arr(data, vocab2idx, max_seq_len)])

nodetypes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz'))
nodeattributes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz'))



Coverage of top 5000 vocabulary:
0.9025832389087423


In [13]:
full_training = randomly_mask(dataset[split_idx["train"]], batch_size * 800)
full_valid = randomly_mask(dataset[split_idx["valid"]], batch_size * 800)
full_test = randomly_mask(dataset[split_idx["test"]], batch_size * 800)

train_loader = DataLoader(full_training, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(full_valid, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(full_test, batch_size=batch_size, shuffle=False)

### Evaluation

Leaderboard on OBG for code2: [LINK](https://ogb.stanford.edu/docs/leader_graphprop/#ogbg-code2)

In [None]:
# for multi-layer CL to be enabled, cl and cl_all must both be True
# default alpha is 0.05

# very overfit
# Train: 0.20 Validation: 0.069 Test: 0.072
main(
  cl=True, 
  cl_all=True, 
  dgi_task=False, 
  run_eval_every_n_batches=300, 
)

Model # Params: 23612920
-------------



2023.04.24-20:33:10
Epoch 1 training...
lr:  0.001
Average loss after batch 0: 8.101700782775879
	Contrastive Term: -0.103
Average loss after batch 1: 7.766380310058594
	Contrastive Term: -0.097
Average loss after batch 2: 7.057017644246419
	Contrastive Term: -0.106
Average loss after batch 3: 6.399185299873352
	Contrastive Term: -0.108
Average loss after batch 4: 5.960888767242432
	Contrastive Term: -0.104
Average loss after batch 5: 5.561765074729919
	Contrastive Term: -0.106
Average loss after batch 6: 5.279287678854806
	Contrastive Term: -0.097
Average loss after batch 7: 5.044170379638672
	Contrastive Term: -0.102
Average loss after batch 8: 4.883655839496189
	Contrastive Term: -0.107
Average loss after batch 9: 4.76292052268982
	Contrastive Term: -0.101
Average loss after batch 10: 4.6109444878318095
	Contrastive Term: -0.102
Average loss after batch 11: 4.479612906773885
	Contrastive Term: -0.104
Average loss after batch 12: 4.3923354882