<a href="https://colab.research.google.com/gist/Steboss89/b21d6abe548d106119666fec6b65965f/publicgat.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### CPU implementation 



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as funct


class GraphAttentionLayer(nn.Module):
    r""" Main GAT class implementing Bahdanau's Attention"""
    def __init__(self, in_features, out_features, dropout=0.6, concat_output=True):
        r""" Constructor
        Define input and output sizes
        Parameters
        ----------
        in_features: int, input size
        out_features: int, output size
        dropout: float, dropout rate 
        concat_output: Bool, default True, concatenate attentions' outputs
        """
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.concat_output = concat_output
        self.dropout = dropout
    
        # initialize weight matrix W 
        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data)
        # initialize attention nnet 'a'--> 2 layers and output is 1
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data)
        
    def forward(self, h, adjacency):
        r""" Define the forward step where the attention is computed
        Parameters
        ----------
        h: np.array input nodes' features
        adjacency: np.array, input nodes' adjacency matrix 
        """
        # first compute W*h
        WH = torch.matmul(h, self.W) # size N (nodes) * out_features
        # compute WH*first layer from neural network a 
        # for info check
        # https://towardsdatascience.com/graph-neural-networks-a-learning-journey-since-2008-graph-attention-networks-f8c39189e7fc#feee
        # take first layer 
        WH1 = torch.matmul(WH, self.a[:self.out_features, :])
        # take second layer 
        WH2 = torch.matmul(WH, self.a[self.out_features:, :])
        # sum up everything to compute the alignment score 
        e = nn.LeakyReLU(0.2)(WH1 + WH2.T)
        # here we gather adjacency matrix with e
        # to avoid overflow substitue 0 values with -1e9
        # and where adjacency values > 0 use e values
        attnt = torch.where(adjacency > 0, e, -1e9*torch.ones_like(e) ) 
        # compute the softmax 
        attnt = funct.softmax(attnt, dim=1)
        # add the dropout step. self.training is a boolean define if we're at training time
        attnt = funct.dropout(attnt, self.dropout, training=self.training)
        # at this point we compute the new nodes' features representation h'
        hfirst = torch.matmul(attnt, WH)
        # here we can concatenate the output or 
        # return the final hrist 
        if self.concat_output:
            return funct.elu(hfirst)
        else:
            return hfirst

In [None]:
class GAT(nn.Module):
    r""" Main GAT model, ready to be trained, with multi head attention"""
    def __init__(self, in_features, out_features, nclass, nheads, dropout=0.6):
        r""" Constructore for GAT 
        Parameters
        ----------
        in_features: int, number of input features 
        out_features: int, number of output features from attention layers 
        nclass: int, total number of class to be predicted 
        nheads: int, number of attention heads
        """
        super(GAT, self).__init__()
        self.dropout = dropout
        # now create all the attention heads
        self.attheads = [GraphAttentionLayer(in_features, out_features, concat_output=True) for _ in range(nheads)]
        # output is a final attention without concat, which takes as input all the previous outputs 
        self.output = GraphAttentionLayer(out_features*nheads, nclass, concat_output=False)

        
    def forward(self, X, adjacency):
        r""" Main forwards step
        Parameters
        ----------
        X: np.array, input nodes' featuers 
        adjacency: np.array, adjacency matrix 
        """
        X = funct.dropout(X, self.dropout, training=self.training)
        # compute the attention from each head given the input 
        X = torch.cat([attn(X, adjacency) for attn in self.attheads], dim=1) 
        # dropout 
        x = funct.dropout(X, self.dropout, training=self.training)
        # ELU activation 
        X = funct.elu(self.output(X, adjacency))
        # return the final result 
        return funct.log_softmax(X, dim=1)

In [None]:
import numpy as np
import scipy.sparse as sp
from tensorflow import keras 
import os


def encode_onehot(labels):
    r""" Transform labels into a one hot encoded vector
    Parameters
    ----------
    labels: np.array, this is the vector of labels 

    Return 
    ------
    labels_onehot: np.array, one hot encoded label 
    """
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    print(classes_dict)
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot


