# Heterogeneous Graph Transformer Pretraining

First, load relevant packages.

In [1]:
'''
PRETRAIN NODE EMBEDDING MODEL
This script contains the main function for pretraining the node embedding model.
'''

# standard imports
import numpy as np
import pandas as pd
from datetime import datetime

# import PyTorch and DGL
import torch
import dgl

# import PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

# path manipulation
from pathlib import Path

# import project config file
import sys
sys.path.append('../..')
import project_config

# custom imports
# from hyperparameters import parse_args, get_hyperparameters
# from dataloaders import load_graph, partition_graph, create_dataloaders
# from models import HGT
from utils import generate_subgraph

# check if CUDA is available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Hyperparameters

Define hyperparameters using code in `hyperparameters.py`.

In [2]:
'''
HYPERPARAMETERS

This file contains the hyperparameters for the node embedder.
'''

# argument parser for command line arguments
import argparse

# COMMAND LINE ARGUMENTS REMOVED

# PRE-TRAINING HYPERPARAMETERS
def get_hyperparameters():   
    '''
    Return hyperparameters for node embedder. Combine tunable hyperparameters with fixed hyperparameters.
    See parse_args() for all possible command line arguments.

    Args:
        args: command line arguments
    
    Tunable Parameters:
        num_feat: dimension of embedding layer
        num_heads: number of attention heads
        hidden_dim: dimension of hidden layer
        output_dim: dimension of output layer
        wd: weight decay
        dropout: dropout probability
        lr: learning rate
    '''

    # generate dictionary from command-line arguments
    # args_dict = vars(args)
    args_dict = {
        'node_list': project_config.NEUROKG_DIR / 'neuroKG_nodes.csv',
        'edge_list': project_config.NEUROKG_DIR / 'neuroKG_edges.csv',
        'save_dir': project_config.RESULTS_DIR / 'CIPHER' / 'pretraining',
        'num_feat': 2048,
        'num_heads': 4,
        'hidden_dim': 32,
        'output_dim': 128,
        'wd': 0.0,
        'dropout_prob': 0.3,
        'lr': 0.0001,
        'max_epochs': 250,
        'resume': None,
        'best_ckpt': None,
        'save_embeddings': False,
        'debug': True
    }

    # define fanout
    fanout = [1, 1, 1] # [1, 1, 1]

    # default hyperparameters
    hparams_dict = {
                # fixed parameters
                'pred_threshold': 0.5,
                'n_gpus': 1,
                'num_workers': 4,
                'train_batch_size': 1024, # 8
                'val_batch_size': 1024,
                'test_batch_size': 1024,
                'sampler_fanout': fanout,
                'num_layers': len(fanout),
                'negative_k': 1,
                'grad_clip': 1.0,
                'lr_factor': 0.01,
                'lr_patience': 100,
                'lr_threshold': 1e-4,
                'lr_threshold_mode': 'rel',
                'lr_cooldown': 0,
                'min_lr': 0,
                'eps': 1e-8,
                'seed': 42,
                'profiler': None,
                # see https://github.com/wandb/wandb/issues/714
                'wandb_save_dir': project_config.RESULTS_DIR / 'wandb' / 'pretraining',
                'log_every_n_steps': 10,
                'time': False,
                'sample_subgraph': False,
                'seed_node': 1,
                'n_walks': 100,
                'walk_length': 10
        }
    
    # combine tunable hyperparameters with fixed hyperparameters
    hparams = dict(args_dict, **hparams_dict)
    
    print('Pre-Training Hyperparameters: ', hparams)
    
    return hparams

# Dataloaders
Define dataloaders using code in `dataloaders.py`. First, define the `load_graph()` function.

In [3]:
# LOAD KNOWLEDGE GRAPH
def load_graph(hparams):

    # read in nodes and edges
    # could also provide as args.node_list and args.edge_list with arg as argument of function
    nodes = pd.read_csv(hparams['node_list'], dtype = {'node_index': int}, low_memory = False)
    edges = pd.read_csv(hparams['edge_list'], dtype = {'edge_index': int, 'x_index': int, 'y_index': int}, low_memory = False)

    # if sample subgraph, subsample nodes and edges
    if hparams['sample_subgraph']:
        nodes, edges = generate_subgraph(nodes, edges, hparams['seed_node'], hparams['n_walks'], hparams['walk_length'])
        print("Number of subgraph nodes: ", len(nodes))
        print("Number of subgraph edges: ", len(edges))

    # group the nodes DataFrame by 'node_type' and use cumcount to generate the 'node_type_index'
    nodes['node_type_index'] = nodes.groupby('node_type').cumcount()

    # use the 'node_type_index' column to create the 'x_type_index' and 'y_type_index' columns in the edges DataFrame
    edges['x_type_index'] = nodes.loc[edges['x_index'], 'node_type_index'].values
    edges['y_type_index'] = nodes.loc[edges['y_index'], 'node_type_index'].values

    # define empty dictionary to store graph data
    neuroKG_data = {}

    # group the edges DataFrame by unique combinations of x_type, relation, and y_type
    grouped_edges = edges.groupby(['x_type', 'relation', 'y_type'], sort = False)

    # iterate over the groups
    for (x_type, relation, y_type), edges_subset in grouped_edges:

        # convert edge indices to torch tensor
        edge_indices = (torch.tensor(edges_subset['x_type_index'].values), torch.tensor(edges_subset['y_type_index'].values))

        # add edge indices to data object
        neuroKG_data[(x_type, relation, y_type)] = edge_indices

        # print update
        # print(f'Added edge relation: {x_type} - {relation} - {y_type}')

    # instantiate a DGL HeteroGraph
    neuroKG = dgl.heterograph(neuroKG_data)

    # add node features to the heterograph
    # first, group the nodes DataFrame by node_type
    grouped_nodes = nodes.groupby('node_type', sort = False)

    # iterate over the groups and add global node indices to the graph
    for node_type, nodes_subset in grouped_nodes:

        neuroKG.nodes[node_type].data['node_index'] = torch.tensor(nodes_subset['node_index'].values)

    # return the graph
    return neuroKG

