## The Heterogeneous Graph Transformer
based on the paper “Heterogeneous Graph Transformer”  by Ziniu Hu, et. al. https://arxiv.org/pdf/2003.01332.pdf
The code below is based on 
https://github.com/dmlc/dgl/tree/master/examples/pytorch/hgt

In [1]:
import dgl
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

Using backend: pytorch


In [2]:
class HGTLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout = 0.2, use_norm = False):
        super(HGTLayer, self).__init__()

        self.in_dim        = in_dim
        self.out_dim       = out_dim
        self.num_types     = num_types
        self.num_relations = num_relations
        self.n_heads       = n_heads
        self.d_k           = out_dim // n_heads
        self.sqrt_dk       = math.sqrt(self.d_k)
        
        self.k_linears   = nn.ModuleList()
        self.q_linears   = nn.ModuleList()
        self.v_linears   = nn.ModuleList()
        self.a_linears   = nn.ModuleList()
        self.norms       = nn.ModuleList()
        self.use_norm    = use_norm
        
        for t in range(num_types):
            self.k_linears.append(nn.Linear(in_dim,   out_dim))
            self.q_linears.append(nn.Linear(in_dim,   out_dim))
            self.v_linears.append(nn.Linear(in_dim,   out_dim))
            self.a_linears.append(nn.Linear(out_dim,  out_dim))
            if use_norm:
                self.norms.append(nn.LayerNorm(out_dim))
            
        self.relation_pri   = nn.Parameter(torch.ones(num_relations, self.n_heads))
        self.relation_att   = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
        self.relation_msg   = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
        self.skip           = nn.Parameter(torch.ones(num_types))
        self.drop           = nn.Dropout(dropout)
        
        nn.init.xavier_uniform_(self.relation_att)
        nn.init.xavier_uniform_(self.relation_msg)

    def edge_attention(self, edges):
        etype = edges.data['id'][0]
        relation_att = self.relation_att[etype]
        relation_pri = self.relation_pri[etype]
        relation_msg = self.relation_msg[etype]
        key   = torch.bmm(edges.src['k'].transpose(1,0), relation_att).transpose(1,0)
        att   = (edges.dst['q'] * key).sum(dim=-1) * relation_pri / self.sqrt_dk
        val   = torch.bmm(edges.src['v'].transpose(1,0), relation_msg).transpose(1,0)
        return {'a': att, 'v': val}
    
    def message_func(self, edges):
        return {'v': edges.data['v'], 'a': edges.data['a']}
    
    def reduce_func(self, nodes):
        att = F.softmax(nodes.mailbox['a'], dim=1)
        h   = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1)
        return {'t': h.view(-1, self.out_dim)}
        
    def forward(self, G, inp_key, out_key):
        node_dict, edge_dict = G.node_dict, G.edge_dict
        for srctype, etype, dsttype in G.canonical_etypes:
            k_linear = self.k_linears[node_dict[srctype]]
            v_linear = self.v_linears[node_dict[srctype]] 
            q_linear = self.q_linears[node_dict[dsttype]]
            
            G.nodes[srctype].data['k'] = k_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k)
            G.nodes[srctype].data['v'] = v_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k)
            G.nodes[dsttype].data['q'] = q_linear(G.nodes[dsttype].data[inp_key]).view(-1, self.n_heads, self.d_k)
            
            G.apply_edges(func=self.edge_attention, etype=etype)
        G.multi_update_all({etype : (self.message_func, self.reduce_func) \
                            for etype in edge_dict}, cross_reducer = 'mean')
        for ntype in G.ntypes:
            n_id = node_dict[ntype]
            alpha = torch.sigmoid(self.skip[n_id])
            trans_out = self.a_linears[n_id](G.nodes[ntype].data['t'])
            trans_out = trans_out * alpha + G.nodes[ntype].data[inp_key] * (1-alpha)
            if self.use_norm:
                G.nodes[ntype].data[out_key] = self.drop(self.norms[n_id](trans_out))
            else:
                G.nodes[ntype].data[out_key] = self.drop(trans_out)
    def __repr__(self):
        return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format(
            self.__class__.__name__, self.in_dim, self.out_dim,
            self.num_types, self.num_relations)