def normalize(mx):
    r""" Function to normalize values of a given sparse array mx
    Parameters
    ----------
    mx: scipy.sparse.coo_matrix, input sparse matrix to be normalized 

    Return 
    ------
    mx: scipy.sparse.coo_matrix, normalize array
    """

    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


def load_data():
    r""" Function to load the Cora dataset. 
    The Cora dataset is downloaded thorugh tensorflow keras. Nodes' features 
    and adjacency matrix are created from cora.content and cora.cites
    """
    
    # Download file
    zip_file = keras.utils.get_file(
        fname="cora.tgz",
        origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
        extract=True,
    )
    # create the path
    data_dir = os.path.join(os.path.dirname(zip_file), "cora")

    # content data is converted to numpy vector
    idx_features_labels = np.genfromtxt(f"{data_dir}/cora.content", dtype=np.dtype(str))
    
    # Take the bag-of-words vector of each paper as the feature vector of each article and store it in a sparse matrix format
    features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
    
    # Take the type of each paper as a label and convert it into a one hot vector
    labels = encode_onehot(idx_features_labels[:, -1])

    # Take out the id of each paper
    idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
    idx_map = {j: i for i, j in enumerate(idx)}
    
    # cites data is converted to numpy vector
    edges_unordered = np.genfromtxt(f"{data_dir}/cora.cites",dtype=np.int32)
    
    # Map the id in the cites data to the interval [0, 2708]
    edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                     dtype=np.int32).reshape(edges_unordered.shape)
    
    # Store the citation relationship between papers in a sparse matrix format
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(labels.shape[0], labels.shape[0]),
                        dtype=np.float32)
    
    # build symmetric adjacency matrix
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    
    # Normalize the characteristics of the article
    features = normalize(features)
    adj = normalize(adj + sp.eye(adj.shape[0]))
    
    # Produce the final vector
    idx_train = range(140)
    idx_val = range(200, 500)
    idx_test = range(500, 1500) 

    features = torch.FloatTensor(np.array(features.todense()))
    labels = torch.LongTensor(np.where(labels)[1])
    adj = torch.LongTensor(adj.todense())

    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    idx_test = torch.LongTensor(idx_test)

    edge_tensor = torch.LongTensor(edges)

    print(f"Train elements {len(idx_train)}")
    print(f"Validation elements {len(idx_val)}")
    print(f"Test elements {len(idx_test)}")

    return adj, features, labels, idx_train, idx_val, idx_test, edge_tensor


adj, features, labels, idx_train, idx_val, idx_test, edge_tensor = load_data()

Downloading data from https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
{'Neural_Networks': array([1., 0., 0., 0., 0., 0., 0.]), 'Reinforcement_Learning': array([0., 1., 0., 0., 0., 0., 0.]), 'Theory': array([0., 0., 1., 0., 0., 0., 0.]), 'Rule_Learning': array([0., 0., 0., 1., 0., 0., 0.]), 'Case_Based': array([0., 0., 0., 0., 1., 0., 0.]), 'Genetic_Algorithms': array([0., 0., 0., 0., 0., 1., 0.]), 'Probabilistic_Methods': array([0., 0., 0., 0., 0., 0., 1.])}
Train elements 140
Validation elements 300
Test elements 1000


In [None]:
import torch.optim as optim
from torch.autograd import Variable

seed = 42
epochs = 100
lr = 0.005 
weight_decay = 5e-4 
hidden = 8 
heads = 8 

model = GAT(in_features = features.shape[1], 
            out_features = hidden,
            nclass=7,
            nheads = heads)

optimizer = optim.Adam(model.parameters(), 
                       lr=lr, 
                       weight_decay=weight_decay)