Next, define the `partition_graph()` function.

In [4]:
# PARTITION GRAPH
def partition_graph(neuroKG, hparams):

    # define dictionaries for train, validation, and test eids
    train_eids = {}
    val_eids = {}
    test_eids = {}

    # split the edges into train, validation, and test sets
    forward_edge_types = [x for x in neuroKG.canonical_etypes if "rev" not in x[1]]
    for etype in forward_edge_types:

        # subset edge IDs for the current edge type
        etype_eids = neuroKG.edges(etype = etype, form = 'eid')

        # randomly shuffle edge IDs
        num_edges = etype_eids.shape[0]
        type_eids = etype_eids[torch.randperm(num_edges)]

        # get train, validation, and test lengths
        # here, we use a 80/15/5 split
        test_length = int(np.ceil(0.05 * num_edges))
        val_length = int(np.ceil(0.15 * num_edges))
        train_length = num_edges - test_length - val_length

        # split the edge IDs into train, validation, and test sets
        etype_train_eids = etype_eids[:train_length]
        etype_val_eids = etype_eids[train_length:(train_length + val_length)]
        etype_test_eids = etype_eids[(train_length + val_length):]

        # print number of edges in each set
        # print("Edges of type {} split into {} train, {} validation, and {} test edges.".format(etype, len(etype_train_eids), len(etype_val_eids), len(etype_test_eids)))

        # add the edge IDs to the dictionaries
        train_eids[etype] = etype_train_eids
        val_eids[etype] = etype_val_eids
        test_eids[etype] = etype_test_eids

        # get reverse edge type
        reverse_etype = (etype[2], "rev_" + etype[1], etype[0])

        # add the reverse edge IDs to the dictionaries
        train_eids[reverse_etype] = etype_train_eids
        val_eids[reverse_etype] = etype_val_eids
        test_eids[reverse_etype] = etype_test_eids
    
    # define new training graph
    train_neuroKG = neuroKG.edge_subgraph(train_eids, relabel_nodes = False)

    # combine train and validation edge IDs
    train_val_eids = {}
    for etype in train_eids.keys():
        train_val_eids[etype] = torch.cat((train_eids[etype], val_eids[etype]))

    # define new validation graph
    val_neuroKG = neuroKG.edge_subgraph(train_val_eids, relabel_nodes = False)

    # define new test graph
    test_neuroKG = neuroKG.edge_subgraph(test_eids, relabel_nodes = False)

    # return the graphs
    return train_neuroKG, val_neuroKG, test_neuroKG, train_eids, val_eids, test_eids

Define new node sampler from `samplers.py`.

In [5]:
# from ..sampling.utils import EidExcluder
# from .. import transforms
# from ..base import NID
# from .base import set_node_lazy_features, set_edge_lazy_features, Sampler
from dgl.sampling.utils import EidExcluder
from dgl.dataloading.shadow import set_node_lazy_features, set_edge_lazy_features, Sampler
from collections import defaultdict

