In [23]:
import numpy as np
import torch
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import seaborn as sns

import sys
import os
sys.path.insert(0, os.path.abspath('../../code/'))
import utils 
from utils.functional_graphs import create_graph_dataset
from utils.models import Attention2Conv, Attention3Conv

In [2]:
fcs = np.load('../../local/fcs/combined/sparse_fcs_1316_subjects.npy')
fc = fcs[:1]
del fcs
print(fc.shape)
labels = np.load('../../local/gender_labels/combined_gender_labels_1316_subjects.npy')
label = labels[:1]
print(label)

(1, 100, 100)
[0]


In [12]:
root = '../../local/graph_datasets/single_graph/'
dataset = create_graph_dataset(sparse_fcs=fc, root=root, labels=label)
graph = dataset[0]
print()
print(f'First Graph: {graph}')
print('=============================================================')
print(f'Number of nodes: {graph.num_nodes}')
print(f'Number of edges: {graph.num_edges}')
print(f'Average node degree: {graph.num_edges / graph.num_nodes:.2f}')
print(f'Has isolated nodes: {graph.has_isolated_nodes()}')
print(f'Has self-loops: {graph.has_self_loops()}')
print(f'Is undirected: {graph.is_undirected()}')


First Graph: Data(x=[100, 8], edge_index=[2, 3614], edge_attr=[3614], y=[1])
Number of nodes: 100
Number of edges: 3614
Average node degree: 36.14
Has isolated nodes: False
Has self-loops: True
Is undirected: False


In [13]:
def get_node_importance(model, graph):
    model.eval()
    graph = graph.to(device)
    _, attn_scores = model(graph.x, graph.edge_index, graph.batch)
    return attn_scores.cpu().detach().numpy()

In [20]:
path = 'results/fc_data/models/2conv/'
models = [path + file for file in os.listdir(path)]
print(f'{len(models)} models')
device = torch.device('cpu')
for m in models:
    model = Attention2Conv(num_node_features=dataset.num_features, hidden_channels=128, dropout_rate=0.5, edge_dropout_rate=0.1)
    model.load_state_dict(torch.load(m, weights_only=False))
    
    node_importance = get_node_importance(model, graph)
    importance_factor = np.concatenate([np.expand_dims(range(100),axis=0), node_importance.T])
    indices = np.argsort(importance_factor[1])[::-1]
    importance_factor = importance_factor[:, indices]

    print(m+'\n', f'Nodes: {np.round(importance_factor[0,:5], 2)}','\n')

4 models
results/fc_data/models/2conv/fc2c_model_p10.5_p20.2
 Nodes: [12. 19. 34. 27. 42.] 

results/fc_data/models/2conv/fc2c_model_p10.5_p20.1
 Nodes: [80. 12. 42. 94. 33.] 

results/fc_data/models/2conv/fc2c_model_p10.3_p20.1
 Nodes: [19. 26. 27. 80. 94.] 

results/fc_data/models/2conv/fc2c_model_p10.3_p20.2
 Nodes: [19. 27. 26.  3. 42.] 



In [24]:
path = 'results/fc_data/models/3conv/'
models = [path + file for file in os.listdir(path)]
print(f'{len(models)} models')
device = torch.device('cpu')
for m in models:
    model = Attention3Conv(num_node_features=dataset.num_features, hidden_channels=128, dropout_rate=0.5, edge_dropout_rate=0.1)
    model.load_state_dict(torch.load(m, weights_only=False))
    
    node_importance = get_node_importance(model, graph)
    importance_factor = np.concatenate([np.expand_dims(range(100),axis=0), node_importance.T])
    indices = np.argsort(importance_factor[1])[::-1]
    importance_factor = importance_factor[:, indices]

    print(m+'\n', f'Nodes: {np.round(importance_factor[0,:5], 2)}','\n')

4 models
results/fc_data/models/3conv/fc3c_model_p10.5_p20.2
 Nodes: [ 5.  4. 27. 26. 19.] 

results/fc_data/models/3conv/fc3c_model_p10.3_p20.2
 Nodes: [ 5. 12.  4. 33. 15.] 

results/fc_data/models/3conv/fc3c_model_p10.5_p20.1
 Nodes: [ 5.  4.  0. 12.  7.] 

results/fc_data/models/3conv/fc3c_model_p10.3_p20.1
 Nodes: [ 4.  5.  0. 19. 42.] 

