In [1]:
# NN related libraries
import torch 
import torch.nn as nn 
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter

import sys 
sys.path.append('./')

# from the code 
from model.GAT import GAT
from utils.layers import GAT_layer

# data related 
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

import time
import os



In [2]:
dataset = Planetoid(root='/tmp/Cora', name='Cora',transform=T.NormalizeFeatures())
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Dataset: Cora():
Number of graphs: 1
Number of features: 1433
Number of classes: 7


In [3]:
edge_index=dataset[0].edge_index
nodes_features=dataset[0].x
nodes_labels=dataset[0].y

#parameters_GAT_network={'num_features_per_layer'={}}



## Training the network in the cora dataset

### First we define the hyperparameters of the network

In [4]:
C=7 # number of classes of the cora dataset
params_network={'num_layers':2,
               'num_nodes':nodes_features.shape[0],
                'num_features_per_layer':[nodes_features.shape[1],8,C],
                'num_heads_per_layer':[8,1],
                 'num_epochs':500
               }


### Divide the dataset in training, validation and test

In [5]:
# indices of each set according to the masks given in the dataset (i.e we use the same assignation as in the original paper)
training_set_indices=(dataset[0].train_mask).nonzero(as_tuple=False).flatten()
test_set_indices=(dataset[0].test_mask).nonzero(as_tuple=False).flatten()
val_set_indices=(dataset[0].val_mask).nonzero(as_tuple=False).flatten()
print('The training dataset starts in node {:} and comprises {:} nodes'.format(training_set_indices[0].numpy(),training_set_indices.shape[0]))
print('The validation dataset starts in node {:} and comprises {:} nodes'.format(test_set_indices[0].numpy(),test_set_indices.shape[0]))
print('The test dataset starts in node {:} and comprises {:} nodes'.format(val_set_indices[0].numpy(),val_set_indices.shape[0]))


The training dataset starts in node 0 and comprises 140 nodes
The validation dataset starts in node 1708 and comprises 1000 nodes
The test dataset starts in node 140 and comprises 500 nodes


In [6]:
#Extract the labels for the training set
nodes_labels_training_set=nodes_labels.index_select(0,training_set_indices)
#validation
nodes_labels_validation_set=nodes_labels.index_select(0,val_set_indices)
#test
nodes_labels_test_set=nodes_labels.index_select(0,test_set_indices)

Now we have everything we need in order to start the training process. Let's define the model and run the learning process