class FixedSampler(Sampler):
    """Subgraph sampler that heterogeneous sampler that sets an upper 
    bound on the number of nodes included in each layer of the sampled subgraph.
    
    At each layer, the frontier is randomly subsampled. Rare node types can also be 
    upsampled by taking the scaled square root of the sampling probabilities.

    It performs node-wise neighbor sampling and returns the subgraph induced by
    all the sampled nodes.

    Parameters
    ----------
    fanouts : list[int] or list[dict[etype, int]]
        List of neighbors to sample per edge type for each GNN layer, with the i-th
        element being the fanout for the i-th GNN layer.

        If only a single integer is provided, DGL assumes that every edge type
        will have the same fanout.

        If -1 is provided for one edge type on one layer, then all inbound edges
        of that edge type will be included.
    fixed_k : int
            The number of nodes to sample for each GNN layer.
    upsample_rare_types : bool
        Whether or not to upsample rare node types.
    replace : bool, default True
        Whether to sample with replacement
    prob : str, optional
        If given, the probability of each neighbor being sampled is proportional
        to the edge feature value with the given name in ``g.edata``. The feature must be
        a scalar on each edge.
    """
    def __init__(self, fanouts, fixed_k, upsample_rare_types, replace=False, prob=None, 
                 prefetch_node_feats=None, prefetch_edge_feats=None, output_device=None):        
        super().__init__()
        self.fanouts = fanouts
        self.replace = replace
        self.fixed_k = fixed_k
        self.upsample_rare_types = upsample_rare_types
        self.prob = prob
        self.prefetch_node_feats = prefetch_node_feats
        self.prefetch_edge_feats = prefetch_edge_feats
        self.output_device = output_device

    def sample(self, g, seed_nodes, exclude_eids=None):
        """Sampling function.

        Parameters
        ----------
        g : DGLGraph
            The graph to sampler from.
        seed_nodes : Tensor or dict[str, Tensor]
            The nodes sampled in the current minibatch.
        exclude_eids : Tensor or dict[etype, Tensor], optional
            The edges to exclude from neighborhood expansion.

        Returns
        -------
        input_nodes, output_nodes, subg
            A triplet containing (1) the node IDs inducing the subgraph, (2) the node
            IDs that are sampled in this minibatch, and (3) the subgraph itself.
        """

        # define empty dictionary to store reached nodes
        output_nodes = seed_nodes
        all_reached_nodes = [seed_nodes]

        # iterate over fanout
        for fanout in reversed(self.fanouts):

            # sample frontier
            frontier = g.sample_neighbors(
                seed_nodes, fanout, output_device=self.output_device,
                replace=self.replace, prob=self.prob, exclude_edges=exclude_eids)

            # get reached nodes
            curr_reached = defaultdict(list)
            for c_etype in frontier.canonical_etypes:
                (src_type, rel_type, dst_type) = c_etype
                src, _ = frontier.edges(etype = c_etype)
                curr_reached[src_type].append(src)

            # de-duplication
            curr_reached = {ntype : torch.unique(torch.cat(srcs)) for ntype, srcs in curr_reached.items()}

            # generate type sampling probabilties
            type_count = {node_type: indices.shape[0] for node_type, indices in curr_reached.items()}
            total_count = sum(type_count.values())
            probs = {node_type: count / total_count for node_type, count in type_count.items()}

            # upsample rare node types
            if self.upsample_rare_types:

                # take scaled square root of probabilities
                prob_dist = list(probs.values())
                prob_dist = np.sqrt(prob_dist)
                prob_dist = prob_dist / prob_dist.sum()

                # update probabilities
                probs = {node_type: prob_dist[i] for i, node_type in enumerate(probs.keys())}

            # generate node counts per type
            n_per_type = {node_type: int(self.fixed_k * prob) for node_type, prob in probs.items()}
            remainder = self.fixed_k - sum(n_per_type.values())
            for _ in range(remainder):
                node_type = np.random.choice(list(probs.keys()), p=list(probs.values()))
                n_per_type[node_type] += 1

            # downsample nodes
            curr_reached_k = {}
            for node_type, node_IDs in curr_reached.items():

                # get number of total nodes and number to sample
                num_nodes = node_IDs.shape[0]
                n_to_sample = min(num_nodes, n_per_type[node_type])

                # downsample nodes of current type
                random_indices = torch.randperm(num_nodes)[:n_to_sample]
                curr_reached_k[node_type] = node_IDs[random_indices]

            # update seed nodes
            seed_nodes = curr_reached_k
            all_reached_nodes.append(curr_reached_k)

        # merge all reached nodes before sending to DGLGraph.subgraph
        merged_nodes = {}
        for ntype in g.ntypes:
            merged_nodes[ntype] = torch.unique(torch.cat([reached.get(ntype, []) for reached in all_reached_nodes]))
        subg = g.subgraph(merged_nodes, relabel_nodes=True, output_device=self.output_device)

        if exclude_eids is not None:
            subg = EidExcluder(exclude_eids)(subg)

        set_node_lazy_features(subg, self.prefetch_node_feats)
        set_edge_lazy_features(subg, self.prefetch_edge_feats)

        return seed_nodes, output_nodes, subg

Finally, define the `create_dataloaders()` function.

