In [119]:
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F

from data_utils import *
from utils import *

from sklearn.cross_decomposition import CCA
from scipy.stats import pearsonr

import networkx as nx
from networkx.algorithms import community
from networkx.algorithms.components import connected_components

## Define The Correlation Model

In [135]:
#Define the correlation model

class MagnitudeCorrelation(nn.Module):
    def __init__(self):
        super(MagnitudeCorrelation, self).__init__()
        
    def node2edge_corr(self, inputs, rel_rec, rel_send):
        x = inputs.view(inputs.size(0), inputs.size(1), -1)
        #shape: [batch_size, num_atoms, num_timesteps*num_features]
        receivers = torch.matmul(rel_rec, x)
        receivers = receivers.view(inputs.size(0)*receivers.size(1),
                                  inputs.size(2),inputs.size(3))
        receivers = receivers.transpose(2,1)
        #shape: [batch_size*num_edges, num_features, num_timesteps]
        senders = torch.matmul(rel_send, x)
        senders = senders.view(inputs.size(0)*senders.size(1),
                              inputs.size(2), inputs.size(3))
        senders = senders.transpose(2,1)
        #shape: [batch_size*num_edges, num_features, num_timesteps]
        
        receivers_norm = torch.norm(receivers, dim=1).cpu().detach()
        senders_norm = torch.norm(senders, dim=1).cpu().detach()
        
        
        corrs = torch.tensor([pearsonr(senders_norm[i], receivers_norm[i])[0] for i in range(senders_norm.size(0))]).float()
            
        corrs = corrs.view(inputs.size(0),-1)
        
        corrs = torch.diag_embed(corrs)
        corrs_nodes = torch.matmul(rel_send.t(), torch.matmul(corrs, rel_rec))
        return corrs_nodes
        
    
    def forward(self, inputs, rel_rec, rel_send):
        """
        args:
          inputs: [batch_size, num_atoms, num_timesteps, num_edges]
        """
        corrs_nodes = self.node2edge_corr(inputs, rel_rec, rel_send)
        #shape: [batch_size, num_atoms, num_atoms]
        return corrs_nodes

## Group Spring Simulation

In [136]:
num_atoms = 5
rel_rec_sl, rel_send_sl = create_edgeNode_relation(num_atoms, self_loops=True)
rel_rec, rel_send = create_edgeNode_relation(num_atoms, self_loops=False)

### Load Data

In [137]:
#Load Group Spring Simulation Data

suffix = "_static_5"

train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_spring_sim(64, suffix, 
                                                                                             normalize=False)
data = test_loader.dataset[:][0]
labels = test_loader.dataset[:][1]
#data = data.unsqueeze(0)
print("Data shape: ",data.size())
print("Labels shape: ",labels.size())

labels_diag = torch.diag_embed(labels)
gr_labels = torch.matmul(rel_send.t(), 
                         torch.matmul(labels_diag.float(), rel_rec))
print("Group relations shape: " ,gr_labels.size())
gr_labels_numpy = gr_labels.cpu().detach().numpy()
print(gr_labels_numpy.shape)

#process Labels
gr_labels_clusters = []
for i in range(gr_labels_numpy.shape[0]):
    gi = nx.from_numpy_array(gr_labels_numpy[i])
    com = list(connected_components(gi))
    com = [list(c) for c in com]
    gr_labels_clusters.append(com)

Data shape:  torch.Size([200, 5, 49, 4])
Labels shape:  torch.Size([200, 20])
Group relations shape:  torch.Size([200, 5, 5])
(200, 5, 5)


### Graph Clustering with Louvain Algorithm

In [139]:
from cdlib import algorithms
import networkx as nx
from networkx.algorithms import community
from networkx.algorithms.components import connected_components

In [140]:
magnitude_correlation = MagnitudeCorrelation()
magnitude_correlation.eval()
with torch.no_grad():
    correlations = magnitude_correlation(data, rel_rec, rel_send)

correlations = symmetrize(correlations)
correlations = (correlations > 0.5).float()
correlations_numpy = correlations.cpu().detach().numpy()

In [144]:
#Compute Group Mitre
precision_test = []
recall_test = []
F1_test = []

for i in range(correlations_numpy.shape[0]):
    Ai = correlations_numpy[i]
    Ai = nx.from_numpy_array(Ai)
    coms = algorithms.louvain(Ai)
    communities = coms.communities
    
    recall, precision, F1 = compute_groupMitre(gr_labels_clusters[i], communities)
    
    recall_test.append(recall)
    precision_test.append(precision)
    F1_test.append(F1)
    

precision_mean = np.mean(precision_test)
recall_mean = np.mean(recall_test)
F1_mean = np.mean(F1_test)

print("Average Precision: ", precision_mean)
print("Average Recall: ", recall_mean)
print("Average F1 Score: ", F1_mean)

Average Precision:  0.18458333333333332
Average Recall:  0.4233333333333333
Average F1 Score:  0.21575901875901876