In [None]:
def train(epoch):
    r""" This function train the GAT model. The function works through variables
    defined in the code - but we can refactor this function to work anyway - 
    Parameters
    ----------
    epoch: int, current epoch 
    """

    model.train()
    optimizer.zero_grad()
    output = model(features, adj)
    loss_train = funct.nll_loss(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()

    # validation 
    loss_val = funct.nll_loss(output[idx_val], labels[idx_val])
    
    if epoch%10==0:
        print(f'Epoch: {epoch+1}'.format(epoch+1),
            f'loss_train: {loss_train.data.item()}'
            f'loss_val: {loss_val.data.item()}',
            )


In [None]:
%%time

import time 

for epoch in range(epochs):
    train(epoch)

Epoch: 1 loss_train: 1.945987343788147loss_val: 1.9458523988723755
Epoch: 11 loss_train: 1.943637490272522loss_val: 1.9434393644332886
Epoch: 21 loss_train: 1.941507339477539loss_val: 1.9412577152252197
Epoch: 31 loss_train: 1.9395989179611206loss_val: 1.9393465518951416
Epoch: 41 loss_train: 1.9375416040420532loss_val: 1.9372817277908325
Epoch: 51 loss_train: 1.9359586238861084loss_val: 1.9355454444885254
Epoch: 61 loss_train: 1.9342877864837646loss_val: 1.9336856603622437
Epoch: 71 loss_train: 1.9333196878433228loss_val: 1.9326894283294678
Epoch: 81 loss_train: 1.9325860738754272loss_val: 1.9318381547927856
Epoch: 91 loss_train: 1.9313238859176636loss_val: 1.9303865432739258
CPU times: user 2min 46s, sys: 8.12 s, total: 2min 55s
Wall time: 2min 55s


### CPU implementation with torch geometric

In [None]:
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric 

Looking in links: https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 5.4 MB/s 
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.0.9
Looking in links: https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_sparse-0.6.13-cp37-cp37m-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 5.0 MB/s 
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.13
Looking in links: https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
Collecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl (2.5 MB)
[K     |████████████████████████████████| 2.5

In [None]:
!nvidia-smi

Fri Mar 25 09:27:57 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   69C    P8    31W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import torch 
import torch.nn as nn 
import torch.nn.functional as funct 

from torch_geometric.data import Data 
from torch_geometric.nn import GATConv 
from torch_geometric.datasets import Planetoid 
import torch_geometric.transforms as T 

dataset = Planetoid(root="/tmp/Cora", name="Cora")
dataset.transform = T.NormalizeFeatures() 
#print(f"Number of Classes in:", dataset.num_classes)
#print(f"Number of Node Features in:", dataset.num_node_features)

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [None]:
class GAT(torch.nn.Module):
    r""" Main class for GAT"""
    def __init__(self, in_features, out_features, nclass, nheads): # in_features, out_features, nclass, nheads):
        r""" Constructor, define input features, output features, nclass and nheads 
        for attention layer"""
        super(GAT, self).__init__()
        self.hid = out_features
        self.in_head = nheads
        self.out_head = out_features
        self.nclass = nclass
        
        #GATConv(in_channels, out_channels, heads, concat)
        # This is the multi-head layer so concat outputs
        self.conv1 = GATConv(in_features,
                             self.hid,
                             heads=self.in_head,
                             dropout=0.6)
        # here we want to average
        # in_feautres = output from self.conv1 (self.hid)*number of heads
        # out features --> number of classes to predict
        self.conv2 = GATConv(self.hid*self.in_head,
                             self.nclass,
                             concat=False,
                             heads=self.out_head,
                             dropout=0.6)

    def forward(self, X, adjacency):
        r""" forward step for training"""
        x, edge_index = X, adjacency
                
        x = funct.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = funct.elu(x)
        x = funct.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        
        return funct.log_softmax(x, dim=1)

hidden = 8 
nheads = 8
nclass = dataset.num_classes

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device {device}")

# data 
adj, features, labels, idx_train, idx_val, idx_test, edge_tensor = load_data()
# send model to device 
model = GAT(in_features = features.shape[1],#dataset.num_features, 
            out_features = hidden,
            nclass = 7, #nclass,
            nheads = nheads).to(device)
#data = dataset[0].to(device)

# define optimizer 
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

Device cpu
{'Neural_Networks': array([1., 0., 0., 0., 0., 0., 0.]), 'Theory': array([0., 1., 0., 0., 0., 0., 0.]), 'Reinforcement_Learning': array([0., 0., 1., 0., 0., 0., 0.]), 'Probabilistic_Methods': array([0., 0., 0., 1., 0., 0., 0.]), 'Genetic_Algorithms': array([0., 0., 0., 0., 1., 0., 0.]), 'Rule_Learning': array([0., 0., 0., 0., 0., 1., 0.]), 'Case_Based': array([0., 0., 0., 0., 0., 0., 1.])}
Train elements 140
Validation elements 300
Test elements 1000


In [None]:
%%time
model.train()
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    output = model(features, edge_tensor.T)#model(data.x, data.edge_index)
    loss =  funct.nll_loss(output[idx_train], labels[idx_train])#funct.nll_loss(output[data.train_mask], data.y[data.train_mask])
    # validation 
    loss_val = funct.nll_loss(output[idx_val], labels[idx_val])

    if epoch%10==0:
        print(f'Epoch: {epoch+1}'.format(epoch+1),
            f'loss_train: {loss.data.item()}'
            f'loss_val: {loss_val.data.item()}',
            )


    loss.backward()
    optimizer.step()

Epoch: 1 loss_train: 1.9463061094284058loss_val: 1.945523738861084
Epoch: 11 loss_train: 1.8703289031982422loss_val: 1.8736236095428467
Epoch: 21 loss_train: 1.7788177728652954loss_val: 1.7778674364089966
Epoch: 31 loss_train: 1.7344424724578857loss_val: 1.7231016159057617
Epoch: 41 loss_train: 1.6397923231124878loss_val: 1.7029311656951904
Epoch: 51 loss_train: 1.554422378540039loss_val: 1.6239001750946045
Epoch: 61 loss_train: 1.446123480796814loss_val: 1.5282834768295288
Epoch: 71 loss_train: 1.3288929462432861loss_val: 1.4940941333770752
Epoch: 81 loss_train: 1.2511863708496094loss_val: 1.4255460500717163
Epoch: 91 loss_train: 1.1075749397277832loss_val: 1.3526391983032227
CPU times: user 9.44 s, sys: 139 ms, total: 9.58 s
Wall time: 9.58 s


# GAT in JAX


In [None]:
from typing import List

import jax.numpy as np
from jax import lax, random
from jax.nn.initializers import glorot_normal, glorot_uniform
import jax.nn as nn


In [None]:
!nvidia-smi

Mon Mar 28 11:22:49 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   73C    P0    75W / 149W |  10442MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from typing import List

import jax.numpy as np
from jax import lax, random
from jax.nn.initializers import glorot_normal, glorot_uniform
import jax.nn as nn
from jax import jit
import numpy.random as npr

@jit 
def create_random():
    r""" Use this function to speed up the creation of random nunmbers

    Return 
    ------
    jax.random.PRNGKey
    """
    return random.split(random.PRNGKey(npr.randint(0,100)), 4)


def Dropout(rate):
    r""" Layer construction function for a dropout layer with given rate.
    This Dropout layer is modified from stax.experimental.Dropout, to use
    `is_training` as an argument to apply_fun, instead of defining it at
    definition time.

    Parameters
    -----------
    rate: float, probability of keeping an element

    Return 
    ------
    init_fun: initializer function 
    apply_fun: application function
    """

    def init_fun(input_shape):
        r""" Constructor
        Parameters
        ----------
        input_shape: input dimension
        """
        return input_shape, ()

    def apply_fun(inputs, is_training, **kwargs):
        r""" Function to compute dropout 
        Parameters
        ----------
        inputs: input features 
        is_training: bool, if the model is in training

        Return 
        ------
        out: lax.condition 
        """
        # generate a random number generate a bernoulli prob
        rng, rng2, rng3, rng4 = create_random()
        # keep rate
        keep = random.bernoulli(rng,  1.0 - rate, inputs.shape)
        # output that is kept from input features
        outs = keep*inputs/(1.0 -rate) 
        # if not training, just return inputs and discard any computation done
        out = lax.cond(is_training, outs, lambda x: x, inputs, lambda x: x)
        return out

    return init_fun, apply_fun


def GraphAttentionLayer(out_dim, dropout):
    r""" Main layer for graph attention 
    Parameters
    ----------
    out_dim: output dimension 
    dropout: float, dropout rate 
    """
    _, drop_fun = Dropout(dropout)

    def init_fun(input_shape):
        r""" Constructor, generate input weights"""
        output_shape = input_shape[:-1] + (out_dim,)
        k1, k2, k3, k4 = create_random()
        # initialize weight
        W = glorot_uniform()(k1, (input_shape[-1], out_dim))
        # initialize nn weight
        a_init = glorot_uniform()
        a1 = a_init(k2, (out_dim, 1))
        a2 = a_init(k3, (out_dim, 1))

        return output_shape, (W, a1, a2)
       
    def apply_fun(params, x, adj, activation=nn.elu, is_training=False, **kwargs):
        r"""Apply function, compute the attention 
        Parameters
        ----------
        params: input parameters for weights W, a1, a2
        x: input nodes'features
        adj: adjacency matrix 
        """
        W, a1, a2 = params
        # initial dropout
        x = drop_fun(x, is_training=is_training)
        # weights matmult
        x = np.dot(x, W)
        # neural netw + alignment score
        f_1 = np.dot(x, a1) 
        f_2 = np.dot(x, a2)
        logits = f_1 + f_2.T
        # softmax of leakyReLu for e
        coefs = nn.softmax( nn.leaky_relu(logits, negative_slope=0.2) + np.where(adj, 0., -1e9))
        # final dropout
        coefs = drop_fun(coefs, is_training=is_training)
        x = drop_fun(x, is_training=is_training)

        ret = np.matmul(coefs, x)

        return activation(ret)

    return init_fun, apply_fun


def MultiHeadLayer(nheads: int, nhid: int, dropout: float,last_layer: bool=False):
    r""" Multi head attention layer
    Parameters
    ----------
    nheads: int, number of attention heads 
    nhid: int, number of hidden units 
    dropout: float, percentage of dropout
    last_layer: bool, if last lyer average, otherwise concat
    """
    
    layer_funs, layer_inits = [], []
    # define the heads layers
    for head_i in range(nheads):
        att_init, att_fun = GraphAttentionLayer(nhid, dropout=dropout)
        # initialize layers of attention
        layer_inits.append(att_init)
        # grab the functions for running attentions
        layer_funs.append(att_fun)
    
    def init_fun(input_shape):
        r""" Initialize each attention head
        Parameters
        ----------
        input_shapee: int, input shape

        Return 
        ------
        input_shape: int, input shape 
        params: list, list of parameters for each attention head
        """
        params = []
        # for each head initialize parameters
        for att_init_fun in layer_inits:
            #rng, layer_rng = random.split(rng)
            layer_shape, param = att_init_fun(input_shape)
            params.append(param)

        input_shape = layer_shape
        if not last_layer:
            # multiply by the number of heads
            input_shape = input_shape[:-1] + (input_shape[-1]*len(layer_inits),)
        return input_shape, params
    
    def apply_fun(params, x, adj, is_training=False, **kwargs):
        r""" Function to apply parameters to head 
        Parameters
        ----------
        params: list, lis tof parameters for attention 
        x: array, input array
        adj: array, input adjacency 
        is_training:  bool
        """

        layer_outs = []
        assert len(params) == nheads
        for head_i in range(nheads):
            layer_params = params[head_i]
            layer_outs.append(layer_funs[head_i](layer_params, x, adj, is_training=is_training))
        # concatenate or average
        if not last_layer:
            x = np.concatenate(layer_outs, axis=1)
        else:
            # average last layer heads
            x = np.mean(np.stack(layer_outs), axis=0)

        return x

    return init_fun, apply_fun


def GAT(nheads: List[int], nhid: List[int], nclass: int, dropout: float):
    """
    Graph Attention Network model definition.
    """

    init_funs = []
    attn_funs = []

    nhid += [nclass]
    for layer_i in range(len(nhid)):
        last = layer_i == len(nhid) - 1
        layer_init, layer_fun = MultiHeadLayer(nheads[layer_i], nhid[layer_i],dropout=dropout,last_layer=last)
        attn_funs.append(layer_fun)
        init_funs.append(layer_init)

    def init_fun(input_shape):
        params = []
        for i, init_fun in enumerate(init_funs):
            layer_shape, param = init_fun(input_shape)
            params.append(param)
            input_shape = layer_shape
        return input_shape, params

    def apply_fun(params, x, adj, is_training=False, **kwargs):

        for i, layer_fun in enumerate(attn_funs):
            x = layer_fun(params[i], x, adj, is_training=is_training)
        
        return nn.log_softmax(x)

    return init_fun, apply_fun


In [None]:
# recreate the load data funciton
import numpy as np
import scipy.sparse as sp
from tensorflow import keras 
import os


def encode_onehot(labels):
    r""" Transform labels into a one hot encoded vector
    Parameters
    ----------
    labels: np.array, this is the vector of labels 

    Return 
    ------
    labels_onehot: np.array, one hot encoded label 
    """
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    print(classes_dict)
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot


def normalize(mx):
    r""" Function to normalize values of a given sparse array mx
    Parameters
    ----------
    mx: scipy.sparse.coo_matrix, input sparse matrix to be normalized 

    Return 
    ------
    mx: scipy.sparse.coo_matrix, normalize array
    """

    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


def load_data():
    r""" Function to load the Cora dataset. 
    The Cora dataset is downloaded thorugh tensorflow keras. Nodes' features 
    and adjacency matrix are created from cora.content and cora.cites
    """
    
    # Download file
    zip_file = keras.utils.get_file(
        fname="cora.tgz",
        origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
        extract=True,
    )
    # create the path
    data_dir = os.path.join(os.path.dirname(zip_file), "cora")

    # content data is converted to numpy vector
    idx_features_labels = np.genfromtxt(f"{data_dir}/cora.content", dtype=np.dtype(str))
    
    # Take the bag-of-words vector of each paper as the feature vector of each article and store it in a sparse matrix format
    features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
    
    # Take the type of each paper as a label and convert it into a one hot vector
    labels = encode_onehot(idx_features_labels[:, -1])

    # Take out the id of each paper
    idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
    idx_map = {j: i for i, j in enumerate(idx)}
    
    # cites data is converted to numpy vector
    edges_unordered = np.genfromtxt(f"{data_dir}/cora.cites",dtype=np.int32)
    
    # Map the id in the cites data to the interval [0, 2708]
    edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                     dtype=np.int32).reshape(edges_unordered.shape)
    
    # Store the citation relationship between papers in a sparse matrix format
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(labels.shape[0], labels.shape[0]),
                        dtype=np.float32)
    
    # build symmetric adjacency matrix
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    
    # Normalize the characteristics of the article
    features = normalize(features)
    adj = normalize(adj + sp.eye(adj.shape[0]))
    
    # Produce the final vector
    idx_train = range(140)
    idx_val = range(200, 500)
    idx_test = range(500, 1500)

    features = np.array(features.todense())

    # JAX doesn't support sparse matrices yet
    adj = np.asarray(adj.todense())

    return adj, features, labels, np.array(idx_train), np.array(idx_val), np.array(idx_test)