In [6]:
# CREATE DATA LOADERS
def create_dataloaders(neuroKG, train_neuroKG, val_neuroKG, test_neuroKG,
                       train_eids, val_eids, test_eids,
                       sampler_fanout = [1, 1, 1], negative_k = 5,
                       train_batch_size = 8, val_batch_size = 8, test_batch_size = 8,
                       num_workers = 0):

    print('Creating mini-batch pre-training dataloader...')

    # define dictionary mapping forward edges to reverse edges, and vice versa
    forward_edge_types = [x for x in neuroKG.canonical_etypes if "rev" not in x[1]]
    reverse_edge_dict = {(u, r, v): (v, "rev_" + r, u) for u, r, v in forward_edge_types}
    reverse_edge_dict.update({value: key for key, value in reverse_edge_dict.items()})

    # define positive sampler
    sampler = FixedSampler(sampler_fanout, fixed_k = 10, upsample_rare_types = True)

    # other choices for positive sampler
    # see https://docs.dgl.ai/en/latest/generated/dgl.dataloading.as_edge_prediction_sampler.html
    # sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3) # 3-layer full neighbor sampler
    # sampler = dgl.dataloading.NeighborSampler([1, 1, 1]) # requires blocks
    # sampler = dgl.dataloading.ShaDowKHopSampler(sampler_fanout)

    # define negative sampler
    # generate 5 negative samples per edge using uniform distribution
    neg_sampler = dgl.dataloading.negative_sampler.Uniform(negative_k)
    # define reverse edge types for each positive edge type and vice versa
    
    # convert to edge sampler
    sampler = dgl.dataloading.as_edge_prediction_sampler(
        sampler,
        exclude = "reverse_types", # exclude reverse edges
        reverse_etypes = reverse_edge_dict, # define reverse edge types
        negative_sampler = neg_sampler)

    # define training dataloader
    train_dataloader = dgl.dataloading.DataLoader(
        train_neuroKG, train_eids, sampler,
        batch_size = train_batch_size,
        shuffle = True,
        drop_last = False,
        num_workers = num_workers)

    # define validation dataloader
    val_dataloader = dgl.dataloading.DataLoader(
        val_neuroKG, val_eids, sampler,
        batch_size = val_batch_size,
        shuffle = True,
        drop_last = False,
        num_workers = num_workers)

    # define test dataloader
    test_dataloader = dgl.dataloading.DataLoader(
        test_neuroKG, test_eids, sampler,
        batch_size = test_batch_size,
        shuffle = True,
        drop_last = False,
        num_workers = num_workers)
    
    # return the dataloaders
    return train_dataloader, val_dataloader, test_dataloader

# Model

Define the model using code in `models.py`. First, define the imports for the model.

In [7]:
# import PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# import DGL
import dgl
from dgl.nn.pytorch.conv import HGTConv

# custom imports
from utils import calculate_metrics

Next, define the `BilinearDecoder` class.

In [8]:
# BILINEAR DECODER CLASS
class BilinearDecoder(pl.LightningModule): # overrides nn.Module

    # INITIALIZATION
    def __init__(self, num_etypes, embedding_dim):
        '''
        This function initializes a bilinear decoder.

        Args:
            num_etypes (int): Number of edge types.
            embedding_dim (int): Dimension of embedding (i.e., output dimension * number of attention heads).
        '''
        super().__init__()

        # edge-type specific learnable weights
        self.relation_weights = nn.Parameter(torch.Tensor(num_etypes, embedding_dim))

        # initialize weights
        nn.init.xavier_uniform_(self.relation_weights, gain = nn.init.calculate_gain('leaky_relu'))
    

    # ADD EDGE TYPE INDEX
    def add_edge_type_index(self, edge_graph):
        '''
        This function adds an integer edge type label to each edge in the graph. This is required for the decoder.
        Specifically, the edge type label is used to subset the right row of the relation weight matrix.
        
        Args:
            edge_graph (dgl.DGLGraph): Positive or negative edge graph.
        '''

        # iterate over the canonical edge types
        for edge_index, edge_type in enumerate(edge_graph.canonical_etypes):
        
            # get number of edges of that type
            num_edges = edge_graph.num_edges(edge_type)

            # add integer label to edge
            edge_graph.edges[edge_type].data['edge_type_index'] = torch.tensor([edge_index] * num_edges, device = self.device) #.to(device)
    
    
    # DECODER
    def decode(self, edges):
        '''
        This is a user-defined function over the edges to generate the score for each edge.
        See https://docs.dgl.ai/en/0.9.x/generated/dgl.DGLGraph.apply_edges.html.
        '''
        
        # get source embeddings
        src_embeddings = edges.src['node_embedding']
        dst_embeddings = edges.dst['node_embedding']

        # get relation weight for specific edge type
        # note that, because the decode function is applied by edge type, we can use the first edge to get the edge type
        edge_type_index = edges.data['edge_type_index'][0] # see torch.unique(edges.data['edge_type'])
        rel_weights = self.relation_weights[edge_type_index]

        # compute weighted dot product
        # each row of src_embeddings is multiplied by rel_weights, then element-wise multiplied by dst_embeddings
        # finally, a row-wise sum is performed to get a single score per edge
        score = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)

        return {'score': score}


    # COMPUTE SCORE
    def compute_score(self, edge_graph):
        '''
        This function computes the score for positive or negative edges using dgl.DGLGraph.apply_edges.

        Args:
            edge_graph (dgl.DGLGraph): Positive or negative edge graph.
        '''

        with edge_graph.local_scope():

            # get edge types with > 0 number of edges in the positive graph
            nonzero_edge_types = [etype for etype in edge_graph.canonical_etypes if edge_graph.num_edges(etype) != 0]

            # compute score for positive graph
            for etype in nonzero_edge_types:
                edge_graph.apply_edges(self.decode, etype = etype)
            
            # return scores
            return edge_graph.edata['score']
    
    
    # FORWARD PASS
    def forward(self, subgraph, pos_graph, neg_graph, node_embeddings):
        '''
        This function performs a forward pass of the bilinear decoder.

        Args:
            subgraph (dgl.DGLHeteroGraph): Subgraph.
            pos_graph (dgl.DGLHeteroGraph): Positive graph.
            neg_graph (dgl.DGLHeteroGraph): Negative graph.
            node_embeddings (torch.Tensor): Node embeddings.
        '''

        # get subgraph node IDs
        subgraph_nodes = subgraph.ndata['node_index']

        # assign node embeddings to positive and negative graphs
        # iterate over node types in positive graph
        for ntype in pos_graph.ntypes:

            # get positive graph node IDs
            pos_graph_nodes = pos_graph.ndata['node_index'][ntype].unsqueeze(1)
            
            # find indices of positive graph nodes in subgraph
            # note, that indices are same for negative graph
            # compare pos_graph.ndata['_ID'] vs. neg_graph.ndata['_ID']
            pos_graph_indices = torch.where(subgraph_nodes == pos_graph_nodes)[1]

            # add embeddings as feature to graph
            pos_graph.nodes[ntype].data['node_embedding'] = node_embeddings[pos_graph_indices]
            neg_graph.nodes[ntype].data['node_embedding'] = node_embeddings[pos_graph_indices]

        # add edge indices to positive and negative graphs
        self.add_edge_type_index(pos_graph)
        self.add_edge_type_index(neg_graph)

        # compute scores for positive and negative graphs
        pos_graph_scores = self.compute_score(pos_graph)
        neg_graph_scores = self.compute_score(neg_graph)

        # return scores
        return pos_graph_scores, neg_graph_scores