In [3]:
class HGT(nn.Module):
    def __init__(self, G, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True):
        super(HGT, self).__init__()
        self.gcs = nn.ModuleList()
        self.n_inp = n_inp
        self.n_hid = n_hid
        self.n_out = n_out
        self.n_layers = n_layers
        self.adapt_ws  = nn.ModuleList()
        for t in range(len(G.node_dict)):
            self.adapt_ws.append(nn.Linear(n_inp,   n_hid))
        for _ in range(n_layers):
            self.gcs.append(HGTLayer(n_hid, n_hid, len(G.node_dict), len(G.edge_dict), n_heads, use_norm = use_norm))
        self.out = nn.Linear(n_hid, n_out)

    def forward(self, G, out_key):
        for ntype in G.ntypes:
            n_id = G.node_dict[ntype]
            G.nodes[ntype].data['h'] = torch.tanh(self.adapt_ws[n_id](G.nodes[ntype].data['inp']))
        for i in range(self.n_layers):
            self.gcs[i](G, 'h', 'h')
        return self.out(G.nodes[out_key].data['h'])
    def __repr__(self):
        return '{}(n_inp={}, n_hid={}, n_out={}, n_layers={})'.format(
            self.__class__.__name__, self.n_inp, self.n_hid,
            self.n_out, self.n_layers)

In [4]:
import scipy.io
import urllib.request
import dgl
import math
import numpy as np
data_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/ACM.mat'
data_file_path = 'ACM.mat'

urllib.request.urlretrieve(data_url, data_file_path)
data = scipy.io.loadmat(data_file_path)


In [5]:
G = dgl.heterograph({
        ('paper', 'written-by', 'author') : data['PvsA'],
        ('author', 'writing', 'paper') : data['PvsA'].transpose(),
        ('paper', 'citing', 'paper') : data['PvsP'],
        ('paper', 'cited', 'paper') : data['PvsP'].transpose(),
        ('paper', 'is-about', 'subject') : data['PvsL'],
        ('subject', 'has', 'paper') : data['PvsL'].transpose(),
    })
print(G)

Graph(num_nodes={'author': 17431, 'paper': 12499, 'subject': 73},
      num_edges={('paper', 'written-by', 'author'): 37055, ('author', 'writing', 'paper'): 37055, ('paper', 'citing', 'paper'): 30789, ('paper', 'cited', 'paper'): 30789, ('paper', 'is-about', 'subject'): 12499, ('subject', 'has', 'paper'): 12499},
      metagraph=[('author', 'paper'), ('paper', 'author'), ('paper', 'paper'), ('paper', 'paper'), ('paper', 'subject'), ('subject', 'paper')])


In [35]:
pvc = data['PvsC'].tocsr()
# look for papers that appeard in in these conferences
#   sosp = 7
#   soda = 5
#   sigcom =  9
#   vldb = 13
c_selected = [7, 5, 9, 13] 

p_selected = pvc[:, c_selected].tocoo()
# generate labels
labels = pvc.indices
labels[labels==0] = 14
labels[labels==7] = 0
labels[labels==1] = 14
labels[labels == 5] = 1
labels[labels==2] = 14
labels[labels == 9]= 2
labels[labels == 3] = 14
labels[labels == 13] = 3
labels = torch.tensor(labels).long()


In [36]:
# generate train/val/test split
pid = p_selected.row
shuffle = np.random.permutation(pid)
train_idx = torch.tensor(shuffle[0:1400]).long()
val_idx = torch.tensor(shuffle[1400:1500]).long()
test_idx = torch.tensor(shuffle[1500:]).long()

In [37]:
print(len(train_idx), len(test_idx))

1400 719


In [38]:
G.node_dict = {}
G.edge_dict = {}
for ntype in G.ntypes:
    G.node_dict[ntype] = len(G.node_dict)
for etype in G.etypes:
    G.edge_dict[etype] = len(G.edge_dict)
    G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * G.edge_dict[etype] 


In [39]:
#     Random initialize input feature
for ntype in G.ntypes:
    emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), 400), requires_grad = False)
    nn.init.xavier_uniform_(emb)
    G.nodes[ntype].data['inp'] = emb

In [40]:
model = HGT(G, n_inp=400, n_hid=200, n_out=labels.max().item()+1, n_layers=2, n_heads=4, use_norm = True)
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=200, max_lr = 1e-3, pct_start=0.05)