In [None]:
adj, features, labels, idx_train, idx_val, idx_test = load_data()

{'Case_Based': array([1., 0., 0., 0., 0., 0., 0.]), 'Rule_Learning': array([0., 1., 0., 0., 0., 0., 0.]), 'Reinforcement_Learning': array([0., 0., 1., 0., 0., 0., 0.]), 'Neural_Networks': array([0., 0., 0., 1., 0., 0., 0.]), 'Probabilistic_Methods': array([0., 0., 0., 0., 1., 0., 0.]), 'Theory': array([0., 0., 0., 0., 0., 1., 0.]), 'Genetic_Algorithms': array([0., 0., 0., 0., 0., 0., 1.])}


In [None]:
import jax
import jax.numpy as np
from jax import jit, grad, random
from jax.experimental import optimizers

@jit
def loss(params, batch):
    """
    The idxes of the batch indicate which nodes are used to compute the loss.
    """
    inputs, targets, adj, is_training, idx = batch
    preds = predict_fun(params, inputs, adj, is_training=is_training)
    ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1))
    l2_loss = 5e-4 * optimizers.l2_norm(params)**2 # tf doesn't use sqrt
    return ce_loss + l2_loss


@jit
def accuracy(params, batch):
    inputs, targets, adj, is_training, idx = batch
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(predict_fun(params, inputs, adj, 
        is_training=is_training), axis=1)
    return np.mean(predicted_class[idx] == target_class[idx])