Finally, define the `HGT` class.

In [9]:
# HETEROGENEOUS GRAPH TRANSFORMER
class HGT(pl.LightningModule):
    
    # INITIALIZATION
    def __init__(self, num_nodes, num_ntypes, num_etypes, num_feat = 1024, num_heads = 4,
                 hidden_dim = 256, output_dim = 128, num_layers = 2,
                 dropout_prob = 0.5, pred_threshold = 0.5,
                 lr = 0.0001, wd = 0.0, lr_factor = 0.01, lr_patience = 100, lr_threshold = 1e-4,
                 lr_threshold_mode = 'rel', lr_cooldown = 0, min_lr = 1e-8, eps = 1e-8,
                 hparams = None):
        '''
        This function initializes the model and defines the model hyperparameters and architecture.

        Args:
            num_nodes (int): Number of nodes in the graph.
            num_ntypes (int): Number of node types in the graph.
            num_etypes (int): Number of edge types in the graph.
            num_feat (int): Number of input features (i.e., hidden embedding dimension).
            num_heads (int): Number of attention heads.
            hidden_dim (int): Number of hidden units in the second to last HGT layer.
            output_dim (int): Number of output units.
            num_layers (int): Number of HGT layers.
            dropout_prob (float): Dropout probability.
            pred_threshold (float): Prediction threshold to compute metrics.
            lr (float): Learning rate.
            wd (float): Weight decay.
            lr_factor (float): Factor by which to reduce learning rate.
            lr_patience (int): Number of epochs with no improvement after which learning rate will be reduced.
            lr_threshold (float): Threshold for measuring the new optimum, to only focus on significant changes.
            lr_threshold_mode (str): One of ['rel', 'abs'].
            lr_cooldown (int): Number of epochs to wait before resuming normal operation after lr reduction.
            min_lr (float): A lower bound on the learning rate of all param groups or each group respectively.
            eps (float): Term added to the denominator to improve numerical stability.
            hparams (dict): Dictionary of model hyperparameters. Will override all other arguments if not None.
        '''

        super().__init__()

        # if hparams_dict is None, construct dictionary from arguments
        if hparams is None:
            hparams = locals()

        # save model hyperparameters
        self.save_hyperparameters(hparams)
        self.num_feat = hparams['num_feat']
        self.num_heads = hparams['num_heads']
        self.hidden_dim = hparams['hidden_dim']
        self.output_dim = hparams['output_dim']
        self.num_layers = hparams['num_layers']
        self.dropout_prob = hparams['dropout_prob']
        self.pred_threshold = hparams['pred_threshold']

        # learning rate parameters
        self.lr = hparams['lr']
        self.wd = hparams['wd']
        self.lr_factor = hparams['lr_factor']
        self.lr_patience = hparams['lr_patience']
        self.lr_threshold = hparams['lr_threshold']
        self.lr_threshold_mode = hparams['lr_threshold_mode']
        self.lr_cooldown = hparams['lr_cooldown']
        self.min_lr = hparams['min_lr']
        self.eps = hparams['eps']

        # calculate sizes of hidden dimensions
        self.h_dim_1 = hidden_dim * 2
        self.h_dim_2 = hidden_dim

        # define node embeddings
        self.emb = nn.Embedding(num_nodes, num_feat)

        # layer 1
        self.conv1 = HGTConv(in_size = num_feat, head_size = self.h_dim_1, num_heads = num_heads,
                                num_ntypes = num_ntypes, num_etypes = num_etypes, dropout = 0.2, use_norm = True)

        # layer normalization 1
        self.norm1 = nn.LayerNorm(self.h_dim_1 * num_heads)
        
        if self.num_layers == 2:
        
            # layer 2
            self.conv2 = HGTConv(in_size = self.h_dim_1 * num_heads, head_size = output_dim, num_heads = num_heads,
                                    num_ntypes = num_ntypes, num_etypes = num_etypes, dropout = 0.2, use_norm = True)
            
        elif self.num_layers == 3:
        
            # layer 2
            self.conv2 = HGTConv(in_size = self.h_dim_1 * num_heads, head_size = self.h_dim_2, num_heads = num_heads,
                                    num_ntypes = num_ntypes, num_etypes = num_etypes, dropout = 0.2, use_norm = True)

            # layer normalization 2
            self.norm2 = nn.LayerNorm(self.h_dim_2 * num_heads)

            # layer 3
            self.conv3 = HGTConv(in_size = self.h_dim_2 * num_heads, head_size = output_dim, num_heads = num_heads,
                                    num_ntypes = num_ntypes, num_etypes = num_etypes, dropout = 0.2, use_norm = True)
            
        else:

            # raise error
            raise ValueError('Number of layers must be 2 or 3.')

        # define decoder
        self.decoder = BilinearDecoder(num_etypes, output_dim * num_heads)
        
    
    # FORWARD PASS
    def forward(self, subgraph):
        '''
        This function performs a forward pass of the model. Note that the subgraph must be converted to from a 
        heterogeneous graph to homogeneous graph for efficiency.

        Args:
            subgraph (dgl.DGLHeteroGraph): Subgraph containing the nodes and edges for the current batch.
        '''

        # get global indices
        global_node_indices = subgraph.ndata['node_index']

        # get node embeddings from the first MFG layer
        x = self.emb(global_node_indices)      

        # pass node embedding through first two layers
        x = self.conv1(subgraph, x, subgraph.ndata[dgl.NTYPE], subgraph.edata[dgl.ETYPE])
        x = self.norm1(x)
        x = F.leaky_relu(x)
        x = self.conv2(subgraph, x, subgraph.ndata[dgl.NTYPE], subgraph.edata[dgl.ETYPE])

        # check if 3 layers
        if self.num_layers == 3:

            # pass node embedding through layer 3
            x = self.norm2(x)
            x = F.leaky_relu(x)
            x = self.conv3(subgraph, x, subgraph.ndata[dgl.NTYPE], subgraph.edata[dgl.ETYPE])
        
        # return node embeddings
        return x
    

    # STEP FUNCTION USED FOR TRAINING, VALIDATION, AND TESTING
    def _step(self, input_nodes, pos_graph, neg_graph, subgraph, mode):
        '''Defines the step that is run on each batch of data. PyTorch Lightning handles steps including:
            - Moving data to the correct device.
            - Epoch and batch iteration.
            - optimizer.step(), loss.backward(), optimizer.zero_grad() calls.
            - Calling of model.eval(), enabling/disabling grads during evaluation.
            - Logging of metrics.
        
        Args:
            input_nodes (torch.Tensor): Input nodes.
            pos_graph (dgl.DGLHeteroGraph): Positive graph.
            neg_graph (dgl.DGLHeteroGraph): Negative graph.
            subgraph (dgl.DGLHeteroGraph): Subgraph.
            mode (str): The mode of the step (train, val, test).
        '''

        # get batch size by summing number of nodes in each node type
        batch_size = sum([x.shape[0] for x in input_nodes.values()])

        # convert heterogeneous graph to homogeneous graph for efficiency
        # see https://docs.dgl.ai/en/latest/generated/dgl.to_homogeneous.html
        subgraph = dgl.to_homogeneous(subgraph, ndata = ['node_index'])
        
        # send to GPU
        # subgraph = subgraph.to(device)
        # pos_graph = pos_graph.to(device)
        # neg_graph = neg_graph.to(device)

        # get node embeddings
        node_embeddings = self.forward(subgraph)

        # compute score from decoder
        pos_scores, neg_scores = self.decoder(subgraph, pos_graph, neg_graph, node_embeddings)

        # compute loss
        loss, metrics = self.compute_loss(pos_scores, neg_scores)

        # return loss and metrics
        return loss, metrics, batch_size
    

    # TRAINING STEP
    def training_step(self, batch, batch_idx):
        '''Defines the step that is run on each batch of training data.'''

        # get batch elements
        input_nodes, pos_graph, neg_graph, subgraph = batch

        # get loss and metrics
        loss, metrics, batch_size = self._step(input_nodes, pos_graph, neg_graph, subgraph, mode = 'train')

        # log loss and metrics
        values = {"train/loss": loss.detach(),
                  "train/accuracy": metrics['accuracy'],
                  "train/ap": metrics['ap'],
                  "train/f1": metrics['f1'],
                  "train/auroc": metrics['auroc']}
        self.log_dict(values, batch_size = batch_size)

        # return loss
        return loss
    

    # VALIDATION STEP
    def validation_step(self, batch, batch_idx):
        '''Defines the step that is run on each batch of validation data.'''

        # get batch elements
        input_nodes, pos_graph, neg_graph, subgraph = batch

        # get loss and metrics
        loss, metrics, batch_size = self._step(input_nodes, pos_graph, neg_graph, subgraph, mode = 'val')

        # log loss and metrics
        values = {"val/loss": loss.detach(),
                  "val/accuracy": metrics['accuracy'],
                  "val/ap": metrics['ap'],
                  "val/f1": metrics['f1'],
                  "val/auroc": metrics['auroc']}
        self.log_dict(values, batch_size = batch_size)


    # TEST STEP
    def test_step(self, batch, batch_idx):
        '''Defines the step that is run on each batch of test data.'''

        # get batch elements
        input_nodes, pos_graph, neg_graph, subgraph = batch

        # get loss and metrics
        loss, metrics, batch_size = self._step(input_nodes, pos_graph, neg_graph, subgraph, mode = 'test')

        # log loss and metrics
        values = {"test/loss": loss.detach(),
                  "test/accuracy": metrics['accuracy'],
                  "test/ap": metrics['ap'],
                  "test/f1": metrics['f1'],
                  "test/auroc": metrics['auroc']}
        self.log_dict(values, batch_size = batch_size)

    
    # LOSS FUNCTION
    def compute_loss(self, pos_scores, neg_scores):
        '''
        This function computes the loss and metrics for the current batch.
        '''

        # concatenate positive and negative scores across edge types
        pos_pred = torch.cat(list(pos_scores.values()))
        neg_pred = torch.cat(list(neg_scores.values()))
        raw_pred = torch.cat((pos_pred, neg_pred))

        # transform with activation function
        pred = torch.sigmoid(raw_pred)

        # construct target vector
        pos_target = torch.ones(pos_pred.shape[0])
        neg_target = torch.zeros(neg_pred.shape[0])
        target = torch.cat((pos_target, neg_target)).to(self.device) #.to(device)

        # compute loss
        loss = F.binary_cross_entropy(pred, target, reduction = "mean")

        # calculate metrics
        metrics = calculate_metrics(pred.cpu().detach().numpy(), target.cpu().detach().numpy(), self.pred_threshold)
        return loss, metrics
    

    # OPTIMIZER AND SCHEDULER
    def configure_optimizers(self):
        '''
        This function is called by PyTorch Lightning to get the optimizer and scheduler.
        We reduce the learning rate by a factor of lr_factor if the validation loss does not improve for lr_patience epochs.

        Args:
            None

        Returns:
            dict: Dictionary containing the optimizer and scheduler.
        '''
        
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay = self.wd)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode = 'min', factor = self.lr_factor, patience = self.lr_patience,
            threshold = self.lr_threshold, threshold_mode = self.lr_threshold_mode,
            cooldown = self.lr_cooldown, min_lr = self.min_lr, eps = self.eps
        )
        
        return {"optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    'name': 'curr_lr'
                    },
                }