In [41]:
best_val_acc = 0
best_test_acc = 0
train_step = 0
for epoch in range(150):
    logits = model(G, 'paper')
    # The loss is computed only for labeled nodes.
    loss = F.cross_entropy(logits[train_idx], labels[train_idx])

    pred = logits.argmax(1).cpu()
    train_acc = (pred[train_idx] == labels[train_idx]).float().mean()
    val_acc   = (pred[val_idx] == labels[val_idx]).float().mean()
    test_acc  = (pred[test_idx] == labels[test_idx]).float().mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_step += 1
    scheduler.step(train_step)

    if best_val_acc < val_acc:
        best_val_acc = val_acc
        best_test_acc = test_acc
    
    if epoch % 5 == 0:
        print('ep:%3d. LR: %.5f Loss %.4f, Train Acc %.4f, Val Acc %.4f (Best %.4f), Test Acc %.4f (Best %.4f)' % (epoch,
            optimizer.param_groups[0]['lr'], 
            loss.item(),
            train_acc.item(),
            val_acc.item(),
            best_val_acc,
            test_acc.item(),
            best_test_acc,
        ))



ep:  0. LR: 0.00007 Loss 2.5634, Train Acc 0.1129, Val Acc 0.1500 (Best 0.1500), Test Acc 0.1029 (Best 0.1029)
ep:  5. LR: 0.00076 Loss 2.1809, Train Acc 0.2593, Val Acc 0.3000 (Best 0.3000), Test Acc 0.2545 (Best 0.2545)
ep: 10. LR: 0.00100 Loss 1.3803, Train Acc 0.4079, Val Acc 0.3700 (Best 0.4000), Test Acc 0.3922 (Best 0.3811)
ep: 15. LR: 0.00100 Loss 0.8651, Train Acc 0.7336, Val Acc 0.7500 (Best 0.7500), Test Acc 0.7163 (Best 0.7163)
ep: 20. LR: 0.00099 Loss 0.4271, Train Acc 0.8850, Val Acc 0.9300 (Best 0.9300), Test Acc 0.8679 (Best 0.8679)
ep: 25. LR: 0.00098 Loss 0.3519, Train Acc 0.9121, Val Acc 0.9300 (Best 0.9600), Test Acc 0.9013 (Best 0.8957)
ep: 30. LR: 0.00097 Loss 0.2918, Train Acc 0.9236, Val Acc 0.9500 (Best 0.9600), Test Acc 0.9082 (Best 0.8957)
ep: 35. LR: 0.00095 Loss 0.2180, Train Acc 0.9357, Val Acc 0.9500 (Best 0.9600), Test Acc 0.9082 (Best 0.8957)
ep: 40. LR: 0.00093 Loss 0.1874, Train Acc 0.9450, Val Acc 0.9500 (Best 0.9600), Test Acc 0.9124 (Best 0.8957)
e

In [42]:
mat = np.zeros([4,4])
mat

array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]])

In [43]:
for i in train_idx:
    if labels[i] == 0:
        mat[0, pred[i]]+=1
    if labels[i] == 1:
        mat[1, pred[i]]+=1
    if labels[i] == 2:
        mat[2, pred[i]]+=1
    if labels[i] == 3:
        mat[3, pred[i]]+=1
mat

array([[213.,   0.,   0.,   0.],
       [  0., 416.,   0.,   0.],
       [  0.,   0., 408.,   0.],
       [  0.,   0.,   0., 363.]])

In [44]:
mat = np.zeros([4,4])
for i in test_idx:
    if labels[i] == 0:
        mat[0, pred[i]]+=1
    if labels[i] == 1:
        mat[1, pred[i]]+=1
    if labels[i] == 2:
        mat[2, pred[i]]+=1
    if labels[i] == 3:
        mat[3, pred[i]]+=1
mat

array([[ 84.,   2.,  17.,   2.],
       [  1., 202.,   5.,   3.],
       [ 16.,   9., 188.,   5.],
       [  0.,   2.,   3., 180.]])

In [45]:
for i in range(4):
    tot = sum(mat[i,:])
    mat[i,:] = np.round(100.0*mat[i,:]/tot,0)
mat

array([[80.,  2., 16.,  2.],
       [ 0., 96.,  2.,  1.],
       [ 7.,  4., 86.,  2.],
       [ 0.,  1.,  2., 97.]])

In [46]:
import pandas as pd
df = pd.DataFrame(mat, index =['sosp', 'soda', 'sigcom','vldb'])
df

Unnamed: 0,0,1,2,3
sosp,80.0,2.0,16.0,2.0
soda,0.0,96.0,2.0,1.0
sigcom,7.0,4.0,86.0,2.0
vldb,0.0,1.0,2.0,97.0
