## Measuring the latent representations: activations of the message functions

In [31]:
import sys
import os

parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(parent_dir)

from model import *
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

In [2]:
#load the data
data = load_data('../simulations/datasets/r1_n=3_dim=2_nt=1000_dt=0.005')
X, y = data

#make train, test, val sets
X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=False)
X_test, X_val, y_test, y_val = train_test_split(X_test, y_test, shuffle=False)

In [3]:
cutoff = 10000
val_cutoff = 2500

#train the model on a small amount of data
model = train((X_train[:cutoff], y_train[:cutoff]), (X_val[:cutoff], y_val[:cutoff]), 25)

Epoch: 1/25: 100%|██████████| 313/313 [00:02<00:00, 156.36it/s]


training loss: 0.9727, val loss: 1.5181


Epoch: 2/25: 100%|██████████| 313/313 [00:01<00:00, 202.07it/s]


training loss: 0.7122, val loss: 1.3754


Epoch: 3/25: 100%|██████████| 313/313 [00:01<00:00, 186.89it/s]


training loss: 0.5743, val loss: 1.1518


Epoch: 4/25: 100%|██████████| 313/313 [00:01<00:00, 203.02it/s]


training loss: 0.4792, val loss: 1.0499


Epoch: 5/25: 100%|██████████| 313/313 [00:01<00:00, 205.83it/s]


training loss: 0.3847, val loss: 0.9625


Epoch: 6/25: 100%|██████████| 313/313 [00:01<00:00, 191.71it/s]


training loss: 0.3622, val loss: 0.8016


Epoch: 7/25: 100%|██████████| 313/313 [00:01<00:00, 178.80it/s]


training loss: 0.3206, val loss: 0.7420


Epoch: 8/25: 100%|██████████| 313/313 [00:01<00:00, 196.12it/s]


training loss: 0.3023, val loss: 0.8440


Epoch: 9/25: 100%|██████████| 313/313 [00:01<00:00, 178.63it/s]


training loss: 0.3029, val loss: 0.7552


Epoch: 10/25: 100%|██████████| 313/313 [00:01<00:00, 187.06it/s]


training loss: 0.2748, val loss: 0.8685


Epoch: 11/25: 100%|██████████| 313/313 [00:02<00:00, 126.02it/s]


training loss: 0.2460, val loss: 0.7322


Epoch: 12/25: 100%|██████████| 313/313 [00:01<00:00, 163.88it/s]


training loss: 0.2496, val loss: 0.6599


Epoch: 13/25: 100%|██████████| 313/313 [00:01<00:00, 169.97it/s]


training loss: 0.2339, val loss: 0.6177


Epoch: 14/25: 100%|██████████| 313/313 [00:01<00:00, 161.85it/s]


training loss: 0.2524, val loss: 0.6608


Epoch: 15/25: 100%|██████████| 313/313 [00:02<00:00, 154.32it/s]


training loss: 0.2134, val loss: 0.5945


Epoch: 16/25: 100%|██████████| 313/313 [00:01<00:00, 157.72it/s]


training loss: 0.2077, val loss: 0.6185


Epoch: 17/25: 100%|██████████| 313/313 [00:01<00:00, 163.96it/s]


training loss: 0.2447, val loss: 0.6241


Epoch: 18/25: 100%|██████████| 313/313 [00:01<00:00, 157.75it/s]


training loss: 0.2056, val loss: 0.5874


Epoch: 19/25: 100%|██████████| 313/313 [00:01<00:00, 161.90it/s]


training loss: 0.2052, val loss: 0.6562


Epoch: 20/25: 100%|██████████| 313/313 [00:01<00:00, 158.93it/s]


training loss: 0.2325, val loss: 0.6317


Epoch: 21/25: 100%|██████████| 313/313 [00:02<00:00, 149.31it/s]


training loss: 0.1841, val loss: 0.6637


Epoch: 22/25: 100%|██████████| 313/313 [00:02<00:00, 134.11it/s]


training loss: 0.1972, val loss: 0.6327


Epoch: 23/25: 100%|██████████| 313/313 [00:01<00:00, 163.86it/s]


training loss: 0.1802, val loss: 0.5610


Epoch: 24/25: 100%|██████████| 313/313 [00:01<00:00, 168.50it/s]


training loss: 0.1983, val loss: 0.5488


Epoch: 25/25: 100%|██████████| 313/313 [00:01<00:00, 167.02it/s]


training loss: 0.1872, val loss: 0.5762


In [None]:
def message_features(test_data, model):
    with torch.no_grad():
        batches, nodes, _ = test_data.size()
        message_dim = 100 #standard message size 
        messages = torch.zeros(batches, nodes, nodes, message_dim)
        for batch in range(batches):
            for i in range(nodes):
                for j in range(nodes):
                    if i != j:
                        x_i = test_data[batch, i].unsqueeze(0)
                        x_j = test_data[batch, j].unsqueeze(0)
                        msg = model.message(x_i, x_j)

                        messages[batch, i, j] = msg
    return messages

In [None]:
wee = message_features(X_test[:500], model)

In [20]:
wee[:,0,0].shape

torch.Size([500, 100])