# Pretrain Model

Expand `pretrain()` function in `pretrain.py` to test model.

In [10]:
# get hyperparameters
# args = parse_args()
hparams = get_hyperparameters() 

# set seed
pl.seed_everything(hparams['seed'], workers = True)

# load NeuroKG knowledge graph
neuroKG = load_graph(hparams)

# partition graph into train, validation, and test sets
train_neuroKG, val_neuroKG, test_neuroKG, train_eids, val_eids, test_eids = partition_graph(neuroKG, hparams)

# get dataloaders
train_dataloader, val_dataloader, test_dataloader = create_dataloaders(
    neuroKG, train_neuroKG, val_neuroKG, test_neuroKG, train_eids, val_eids, test_eids,
    sampler_fanout = hparams['sampler_fanout'], negative_k = hparams['negative_k'],
    train_batch_size = hparams['train_batch_size'], val_batch_size = hparams['val_batch_size'], 
    test_batch_size = hparams['test_batch_size'], num_workers = hparams['num_workers']
)

# enable CPU affinity
train_dataloader.enable_cpu_affinity()
val_dataloader.enable_cpu_affinity()
test_dataloader.enable_cpu_affinity()

# instantiate logger
curr_time = datetime.now()
run_name = curr_time.strftime('%H:%M:%S on %m/%d/%Y')
run_id = curr_time.strftime('%Y_%m_%d_%H_%M_%S')
# wandb_logger = WandbLogger(name = run_name, project = 'cipher-pretraining', entity = 'ayushnoori',
#                             save_dir = hparams['wandb_save_dir'], id = run_id, resume = "allow")

