In [1]:
import torch
from torch import nn

from labml_helpers.module import Module

In [2]:
class GraphAttentionLayer(Module):
    """Single graph attention layer
    
    in features: is the number of input features per node,
    out features: is the number of output features per node,
    n_heads: number of attention heads
    is concat: should the output concatinated or averaged,
    dropout: prrobability of dropout,
    leaky_relu_negative_slope: negagtive slope for leaky relu activation
    
    """
    
    def __init__(self, in_features: int, out_features: int, n_heads: int, is_concat: bool = True,
                dropout: int = 0.6, leaky_relu_negative_slope: float = 0.2):
        
        super().__init__()
        self.is_concat = is_concat
        self.n_heads = n_heads
        
        if is_concat:
            assert out_features % n_heads == 0
            
            self.n_hidden = out_features // n_heads
        else:
            self.n_hidden = out_features
            
        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
        
        self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)
        
        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
        
        # to compute attention
        self.softmax = nn.Softmax(dim=1)
        
        self.dropout = nn.Dropout(dropout)
        
    def __call__(self, h: torch.Tensor, adj_mat: torch.Tensor):
        """
        adj_mat is the adjacency matrix of shape [n_nodes, n_nodes, n_heads]. 
        We use shape [n_nodes, n_nodes, 1] since the adjacency is the same for each head.
        """
        n_nodes = h.shape[0]
        # for each head we do a linear transformation and split
        g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)
        # Calculate attention score
        
        # where each node embedding is repeated n_nodes times.
        g_repeat = g.repeat(n_nodes, 1, 1)
        
        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
        
        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
        
        # Reshape
        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
        
        e = self.activation(self.attn(g_concat))
        
        # Remove last dimension
        e = e.squeeze(-1)
        
        # [n_nodes, n_nodes, n_heads] or[n_nodes, n_nodes, 1]
        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads
        
        # Mask eij based on adjacency matrix. eij is set to −∞ if there is no edge from i to j.
        e = e.masked_fill(adj_mat == 0, float('-inf'))
        
        a = self.softmax(e)
        
        a = self.dropout(a)
        
        #Calculate final output for each head
        attn_res= torch.einsum('ijh,jhf->ihf', a, g)
        
        if self.concat:
            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
        else:
            return attn_res.mean(dim=1)
        
        

In [3]:
from typing import Dict

import numpy as np
import torch
from torch import nn

from labml import lab, monit, tracker, experiment
from labml.configs import BaseConfigs, option, calculate
from labml.utils import download
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
from labml_nn.graphs.gat import GraphAttentionLayer
from labml_nn.optimizers.configs import OptimizerConfigs

In [4]:
class CoraDataset:
    """
    Cora dataset is for research papers,
    For each paper we have binary feature vector that indicates the presence of words
    Each paper is classified into one of 7 classes
    
    The papers are the nodes of the graph and the edges are the citations.
    """
    
    labels: torch.Tensor
    classes: Dict[str, int]
    features: torch.Tensor
    adj_mat: torch.Tensor
        
    @staticmethod
    def _download():
        if not (lab.get_data_path() / 'cora').exists():
            download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
                                   lab.get_data_path() / 'cora.tgz')
            download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())
    
    def __init__(self, include_edges: bool = True):
        self.include_edges = include_edges
        
        self._download()
        
        with monit.section('Read content file'):
            content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))
            
        with monit.section('Read citations file'):
            citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)
            
        features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))
        
        self.features = features / features.sum(dim=1, keepdim=True)
        
        self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}
        
        self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)
        
        paper_ids = np.array(content[:, 0], dtype=np.int32)
        
        ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}
        #Mark citations in adj matrix
        self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)
        
        if self.include_edges:
            for e in citations:
                e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]
                self.adj_mat[e1][e2] = True
                self.adj_mat[e2][e1] = True
        
        
    

In [5]:
class GAT(Module):
    
    def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
        
        super().__init__()
        self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)
        
        self.activation = nn.ELU()
        
        self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
        
        x = self.dropout(x)
        
        x = self.layer1(x, adj_mat)
        
        x = self.activation(x)
        
        x = self.dropout(x)
        
        return self.output(x, adj_mat)
        
        
def accuracy(output: torch.Tensor, labels: torch.Tensor):
    return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)


    

In [7]:
class Configs(BaseConfigs):
    model: GAT
    training_samples: int = 500
    in_features: int
    n_hidden: int = 64
    n_head: int = 8
    n_classes: int
    dropout: float = 0.6
    include_edges: bool = True
    dataset: CoraDataset
    epochs: int = 1_000
    loss_func = nn.CrossEntropyLoss()
    device: torch.device = DeviceConfigs()
    optimizer: torch.optim.Adam
        
    def run(self):
        features = self.dataset.features.to(self.device)
        
        labels = self.dataset.labels.to(self.device)
        
        edges_adj = self.dataset.adj_mat.to(self.device)
        
        edges_adj = edges_adj.unsqueeze(-1)
        
        idx_rand = torch.randperm(len(labels))
        
        idx_train = idx_rand[:self.training_samples]
        
        idx_valid = idx_rand[self.training_samples:]
        
        for epoch in monit.loop(self.epochs):
            
            self.model.train()
            
            self.optimizer.zero_grad()
            
            output = self.model(features, edges_adj)
            
            loss = self.loss_func(output[idx_train], labels[idx_train])
            
            loss.backward()
            
            self.optimizer.step()
            
            tracker.add('loss.train', loss)
            
            tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))
            
            self.model.eval()
            
            with torch.no_grad():
                output = self.model(features, edges_adj)
                
                loss = self.loss_func(output[idx_valid], labels[idx_valid])
                
                tracker.add('loss.valid', loss)
                
                tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))
                
            tracker.save()
            
@option(Configs.dataset)
def cora_dataset(c: Configs):
    return CoraDataset(c.include_edges)

In [8]:
calculate(Configs.n_classes, lambda c: len(c.dataset.classes))


<function __main__.<lambda>(c)>

In [9]:
calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])


<function __main__.<lambda>(c)>

In [10]:
@option(Configs.model)
def gat_model(c: Configs):
    return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)

In [11]:
@option(Configs.optimizer)
def _optimizer(c: Configs):
    opt_conf = OptimizerConfigs()
    opt_conf.parameters = c.model.parameters()
    return opt_conf


In [12]:
def main():
    conf = Configs()
    experiment.create(name='gat')
    experiment.configs(conf, {
        'optimizer.optimizer': 'Adam',
        'optimizer.learning_rate': 5e-3,
        'optimizer.weight_decay': 5e-4,
    })
    
    with experiment.start():
        conf.run()
        
    

In [17]:
main()

[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


AttributeError: Configs has no attribute `n_heads`