### Train GAE

In [1]:
import time
import math
from importlib import reload 
import pdb 
import numpy as np
import scipy.sparse as sp
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as ds
from distributions import VonMisesFisher, HypersphericalUniform

from gae.utils import load_data, mask_test_edges, preprocess_graph, get_roc_score
from gae.layers import GraphConvolution, GraphAttentionLayer
from gae.model import InnerProductDecoder
from gae.sparse import SpGraphAttentionLayer

In [2]:
class obj(object):
    def __init__(self, d):
        for a, b in d.items():
            if isinstance(b, (list, tuple)):
               setattr(self, a, [obj(x) if isinstance(x, dict) else x for x in b])
            else:
               setattr(self, a, obj(b) if isinstance(b, dict) else b)
args = {
    'model': 'gcn_vae',
    'seed': 72,
    'epochs': 200,
    'lr': 7.5e-3,
    'dropout': 0.6,
}
args = obj(args)
device = torch.device('cpu')

Below, we load and preprocess our data. We load the data with the loading functions from the GCN repo because this repo contains labels. We preprocess the data with functions from the GAE repo. 

In [3]:
# Load graph and features
data = torch.load('../data/cora/preprocessed_gcn_data_for_gae.pth')

# Unpack data
adj = sp.csr_matrix(data['adj_orig'])
features = (data['feats'] != 0).float()
labels = data['labels']
idx_train, idx_val, idx_test = [data['idx_'+s] for s in ['train','val','test']]

# Get info
n_nodes, feat_dim = features.shape

# Store original adjacency matrix (without diagonal entries) for later
adj_orig = adj
adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
adj_orig.eliminate_zeros()

# Split edges into train, val, test
adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
adj = adj_train

# Some preprocessing
adj_norm = preprocess_graph(adj)
adj_label = adj_train + sp.eye(adj_train.shape[0])
adj_label = torch.FloatTensor(adj_label.toarray())

# Normalization
pos_weight = float(float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum())
norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)

# GPU
adj_norm = adj_norm.to_dense().to(device)
adj_label = adj_label.to(device)
features = features.to(device)
labels = labels.to(device)

Below, we defined a GAE with two outputs: link prediction and node classification

In [4]:
class GAVAE(nn.Module):
    def __init__(self, in_dim, h_dim, c_dim, dropout, n_heads, sparse=True):
        super(GAVAE, self).__init__()
        self.dropout = dropout
        
        # Construct attention layers
        self.attentions = [SpGraphAttentionLayer(in_dim, h_dim, dropout=dropout, concat=True) 
                           for _ in range(n_heads)]
        for i, attention in enumerate(self.attentions): # name layers
            self.add_module('attention_{}'.format(i), attention)
        
        self.attentions2 = [SpGraphAttentionLayer(h_dim * n_heads, h_dim, dropout=dropout, concat=True) 
                           for _ in range(n_heads)]
        for i, attention in enumerate(self.attentions2): # name layers
            self.add_module('attention2_{}'.format(i), attention)

        # Output encoder layer
        self.mu_att = GraphAttentionLayer(h_dim * n_heads, h_dim, dropout=dropout, concat=False)
        self.lv_att = GraphAttentionLayer(h_dim * n_heads,     1, dropout=dropout, concat=False)
        
        # Decoder layer
        self.linear = nn.Linear(h_dim, c_dim)
        
    def encode(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions2], dim=1) # added
        x = F.dropout(x, self.dropout, training=self.training) # added
        mu = self.mu_att(x, adj) 
        lv = self.lv_att(x, adj)
        return mu, lv

    def reparameterize(self, mu, lv):
        q = ds.Normal(mu, F.softplus(lv) + 1)
        p = ds.Normal(torch.zeros_like(mu), torch.ones_like(lv))
        #mu, lv = mu.cpu(), lv.cpu()
        #q = VonMisesFisher(mu / mu.norm(dim=-1, keepdim=True), F.softplus(lv) + 1)
        #p = HypersphericalUniform(7)
        return q, p
    
    def forward(self, x, adj):
        mu, lv = self.encode(x, adj) # get mean and log standard deviation
        q, p = self.reparameterize(mu, lv) # get variational dist and prior
        if self.training: # sample latent variable
            z = q.rsample().to(device)
        else: # use mean
            z = q.loc.to(device)
        o = self.linear(z) # decode
        return o, q, p, z, mu, lv

In [5]:
def accuracy(y_hat, y):
    '''Accuracy of prediction (max of prediction)'''
    preds = y_hat.max(1)[1].type_as(y)
    acc = preds.eq(y).float().sum() / len(y)
    return acc

In [6]:
# Manual seed
torch.manual_seed(args.seed)

# Init model 
model = GAVAE(in_dim=feat_dim, 
              h_dim=8,
              c_dim=labels.max().item() + 1,
              dropout=args.dropout,
              n_heads=8).to(device)

# Init optimizer
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# Train
val_every = 10
hidden_emb = None
for epoch in range(args.epochs):
    t = time.time()
    
    # Train
    model.train()
    optimizer.zero_grad()
    
    # Forward pass (output, q(z), p(z), mean, log standard deviation)
    out, q, p, z, mu, lv = model(features, adj_norm)
    
    # Loss
    log_likelihood = F.cross_entropy(out[idx_train], labels[idx_train]).mean()
    kl_divergence = torch.distributions.kl.kl_divergence(q,p).mean().to(device)
    loss = log_likelihood + 0 * kl_divergence
    
    # Backprop
    loss.backward()
    optimizer.step()
    
    # Statistics
    acc_val_tmode = accuracy(out[idx_val], labels[idx_val])
    print_string = ('Epoch [{:4d}]   '.format(epoch + 1) + 
                    'Train loss = {:.3f}   '.format(loss.item()) + 
                    'Val_tmode acc = {:.3f}   '.format(acc_val_tmode.item()))
    
    # Val
    if epoch % val_every == 0:
        model.eval()
        with torch.no_grad():
            out, q, p, z, mu, lv = model(features, adj_norm)
            val_acc = accuracy(out[idx_test], labels[idx_test])
        print_string += ('Val acc = {:.3f}   '.format(val_acc))

KeyboardInterrupt: 