# instantiate models
model = HGT(
    num_nodes = train_neuroKG.num_nodes(), num_ntypes = len(train_neuroKG.ntypes),
    num_etypes = len(train_neuroKG.canonical_etypes), hparams = hparams
)

Global seed set to 42


Pre-Training Hyperparameters:  {'node_list': PosixPath('/n/data1/hms/dbmi/zitnik/lab/users/an252/NeuroKG/neuroKG/Data/NeuroKG/4_final_KG/neuroKG_nodes.csv'), 'edge_list': PosixPath('/n/data1/hms/dbmi/zitnik/lab/users/an252/NeuroKG/neuroKG/Data/NeuroKG/4_final_KG/neuroKG_edges.csv'), 'save_dir': PosixPath('/n/data1/hms/dbmi/zitnik/lab/users/an252/NeuroKG/neuroKG/Results/pretrain'), 'num_feat': 2048, 'num_heads': 4, 'hidden_dim': 32, 'output_dim': 128, 'wd': 0.0, 'dropout_prob': 0.3, 'lr': 0.0001, 'max_epochs': 250, 'resume': '', 'best_ckpt': None, 'save_embeddings': False, 'debug': True, 'pred_threshold': 0.5, 'n_gpus': 1, 'num_workers': 4, 'train_batch_size': 1024, 'val_batch_size': 1024, 'test_batch_size': 1024, 'sampler_fanout': [1, 1, 1], 'num_layers': 3, 'negative_k': 1, 'grad_clip': 1.0, 'lr_factor': 0.01, 'lr_patience': 100, 'lr_threshold': 0.0001, 'lr_threshold_mode': 'rel', 'lr_cooldown': 0, 'min_lr': 0, 'eps': 1e-08, 'seed': 42, 'profiler': None, 'wandb_save_dir': PosixPath('/



In [11]:
# move model to GPU
model = model.to(device)

# going to be training_step function
for input_nodes, pos_graph, neg_graph, subgraph in train_dataloader:

    # convert heterogeneous graph to homogeneous graph for efficiency
    # see https://docs.dgl.ai/en/latest/generated/dgl.to_homogeneous.html
    subgraph = dgl.to_homogeneous(subgraph, ndata = ['node_index'])
        
    # send to GPU
    subgraph = subgraph.to(device)
    pos_graph = pos_graph.to(device)
    neg_graph = neg_graph.to(device)

    # get node embeddings
    node_embeddings = model.forward(subgraph)

    # compute score from decoder
    pos_scores, neg_scores = model.decoder(subgraph, pos_graph, neg_graph, node_embeddings)

    # compute loss
    loss, metrics = model.compute_loss(pos_scores, neg_scores)
    
    break



In [12]:
subgraph.num_nodes()

2718

In [13]:
subgraph.num_edges()

340226

Define new sampler with fixed subgraph size.

In [16]:
# get node types
node_types = train_neuroKG.ntypes

# for each node type, select five random indices from graph, and construct dict
original_seed_nodes = {}
for node_type in node_types:
    node_subset = train_neuroKG.nodes(node_type)
    total_elements = node_subset.numel()

    num_to_sample = 5
    # # get number to sample
    # if node_type == 'gene/protein':
    #     num_to_sample = 1
    # else:
    #     num_to_sample = 0

    random_indices = torch.randperm(total_elements)[:num_to_sample]
    original_seed_nodes[node_type] = node_subset[random_indices]

seed_nodes = original_seed_nodes

In [15]:
# from collections import defaultdict

# fanouts = [4, 4, 4]
# fixed_k = 10

# # define empty dictionary to store reached nodes
# all_reached_nodes = [seed_nodes]

# # iterate over fanout
# for fanout in reversed(fanouts):

#     # sample frontier
#     frontier = train_neuroKG.sample_neighbors(seed_nodes, fanout, replace=False, prob=None)

#     # get reached nodes
#     curr_reached = defaultdict(list)
#     for c_etype in frontier.canonical_etypes:
#         (src_type, rel_type, dst_type) = c_etype
#         src, _ = frontier.edges(etype = c_etype)
#         curr_reached[src_type].append(src)

#     # de-duplication
#     curr_reached = {ntype : torch.unique(torch.cat(srcs)) for ntype, srcs in curr_reached.items()}

#     # set upper limit
#     fixed_k = 20
#     upsample_rare_types = True

#     # generate type sampling probabilties
#     type_count = {node_type: indices.shape[0] for node_type, indices in curr_reached.items()}
#     total_count = sum(type_count.values())
#     probs = {node_type: count / total_count for node_type, count in type_count.items()}

#     # upsample rare node types
#     if upsample_rare_types:

#         # take scaled square root of probabilities
#         prob_dist = list(probs.values())
#         prob_dist = np.sqrt(prob_dist)
#         prob_dist = prob_dist / prob_dist.sum()

#         # update probabilities
#         probs = {node_type: prob_dist[i] for i, node_type in enumerate(probs.keys())}

#     # generate node counts per type
#     n_per_type = {node_type: int(fixed_k * prob) for node_type, prob in probs.items()}
#     remainder = fixed_k - sum(n_per_type.values())
#     for _ in range(remainder):
#         node_type = np.random.choice(list(probs.keys()), p=list(probs.values()))
#         n_per_type[node_type] += 1

#     # downsample nodes
#     curr_reached_k = {}
#     for node_type, node_IDs in curr_reached.items():

#         # get number of total nodes and number to sample
#         num_nodes = node_IDs.shape[0]
#         n_to_sample = min(num_nodes, n_per_type[node_type])

#         # downsample nodes of current type
#         random_indices = torch.randperm(num_nodes)[:n_to_sample]
#         curr_reached_k[node_type] = node_IDs[random_indices]

#     # update seed nodes
#     seed_nodes = curr_reached_k
#     all_reached_nodes.append(curr_reached_k)

# # merge all reached nodes before sending to DGLGraph.subgraph
# merged_nodes = {}
# for ntype in train_neuroKG.ntypes:
#     merged_nodes[ntype] = torch.unique(torch.cat([reached.get(ntype, []) for reached in all_reached_nodes]))
# subg = train_neuroKG.subgraph(merged_nodes, relabel_nodes=True)