In [1]:
import jax.numpy as jnp
from jax import lax, random
from jax.nn.initializers import glorot_normal, glorot_uniform
import jax.nn as nn
import jax
from jax import jit
import numpy.random as npr
from jax import jit, grad, random
from jax.example_libraries import optimizers
from typing import List
import numpy as np
import scipy.sparse as sp
from tensorflow import keras 
import os
import jax
import jax.numpy as np
from jax import jit, grad, random
from jax.example_libraries import optimizers

In [2]:
jax.default_backend()

'gpu'

In [3]:
#function to speed up the creation of random nunmbers
@jit 
def create_random():
    return random.split(random.PRNGKey(npr.randint(0,100)), 4)

In [4]:
# Layer construction function for a dropout layer with given rate.
def Dropout(rate):
    #Constructor function
    def init_fun(input_shape):
        return input_shape, ()
    #Function to compute dropout
    def apply_fun(inputs, is_training, **kwargs):
        # generate a random number generate a bernoulli prob
        rng, rng2, rng3, rng4 = create_random()
        # keep rate
        keep = random.bernoulli(rng,  1.0 - rate, inputs.shape)
        # output that is kept from input features
        outs = keep*inputs/(1.0 -rate) 
        # if not training, just return inputs and discard any computation done
        out = lax.cond(is_training, outs, lambda x: x, inputs, lambda x: x)
        return out

    return init_fun, apply_fun

In [5]:
def GraphAttentionLayer(out_dim, dropout):
    #Main layer for graph attention 
    _, drop_fun = Dropout(dropout)

    def init_fun(input_shape):
        # Constructor, generate input weights
        output_shape = input_shape[:-1] + (out_dim,)
        k1, k2, k3, k4 = create_random()
        # initialize weight
        W = glorot_uniform()(k1, (input_shape[-1], out_dim))
        # initialize nn weight
        a_init = glorot_uniform()
        a1 = a_init(k2, (out_dim, 1))
        a2 = a_init(k3, (out_dim, 1))

        return output_shape, (W, a1, a2)
       
    def apply_fun(params, x, adj, activation=nn.elu, is_training=False, **kwargs):
        #Apply function, compute the attention 
        W, a1, a2 = params
        # initial dropout
        x = drop_fun(x, is_training=is_training)
        # weights matmult
        x = np.dot(x, W)
        # neural netw + alignment score
        f_1 = np.dot(x, a1) 
        f_2 = np.dot(x, a2)
        logits = f_1 + f_2.T
        # softmax of leakyReLu for e
        coefs = nn.softmax( nn.leaky_relu(logits, negative_slope=0.2) + np.where(adj, 0., -1e9))
        # final dropout
        coefs = drop_fun(coefs, is_training=is_training)
        x = drop_fun(x, is_training=is_training)

        ret = np.matmul(coefs, x)

        return activation(ret)
    return init_fun, apply_fun

def MultiHeadLayer(nheads: int, nhid: int, dropout: float,last_layer: bool=False):
    #Multi head attention layer
    
    layer_funs, layer_inits = [], []
    # define the heads layers
    for head_i in range(nheads):
        att_init, att_fun = GraphAttentionLayer(nhid, dropout=dropout)
        # initialize layers of attention
        layer_inits.append(att_init)
        # grab the functions for running attentions
        layer_funs.append(att_fun)
    
    def init_fun(input_shape):
        #Initialize each attention head
        params = []
        # for each head initialize parameters
        for att_init_fun in layer_inits:
            #rng, layer_rng = random.split(rng)
            layer_shape, param = att_init_fun(input_shape)
            params.append(param)

        input_shape = layer_shape
        if not last_layer:
            # multiply by the number of heads
            input_shape = input_shape[:-1] + (input_shape[-1]*len(layer_inits),)
        return input_shape, params
    
    def apply_fun(params, x, adj, is_training=False, **kwargs):
        #Function to apply parameters to head 

        layer_outs = []
        assert len(params) == nheads
        for head_i in range(nheads):
            layer_params = params[head_i]
            layer_outs.append(layer_funs[head_i](layer_params, x, adj, is_training=is_training))
        # concatenate or average
        if not last_layer:
            x = np.concatenate(layer_outs, axis=1)
        else:
            # average last layer heads
            x = np.mean(np.stack(layer_outs), axis=0)

        return x

    return init_fun, apply_fun