@jit
def loss_accuracy(params, batch):
    inputs, targets, adj, is_training, idx = batch
    preds = predict_fun(params, inputs, adj, is_training=is_training)
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(preds, axis=1)
    ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1))
    acc = np.mean(predicted_class[idx] == target_class[idx])
    return ce_loss, acc


In [None]:
%%time
lr = 0.05
num_epochs = 100
n_nodes = adj.shape[0]
n_feats = features.shape[1]

# GAT params
nheads = [8, 1]
nhid = [8]
dropout = 0.6 # probability of keeping
residual = False

init_fun, predict_fun = GAT(nheads=nheads,
                            nhid=nhid,
                            nclass=7,
                            dropout=dropout,
                            )

input_shape = (-1, n_nodes, n_feats)
_, init_params = init_fun(input_shape)

opt_init, opt_update, get_params = optimizers.adam(lr)

@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

opt_state = opt_init(init_params)

print("\nStarting training...")
for epoch in range(num_epochs):
    
    batch = (features, labels, adj, True, idx_train)
    opt_state = update(epoch, opt_state, batch)

    params = get_params(opt_state)
    eval_batch = (features, labels, adj, False, idx_val)
    train_batch = (features, labels, adj, False, idx_train)
    # additional step, everything can be loaded onto the GPU:
    train_batch = jax.device_put(train_batch)
    eval_batch = jax.device_put(eval_batch)
    # without that we take about 1 min
    train_loss, train_acc = loss_accuracy(params, train_batch)
    val_loss, val_acc = loss_accuracy(params, eval_batch)
    if epoch%10==0:
        print((f"Iter {epoch}/{num_epochs} train_loss:"+
            f"{train_loss:.4f}, train_acc: {train_acc:.4f}, val_loss:"+
            f"{val_loss:.4f}, val_acc: {val_acc:.4f}"))

    # new random key at each iteration, othwerwise dropout uses always 
    # the same mask 

