In [93]:
from typing import Tuple
import numpy as np
from scipy.sparse import base
from sklearn.metrics import cluster

import torch
import torch.nn as nn
import torch.nn.functional as F

from sknetwork.topology import get_connected_components
from cdlib import algorithms
import networkx as nx
from networkx.algorithms import community
from networkx.algorithms.components import connected_components

import matplotlib.pyplot as plt
from matplotlib import offsetbox
from sklearn import manifold
import pandas as pd
import plotly.express as px

from utils import *
from data_utils import *
from models_NRI import *
from models_traj import *
from models_clustering import *
from torch.nn.functional import pdist

## Group Spring Simulation

### Define and Load Models
load the nri, traj models

In [84]:
suffix = "_static_5"

num_atoms = 5
num_features = 4
nri_hidden = 256
edge_types = 2
use_motion = True

traj_emb = 16
traj_heads = 2
c_hidden = 64
c_out = 48
traj_latent = 32
traj_depth = 3
kernel_size = 5

rel_rec_sl, rel_send_sl = create_edgeNode_relation(num_atoms, self_loops=True)
rel_rec, rel_send = create_edgeNode_relation(num_atoms, self_loops=False)


nri_folder = "logs/pipeline"
traj_folder = "logs/pipeline"
save_folder = "logs/pipeline"

nri_encoder = ResCausalCNNEncoder(num_features, nri_hidden, edge_types,
                                 do_prob=0.,factor=True, use_motion=True)
traj_encoder = GraphTCNEncoder(num_features, traj_emb, traj_heads, c_hidden,
                              c_out, kernel_size, traj_depth, traj_latent,
                              use_motion)
traj_decoder = RNNDecoder(traj_latent, num_features, traj_emb, 64, "GRU", True)

nri_encoder_file = os.path.join(os.path.join(save_folder, "nri"), 'nri_encoder.pt')
nri_encoder.load_state_dict(torch.load(nri_encoder_file))
traj_encoder_file = os.path.join(os.path.join(save_folder, "traj"), 'traj_encoder.pt')
traj_encoder.load_state_dict(torch.load(traj_encoder_file))
traj_decoder_file = os.path.join(os.path.join(save_folder, "traj"), 'traj_decoder.pt')
traj_decoder.load_state_dict(torch.load(traj_decoder_file))

Using factor graph ResCausalCNN encoder


<All keys matched successfully>

### Load data

In [94]:
train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_spring_sim(64, suffix)
data = test_loader.dataset[:][0]
labels = test_loader.dataset[:][1]
#data = data.unsqueeze(0)
print("Data shape: ",data.size())
print("Labels shape: ",labels.size())

labels_diag = torch.diag_embed(labels)
gr_labels = torch.matmul(rel_send.t(), 
                         torch.matmul(labels_diag.float(), rel_rec))
print("Group relations shape: " ,gr_labels.size())


#create group size labels
gm_labels = gr_labels.sum(-1).unsqueeze(-1)+1
print("Group members shape: ", gm_labels.shape)

#Create Group Relation labels
gr_labels_numpy = gr_labels.cpu().detach().numpy()
gIDs = []
for i in range(gr_labels_numpy.shape[0]):
    if gr_labels_numpy[i].sum()==0:
        gIDs.append(list(range(gr_labels_numpy.shape[1])))
    else:
        gIDs.append(list(get_connected_components(gr_labels_numpy[i])))


gIDs = np.array(gIDs)

gr_labels_clusters = []
for i in range(gr_labels_numpy.shape[0]):
    gi = nx.from_numpy_array(gr_labels_numpy[i])
    com = list(connected_components(gi))
    com = [list(c) for c in com]
    gr_labels_clusters.append(com)

Data shape:  torch.Size([200, 5, 49, 4])
Labels shape:  torch.Size([200, 20])
Group relations shape:  torch.Size([200, 5, 5])
Group members shape:  torch.Size([200, 5, 1])


### Evaluate Performance of Trajectory Representation Learning Model

In [86]:
traj_encoder.eval()
data = data.contiguous()
mu, sigma = traj_encoder(data, rel_rec_sl, rel_send_sl)
latents = mu+sigma*torch.randn_like(sigma)
print(latents.size())
data_rec = traj_decoder(latents, data, teaching_rate=1.)
print("reconstruction loss: ",F.mse_loss(data[:,:,1:,:], data_rec[:,:,1:,:]))

torch.Size([200, 5, 32])
reconstruction loss:  tensor(0.0031, grad_fn=<MseLossBackward0>)


### T-SNE Visualization of Raw Data for Group Sizes

In [87]:
data.size()
data = data.contiguous()
data_reshape = data.view(data.size(0)*data.size(1), data.size(2)*data.size(-1))
data_reshape.size()
data_reshape_numpy = data_reshape.cpu().detach().numpy()
tsne = manifold.TSNE(n_components=2, init="pca", random_state=1)

data_tsne = tsne.fit_transform(data_reshape_numpy)
print(data_tsne.shape)
data_tsne_df = pd.DataFrame(data_tsne)
gm_labels_numpy = gm_labels.detach().numpy()
gm_labels_numpy = gm_labels_numpy.reshape(-1, gm_labels_numpy.shape[-1])
data_tsne_df["label"] = gm_labels_numpy

