In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn.pytorch import GATConv
import dgl

Using backend: pytorch


In [4]:
class HUGAT_J(nn.Module):
    def __init__(self, data, meta_paths, in_size, hidden_size, out_size, num_heads, dropout):
        super().__init__()
        self.hugat1 = HUGAT(data['heterograph_unified{}'.format(2018)], meta_paths, in_size, hidden_size, out_size, num_heads, dropout)
        self.hugat2 = HUGAT(data['heterograph_unified{}'.format(2019)], meta_paths, in_size, hidden_size, out_size, num_heads, dropout)
        self.hugat3 = HUGAT(data['heterograph_unified{}'.format(2020)], meta_paths, in_size, hidden_size, out_size, num_heads, dropout)
        self.hugat4 = HUGAT(data['heterograph_unified{}'.format(2021)], meta_paths, in_size, hidden_size, out_size, num_heads, dropout)
        
        '''meta-path for l-urban-HIN'''
        self.tot_meta = [['zone-zone'],['src-time','time-src'],['dst-time','time-dst'],
                         ['years-after'],['years-before'],['zone-year','year-zone']]

        self.hugat_total = HUGAT(data['heterograph_unified_3'], self.tot_meta, out_size, hidden_size, out_size, num_heads, dropout)

        
    def forward(self, data, h):
        num_nodes = 77
        emb2018, beta2018 = self.hugat1.forward(data['heterograph_unified{}'.format(2018)], h[num_nodes*0:num_nodes*1,:])
        emb2019, beta2019 = self.hugat2.forward(data['heterograph_unified{}'.format(2019)], h[num_nodes*1:num_nodes*2,:])
        emb2020, beta2020 = self.hugat3.forward(data['heterograph_unified{}'.format(2020)], h[num_nodes*2:num_nodes*3,:])
        emb2021, beta2021 = self.hugat4.forward(data['heterograph_unified{}'.format(2021)], h[num_nodes*3:num_nodes*4,:])
        
        embs = torch.stack([emb2018, emb2019, emb2020, emb2021])
        embs = torch.reshape(embs, (num_nodes*4,-1))
        
        emb_total, beta = self.hugat_total.forward(data['heterograph_unified_3'], embs)
        
        beta_whole = [beta2018, beta2019, beta2020, beta2021, beta]
        
        return emb_total, beta_whole 

In [16]:
#import utils

class HUGAT(nn.Module):
    def __init__(self, g, meta_paths, in_size, hidden_size, out_size, num_heads, dropout):
        super().__init__()
        self.han = HAN(meta_paths, in_size, hidden_size, out_size, num_heads, dropout)
    
    def forward(self, g, h):
        self.emb, beta = self.han.forward(g, h)        
        return self.emb, beta

In [7]:
def Hellinger_pairwise(a, b):
    hellinger_distance = ((1/2)**(1/2)) * torch.cdist(torch.sqrt(a), torch.sqrt(b), p=2)
    return hellinger_distance

In [8]:
class HAN(nn.Module):
    def __init__(self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout):
        super(HAN, self).__init__()
        
        self.num_heads = num_heads

        self.layers = nn.ModuleList()
        self.layers.append(HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout))
        for l in range(1, len(num_heads)):
            self.layers.append(HANLayer(meta_paths, hidden_size * num_heads[l-1],
                                        hidden_size, num_heads[l], dropout))
        self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)
        
    def forward(self, g, h):
        i = 0
        for gnn in self.layers:
            if i < len(self.num_heads):
                h, beta = gnn(g, h)
            else:
                h = gnn(g, h)
            i += 1
        return self.predict(h), beta   

In [9]:
def pairwise_inner_product(a, b):
    n, _ = list(a.size())
    b_ = torch.unsqueeze(b, 0)
    b_ = torch.tile(b_, [n, 1, 1])
    b_ = b_.permute(1, 0, 2)
    inner_product = torch.multiply(b_, a)
    inner_product = torch.sum(inner_product, axis=-1)
    return inner_product

In [10]:
class HANLayer(nn.Module):
    """
    HAN layer.
    
    Dimensions
    ---------
    N : number of nodes
    D : dimension of output feature
    M : cardinality of meta-pahts
    K : number of Multi-heads
    
    Arguments
    ---------
    meta_paths : list of metapaths, each as a list of edge types
    in_size : input feature dimension
    out_size : output feature dimension
    layer_num_heads : number of attention heads
    dropout : Dropout probability
    
    Inputs
    ------
    g : DGLHeteroGraph
        The heterogeneous graph
    h : tensor
        Input features
    Outputs
    -------
    tensor
        The output feature
    """
    def __init__(self, meta_paths, in_size, out_size, layer_num_heads, dropout):
        super(HANLayer, self).__init__()

        # One GAT layer for each meta path based adjacency matrix
        self.gat_layers = nn.ModuleList()
        for i in range(len(meta_paths)):
            self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
                                           dropout, dropout, activation=F.elu,
                                           allow_zero_in_degree=True))
        self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
        self.meta_paths = list(tuple(meta_path) for meta_path in meta_paths)

        self._cached_graph = None
        self._cached_coalesced_graph = {}

    def forward(self, g, h):
        semantic_embeddings = []

        if self._cached_graph is None or self._cached_graph is not g:
            self._cached_graph = g
            self._cached_coalesced_graph.clear()
            for meta_path in self.meta_paths:
                self._cached_coalesced_graph[meta_path] = dgl.metapath_reachable_graph(
                        g, meta_path)

        for i, meta_path in enumerate(self.meta_paths):
            new_g = self._cached_coalesced_graph[meta_path]
            # concatenate
            semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))
        semantic_embeddings = torch.stack(semantic_embeddings, dim=1)                  

        return self.semantic_attention(semantic_embeddings)                    

In [11]:
class SemanticAttention(nn.Module):
    def __init__(self, in_size, hidden_size=128):
        super(SemanticAttention, self).__init__()
        '''
        q.T x tanh(Wz+b)
        '''
        self.project = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False)     
        )                                              

    def forward(self, z):                              
        w = self.project(z).mean(0)                  
        beta = torch.softmax(w, dim=0)                 
        beta = beta.expand((z.shape[0],) + beta.shape) 

        return (beta * z).sum(1), beta                     