# now run on the test set
test_batch = (features, labels, adj, False, idx_test)
test_acc = accuracy(params, test_batch)
print(f'Test set acc: {test_acc}')


Starting training...
Iter 0/100 train_loss:1.9087, train_acc: 0.4500, val_loss:1.9120, val_acc: 0.4667
Iter 10/100 train_loss:1.6392, train_acc: 0.5857, val_loss:1.6892, val_acc: 0.5200
Iter 20/100 train_loss:1.7689, train_acc: 0.5286, val_loss:1.7940, val_acc: 0.4767
Iter 30/100 train_loss:1.8067, train_acc: 0.5071, val_loss:1.8219, val_acc: 0.4833
Iter 40/100 train_loss:1.8210, train_acc: 0.5571, val_loss:1.8353, val_acc: 0.5000
Iter 50/100 train_loss:1.8317, train_acc: 0.5786, val_loss:1.8458, val_acc: 0.5067
Iter 60/100 train_loss:1.8327, train_acc: 0.5214, val_loss:1.8463, val_acc: 0.4900
Iter 70/100 train_loss:1.8372, train_acc: 0.5643, val_loss:1.8492, val_acc: 0.5133
Iter 80/100 train_loss:1.8373, train_acc: 0.5214, val_loss:1.8487, val_acc: 0.5033
Iter 90/100 train_loss:1.8403, train_acc: 0.5357, val_loss:1.8496, val_acc: 0.5133
Test set acc: 0.39900001883506775
CPU times: user 42.8 s, sys: 397 ms, total: 43.2 s
Wall time: 38.8 s


In [None]:
# on Tesla T4 --> 17.2 seconds