In [7]:
# Let's run the training loop 
def train_gat(params_network,num_epochs=40,val_lapse=50,perform_test='True',starting_epoch=0,):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #device=torch.device('mps')
    time_start=time.time()
    model_gat=GAT(params_network['num_layers'],params_network['num_nodes'],params_network['num_features_per_layer'],params_network['num_heads_per_layer']).to(device)
    
    loss_fn = nn.CrossEntropyLoss(reduction='mean').to(device)
    optimizer = Adam(model_gat.parameters(), lr=5e-3,weight_decay=5e-4) # weight decay corresponds to the L2 penalty, which in the original implementation is chosen to the value we put here
    
    #nodes_features_dev=nodes_features.to(device)
    #edge_index_dev=edge_index.to(device)
    graph_data=(nodes_features,edge_index)
    
    #nodes_labels_training_set_dev=nodes_labels_training_set.to(device)
    #nodes_labels_validation_set_dev=nodes_labels_validation_set.to(device)
    #nodes_labels_test_set_dev=nodes_labels_test_set.to(device)
    
    #training_set_indices_dev=training_set_indices.to(device)
    #test_set_indices_dev=test_set_indices.to(device)
    #val_set_indices_dev=val_set_indices.to(device)
    
    if starting_epoch!=0:
        trained=torch.load(os.path.join(os.path.dirname('model'), 'model','saved_model','model_gat_trained.pt'))
        model_gat.state_dict=trained['state_dict']
    #TensorBoard summary writter 
    writer=SummaryWriter()
    
    for epoch in range(starting_epoch,num_epochs):
        
        #print(epoch)
        
        model_gat.train() #set model in training mode
        
        #print(next(model_gat.parameters()).device)
        # We do a forward pass of the model and extract the unnormalized logits for the training set 
        # shape = (N, C) where N is the number of nodes in the split (train/val/test) and C is the number of classes
        nodes_unnormalized_out_train = model_gat(graph_data)[0].index_select(0,training_set_indices)
        
        #print(nodes_unnormalized_out_train.device)
        #print(nodes_unnormalized_out_train.shape)
        loss=loss_fn(nodes_unnormalized_out_train,nodes_labels_training_set)
        
        #Optimizer backward evaluation
        
        optimizer.zero_grad()  
        loss.backward()  
        optimizer.step()  
        
        # Compute the accuracy

        # Finds the index of maximum (unnormalized) score for every node and that's the class prediction for that node.
        # Compare those to true (ground truth) labels and find the fraction of correct predictions -> accuracy metric.
        predictions = torch.argmax(nodes_unnormalized_out_train, dim=-1)
        accuracy = torch.sum(torch.eq(predictions, nodes_labels_training_set).long()).item() / len(nodes_labels_training_set)
        
        
        
        writer.add_scalar('Loss/train',loss.item(),epoch)
        writer.add_scalar('Accuracy/train',accuracy,epoch)
        
        #print(f'time elapsed={(time.time()-time_start):.2f} [s]')
        #print(f'accuracy train={accuracy:.3f}')
        if (epoch+1)%val_lapse==0:
            with torch.no_grad():
                model_gat.eval()
                nodes_unnormalized_out_val = model_gat(graph_data)[0].index_select(0,val_set_indices)
                loss_val=loss_fn(nodes_unnormalized_out_val,nodes_labels_validation_set)
                predictions = torch.argmax(nodes_unnormalized_out_val, dim=-1)
                accuracy = torch.sum(torch.eq(predictions, nodes_labels_validation_set).long()).item() / len(nodes_labels_validation_set)
        
                
                writer.add_scalar('Loss/validation',loss_val.item(),epoch)
                writer.add_scalar('Accuracy/validation',accuracy,epoch)
                print(f'GAT training: time elapsed= {(time.time() - time_start):.2f} [s] | epoch={epoch + 1} | val acc={accuracy}')
            torch.save(model_gat,os.path.join(os.path.dirname('model'), 'model','saved_model',f'model_gat_trained_epoch_{epoch+1}_.pt'))
    
    if perform_test:
        with torch.no_grad():
            model_gat.eval()
            nodes_unnormalized_out_test = model_gat(graph_data)[0].index_select(0,test_set_indices)
            loss_test=loss_fn(nodes_unnormalized_out_test,nodes_labels_test_set)
            predictions = torch.argmax(nodes_unnormalized_out_test, dim=-1)
            accuracy = torch.sum(torch.eq(predictions, nodes_labels_test_set).long()).item() / len(nodes_labels_test_set)
        print(f'Test accuracy = {accuracy}')
    torch.save(model_gat,os.path.join(os.path.dirname('model'), 'model','saved_model','model_gat_trained.pt'))
    writer.close()
        
        

        


        
        
    
    
    
    
    

In [8]:
train_gat(params_network,starting_epoch=0,num_epochs=10000,val_lapse=100)

GAT training: time elapsed= 3.88 [s] | epoch=100 | val acc=0.496
GAT training: time elapsed= 7.75 [s] | epoch=200 | val acc=0.518
GAT training: time elapsed= 12.13 [s] | epoch=300 | val acc=0.726
GAT training: time elapsed= 16.18 [s] | epoch=400 | val acc=0.744
GAT training: time elapsed= 20.34 [s] | epoch=500 | val acc=0.734
GAT training: time elapsed= 24.65 [s] | epoch=600 | val acc=0.726
GAT training: time elapsed= 29.07 [s] | epoch=700 | val acc=0.732
GAT training: time elapsed= 34.09 [s] | epoch=800 | val acc=0.716
GAT training: time elapsed= 38.19 [s] | epoch=900 | val acc=0.736
GAT training: time elapsed= 42.21 [s] | epoch=1000 | val acc=0.722
GAT training: time elapsed= 46.39 [s] | epoch=1100 | val acc=0.72
GAT training: time elapsed= 50.61 [s] | epoch=1200 | val acc=0.722
GAT training: time elapsed= 54.42 [s] | epoch=1300 | val acc=0.734
GAT training: time elapsed= 58.74 [s] | epoch=1400 | val acc=0.73
GAT training: time elapsed= 62.54 [s] | epoch=1500 | val acc=0.73
GAT train