In [6]:
def GAT(nheads: List[int], nhid: List[int], nclass: int, dropout: float):
    # Graph Attention Network model definition.
    init_funs = []
    attn_funs = []

    nhid += [nclass]
    for layer_i in range(len(nhid)):
        last = layer_i == len(nhid) - 1
        layer_init, layer_fun = MultiHeadLayer(nheads[layer_i], nhid[layer_i],dropout=dropout,last_layer=last)
        attn_funs.append(layer_fun)
        init_funs.append(layer_init)

    def init_fun(input_shape):
        params = []
        for i, init_fun in enumerate(init_funs):
            layer_shape, param = init_fun(input_shape)
            params.append(param)
            input_shape = layer_shape
        return input_shape, params

    def apply_fun(params, x, adj, is_training=False, **kwargs):

        for i, layer_fun in enumerate(attn_funs):
            x = layer_fun(params[i], x, adj, is_training=is_training)
        
        return nn.log_softmax(x)

    return init_fun, apply_fun

In [7]:
def encode_onehot(labels):
    #Transform labels into a one hot encoded vector
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    print(classes_dict)
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot
def normalize(mx):
    # Function to normalize values of a given sparse array mx
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx
def load_data():
    #Function to load the Cora dataset. 
    # Download file
    zip_file = keras.utils.get_file(
        fname="cora.tgz",
        origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
        extract=True,
    )
    # create the path
    data_dir = os.path.join(os.path.dirname(zip_file), "cora")

    # content data is converted to numpy vector
    idx_features_labels = np.genfromtxt(f"{data_dir}/cora.content", dtype=np.dtype(str))
    
    # Take the bag-of-words vector of each paper as the feature vector of each article and store it in a sparse matrix format
    features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
    
    # Take the type of each paper as a label and convert it into a one hot vector
    labels = encode_onehot(idx_features_labels[:, -1])

    # Take out the id of each paper
    idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
    idx_map = {j: i for i, j in enumerate(idx)}
    
    # cites data is converted to numpy vector
    edges_unordered = np.genfromtxt(f"{data_dir}/cora.cites",dtype=np.int32)
    
    # Map the id in the cites data to the interval [0, 2708]
    edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                     dtype=np.int32).reshape(edges_unordered.shape)
    
    # Store the citation relationship between papers in a sparse matrix format
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(labels.shape[0], labels.shape[0]),
                        dtype=np.float32)
    
    # build symmetric adjacency matrix
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    
    # Normalize the characteristics of the article
    features = normalize(features)
    adj = normalize(adj + sp.eye(adj.shape[0]))
    # Produce the final vector
    idx_train = range(140)
    idx_val = range(200, 500)
    idx_test = range(500, 1500)

    features = np.array(features.todense())

    # JAX doesn't support sparse matrices yet
    adj = np.asarray(adj.todense())

    return adj, features, labels, np.array(idx_train), np.array(idx_val), np.array(idx_test)

In [8]:
adj, features, labels, idx_train, idx_val, idx_test = load_data()

Downloading data from https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
{'Probabilistic_Methods': array([1., 0., 0., 0., 0., 0., 0.]), 'Rule_Learning': array([0., 1., 0., 0., 0., 0., 0.]), 'Reinforcement_Learning': array([0., 0., 1., 0., 0., 0., 0.]), 'Neural_Networks': array([0., 0., 0., 1., 0., 0., 0.]), 'Theory': array([0., 0., 0., 0., 1., 0., 0.]), 'Case_Based': array([0., 0., 0., 0., 0., 1., 0.]), 'Genetic_Algorithms': array([0., 0., 0., 0., 0., 0., 1.])}


In [9]:


@jit
def loss(params, batch):
    # The indexes of the batch indicate which nodes are used to compute the loss.
    inputs, targets, adj, is_training, idx = batch
    preds = predict_fun(params, inputs, adj, is_training=is_training)
    ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1))
    l2_loss = 5e-4 * optimizers.l2_norm(params)**2 
    return ce_loss + l2_loss


@jit
def accuracy(params, batch):
    inputs, targets, adj, is_training, idx = batch
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(predict_fun(params, inputs, adj, 
        is_training=is_training), axis=1)
    return np.mean(predicted_class[idx] == target_class[idx])


@jit
def loss_accuracy(params, batch):
    inputs, targets, adj, is_training, idx = batch
    preds = predict_fun(params, inputs, adj, is_training=is_training)
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(preds, axis=1)
    ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1))
    acc = np.mean(predicted_class[idx] == target_class[idx])
    return ce_loss, acc

In [10]:
%%time
lr = 0.05
num_epochs = 100
n_nodes = adj.shape[0]
n_feats = features.shape[1]

# GAT params
nheads = [8, 1]
nhid = [8]
dropout = 0.6 # probability of keeping
residual = False

init_fun, predict_fun = GAT(nheads=nheads,
                            nhid=nhid,
                            nclass=7,
                            dropout=dropout,
                            )

input_shape = (-1, n_nodes, n_feats)
_, init_params = init_fun(input_shape)

opt_init, opt_update, get_params = optimizers.sgd(lr)

CPU times: user 2.54 s, sys: 1.02 s, total: 3.57 s
Wall time: 11.2 s


In [11]:
@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

opt_state = opt_init(init_params)

In [12]:
print("\nStarting training...")
for epoch in range(num_epochs):
    
    batch = (features, labels, adj, True, idx_train)
    opt_state = update(epoch, opt_state, batch)

    params = get_params(opt_state)
    eval_batch = (features, labels, adj, False, idx_val)
    train_batch = (features, labels, adj, False, idx_train)
    # additional step, everything can be loaded onto the GPU:
    train_batch = jax.device_put(train_batch)
    eval_batch = jax.device_put(eval_batch)
    # without that we take about 1 min
    train_loss, train_acc = loss_accuracy(params, train_batch)
    val_loss, val_acc = loss_accuracy(params, eval_batch)
    if epoch%10==0:
        print((f"Iter {epoch}/{num_epochs} train_loss:"+
            f"{train_loss:.4f}, train_acc: {train_acc:.4f}, val_loss:"+
            f"{val_loss:.4f}, val_acc: {val_acc:.4f}"))


Starting training...
Iter 0/100 train_loss:1.9459, train_acc: 0.1143, val_loss:1.9459, val_acc: 0.1500
Iter 10/100 train_loss:1.9456, train_acc: 0.1214, val_loss:1.9456, val_acc: 0.1567
Iter 20/100 train_loss:1.9453, train_acc: 0.1357, val_loss:1.9453, val_acc: 0.1567
Iter 30/100 train_loss:1.9450, train_acc: 0.1286, val_loss:1.9451, val_acc: 0.1567
Iter 40/100 train_loss:1.9447, train_acc: 0.1357, val_loss:1.9448, val_acc: 0.1800
Iter 50/100 train_loss:1.9444, train_acc: 0.1571, val_loss:1.9446, val_acc: 0.1867
Iter 60/100 train_loss:1.9441, train_acc: 0.1714, val_loss:1.9443, val_acc: 0.2000
Iter 70/100 train_loss:1.9438, train_acc: 0.1786, val_loss:1.9441, val_acc: 0.2133
Iter 80/100 train_loss:1.9435, train_acc: 0.2143, val_loss:1.9439, val_acc: 0.2367
Iter 90/100 train_loss:1.9432, train_acc: 0.2357, val_loss:1.9436, val_acc: 0.2567


In [13]:
# now run on the test set
test_batch = (features, labels, adj, False, idx_test)
test_acc = accuracy(params, test_batch)
print(f'Test set acc: {test_acc}')

Test set acc: 0.19700001180171967