In [38]:
def get_message_info(model, input_data, batch_size=32):
    """
    Analyzes the messages passed between nodes in the NBodyGNN model.
    
    Args:
        model (NBodyGNN): Trained model instance
        input_data (torch.Tensor): Input data with shape [no_timesteps, no_nodes, node_features]
        batch_size (int): Size of batches to process
        
    Returns:
        pd.DataFrame: DataFrame containing node features and message information
    """

    model.eval()  # Set to evaluation mode
    edge_index = get_edge_index(input_data.shape[1])
    
    # Create dataloader for batch processing
    dataset = TensorDataset(input_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    all_messages = []
    
    with torch.no_grad():
        for (nodes,) in dataloader:
            # Get source and target node features for each edge
            s1 = nodes[:, edge_index[0]]  # Source nodes
            s2 = nodes[:, edge_index[1]]  # Target nodes
            
            # Concatenate features and pass through edge model
            edge_features = torch.cat((s1, s2), dim=-1)
            messages = model.edge_model(edge_features)
            
            # Combine node features with messages
            batch_messages = torch.cat((s1, s2, messages), dim=-1)
            
            # Reshape batch_messages to be 2D: [batch_size * num_edges, features]
            batch_messages = batch_messages.reshape(-1, batch_messages.shape[-1])
            all_messages.append(batch_messages)
    
    # Combine all batches
    all_messages = torch.cat(all_messages, dim=0)
    
    # Convert to numpy for DataFrame creation
    all_messages = all_messages.numpy()
    
    # Create column names
    node_features = ['x', 'y', 'vx', 'vy', 'q', 'm']  # Based on your 2D implementation
    source_cols = [f'{f}1' for f in node_features]
    target_cols = [f'{f}2' for f in node_features]
    message_cols = [f'e{i}' for i in range(messages.shape[-1])]
    columns = source_cols + target_cols + message_cols
    
    # Create DataFrame
    msg_info = pd.DataFrame(all_messages, columns=columns)
    
    # Calculate physical quantities
    msg_info['dx'] = msg_info.x1 - msg_info.x2
    msg_info['dy'] = msg_info.y1 - msg_info.y2
    msg_info['r'] = np.sqrt(msg_info.dx**2 + msg_info.dy**2)
    
    # Calculate relative velocities
    msg_info['dvx'] = msg_info.vx1 - msg_info.vx2
    msg_info['dvy'] = msg_info.vy1 - msg_info.vy2
    msg_info['v_rel'] = np.sqrt(msg_info.dvx**2 + msg_info.dvy**2)
    
    return msg_info

In [39]:
get_message_info(model, X_test[:500])

Unnamed: 0,x1,y1,vx1,vy1,q1,m1,x2,y2,vx2,vy2,...,e96,e97,e98,e99,dx,dy,r,dvx,dvy,v_rel
0,1.108470,-3.027177,-0.107625,-0.693829,-0.871233,1.801894,1.019855,-3.755330,0.474216,-1.642119,...,0.032894,0.065366,0.066830,0.357321,0.088615,0.728153,0.733526,-0.581841,0.948290,1.112561
1,1.108470,-3.027177,-0.107625,-0.693829,-0.871233,1.801894,2.941403,-0.552212,2.070446,-1.332116,...,-0.015355,0.007645,-0.071077,0.153527,-1.832933,-2.474966,3.079789,-2.178071,0.638288,2.269670
2,1.019855,-3.755330,0.474216,-1.642119,-1.288934,1.142944,1.108470,-3.027177,-0.107625,-0.693829,...,0.131412,0.010498,-0.475665,0.393173,-0.088615,-0.728153,0.733526,0.581841,-0.948290,1.112561
3,1.019855,-3.755330,0.474216,-1.642119,-1.288934,1.142944,2.941403,-0.552212,2.070446,-1.332116,...,-0.018540,0.012951,-0.075058,0.150679,-1.921548,-3.203119,3.735280,-1.596229,-0.310002,1.626053
4,2.941403,-0.552212,2.070446,-1.332116,0.595053,0.623231,1.108470,-3.027177,-0.107625,-0.693829,...,0.015953,0.131964,-0.052329,0.305343,1.832933,2.474966,3.079789,2.178071,-0.638288,2.269670
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2995,1.490004,0.604459,0.959115,1.252741,0.495583,1.296262,0.835605,0.783813,-0.987951,-0.009257,...,0.146489,-0.021632,-0.271225,0.310987,0.654399,-0.179354,0.678532,1.947065,1.261998,2.320281
2996,-1.636888,1.912824,-0.466934,0.984240,1.361227,0.673946,1.490004,0.604459,0.959115,1.252741,...,-0.052281,0.022073,-0.054082,0.103952,-3.126891,1.308365,3.389582,-1.426049,-0.268502,1.451106
2997,-1.636888,1.912824,-0.466934,0.984240,1.361227,0.673946,0.835605,0.783813,-0.987951,-0.009257,...,-0.056636,0.032347,-0.050118,0.102044,-2.472492,1.129011,2.718066,0.521016,0.993497,1.121826
2998,0.835605,0.783813,-0.987951,-0.009257,-1.039976,1.101690,1.490004,0.604459,0.959115,1.252741,...,-0.114410,0.193582,0.061398,0.211932,-0.654399,0.179354,0.678532,-1.947065,-1.261998,2.320281