px.scatter(data_tsne_df, x=0, y=1, color=data_tsne_df.label.astype(str), opacity=0.7)

(1000, 2)


### T-SNE Visualization of Raw Data for Group Relations

In [92]:
data_tsne_g = data_tsne_df[5:10]
data_tsne_g["label"] = gIDs[1]
data_tsne_g

px.scatter(data_tsne_g, x=0, y=1, color=latents_tsne_g.label.astype(str), opacity=0.7)

### T-SNE Visualization of Latent Variables for Group Sizes

In [89]:
latents_numpy = latents.cpu().detach().numpy()
latents_numpy = latents_numpy.reshape(-1, latents_numpy.shape[-1])
print(latents_numpy.shape)
gm_labels_numpy = gm_labels.detach().numpy()
gm_labels_numpy = gm_labels_numpy.reshape(-1, gm_labels_numpy.shape[-1])
print(gm_labels_numpy.shape)

tsne = manifold.TSNE(n_components=2, init="pca", random_state=1)

latents_tsne = tsne.fit_transform(latents_numpy)
print(latents_tsne.shape)
latents_tsne_df = pd.DataFrame(latents_tsne)
latents_tsne_df["label"] = gm_labels_numpy

px.scatter(latents_tsne_df, x=0, y=1, color=latents_tsne_df.label.astype(str), opacity=0.7)

(1000, 32)
(1000, 1)
(1000, 2)


### T-SNE Visualization of Latent Variables for Group Relations

In [91]:
latents_tsne_g = latents_tsne_df[5:10]
latents_tsne_g["label"] = gIDs[1]
latents_tsne_g

px.scatter(latents_tsne_g, x=0, y=1, color=latents_tsne_g.label.astype(str), opacity=0.7)

### Test of Louvain Algorithm based on NRI Interaction Graphs

In [95]:
nri_encoder.eval()
edges = nri_encoder(data, rel_rec, rel_send)
edges = F.softmax(edges, dim=-1)
interaction = 1-edges[:,:,0]
interaction = torch.matmul(rel_send.t(), torch.matmul(torch.diag_embed(interaction), rel_rec))
interaction = symmetrize(interaction)
interaction = (interaction > 0.5).float()
interaction_numpy = interaction.cpu().detach().numpy()

precision_test = []
recall_test = []
F1_test = []

for i in range(interaction_numpy.shape[0]):
    Ai = interaction_numpy[i]
    Ai = nx.from_numpy_array(Ai)
    coms = algorithms.louvain(Ai)
    communities = coms.communities
    
    recall, precision, F1 = compute_groupMitre(gr_labels_clusters[i], communities)
    
    recall_test.append(recall)
    precision_test.append(precision)
    F1_test.append(F1)
    

precision_mean = np.mean(precision_test)
recall_mean = np.mean(recall_test)
F1_mean = np.mean(F1_test)

print(precision_mean)
print(recall_mean)
print(F1_mean)

0.99625
0.9958333333333332
0.9958571428571429


## Test of GCN Clustering

In [97]:
import networkx as nx
from networkx.algorithms import community
from networkx.algorithms.components import connected_components

In [98]:
# Load GCNEncoder
gcn_hid = 24
gcn_out = 16
n_clusters = 8
gcn_encoder = GCNEncoder(traj_latent, gcn_hid, gcn_out, n_clusters)
gcn_encoder_file = os.path.join(os.path.join(save_folder, "gcn"), 'gcn_encoder.pt')
gcn_encoder.load_state_dict(torch.load(gcn_encoder_file))

<All keys matched successfully>

In [99]:
#get interaction matrix
nri_encoder.eval()
edges = nri_encoder(data, rel_rec, rel_send)
edges = F.softmax(edges, dim=-1)
A = 1-edges[:,:,0]
A = torch.matmul(rel_send.t(), torch.matmul(torch.diag_embed(A), rel_rec))
A = symmetrize(A)
A = (A>0.5).float()
A_norm = normalize_graph(A, add_self_loops=False)

#get trajectory representation
traj_encoder.eval()
mu, sigma = traj_encoder(data, rel_rec_sl, rel_send_sl)
X = mu+sigma*torch.randn_like(sigma)

#get representation by GCN
Z = gcn_encoder(A_norm, X)


# Evaluation of Clustering performance
precision_test = []
recall_test = []
F1_test = []

for i in range(Z_numpy.shape[0]):
    Zi = Z_numpy[i]
    clustering = DBSCAN(eps=0.2, min_samples=1, metric="l1").fit(Zi)
    predicted_labels = clustering.labels_
    target = gIDs[i]
    recall, precision, F1 = compute_groupMitre_labels(target, predicted_labels)
    recall_test.append(recall)
    precision_test.append(precision)
    F1_test.append(F1)
    
print("Average precision: ", np.mean(precision_test))
print("Average recall: ", np.mean(recall_test))
print("Average F1 Score: ", np.mean(F1_test))

Average precision:  0.8393333333333334
Average recall:  0.8401666666666665
Average F1 Score:  0.8265515873015873
