# GNNs

## mass into account

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from torch_geometric.nn import radius_graph
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from datagen import generate_dataset
from tqdm import tqdm


def fully_connected(input_dim, hidden_dim, output_dim, num_layers):
    layers = []
    # Add hidden layers
    for _ in range(num_layers):
        layers.append(torch.nn.Linear(input_dim, hidden_dim))
        layers.append(torch.nn.ReLU())
        input_dim = hidden_dim
    # Add output layer
    layers.append(torch.nn.Linear(hidden_dim, output_dim))
    return torch.nn.Sequential(*layers)

# Function to create edge weights based on mass difference
def create_edge_weights_based_on_mass(data):
    row, col = data.edge_index  # Get indices of connected nodes
    # Get masses for each node in the connection
    pos_row = data.x[row][:, :3]
    pos_col = data.x[col][:, :3]
    distance = torch.norm(pos_row - pos_col, dim=1)  # Calculate distance between nodes
    data.edge_weight = 1 / distance  # Assign edge weight based on inverse of distance
    
    return data

class ParticleGNN(torch.nn.Module):
    def __init__(self, input_dim, fc_dim, fc_layers, gnn_dim, message_passing_steps, output_dim):
        super(ParticleGNN, self).__init__()
        # Encoder for initial node features
        self.encoder = fully_connected(input_dim, fc_dim, gnn_dim, fc_layers)
        
        # GCN layers for message passing
        self.gnns = []
        for i in range(message_passing_steps):
            gnn = GCNConv(gnn_dim, gnn_dim)
            setattr(self, f'gnn_{i}', gnn)
            self.gnns.append(gnn)
        
        # Decoder to transform final node features to output
        self.decoder = fully_connected(gnn_dim, fc_dim, output_dim, fc_layers)

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
        x = self.encoder(x)
        
        for gnn in self.gnns:
            x = F.relu(gnn(x, edge_index, edge_weight=edge_weight))
        
        x = self.decoder(x)
        return x



# Function to convert particles to a torch_geometric Data object
def transform_particles_to_graph_with_radius_and_mass_edge_weights(features, positions, radius):
    # Extract features for each particle: [x, y, z, vx, vy, vz, mass]

    
    # Create edges based on radius
    edge_index = radius_graph(positions, r=radius)
    # for every node, connect it to top 2 biggest nodes

    biggest_nodes_index = torch.argsort(features[:, -1], descending=True)[:2]
    for i in range(features.shape[0]):
        if i not in biggest_nodes_index:
            edge_index = torch.cat((edge_index, torch.tensor([[i, biggest_nodes_index[0]], [i, biggest_nodes_index[1]]], dtype=torch.long)), 1)
            edge_index = torch.cat((edge_index, torch.tensor([[biggest_nodes_index[0], i], [biggest_nodes_index[1], i]], dtype=torch.long)), 1)



    
    # Create data object
    graph_data = Data(x=features, edge_index=edge_index)
    
    # Add edge weights based on mass
    graph_data = create_edge_weights_based_on_mass(graph_data)
    return graph_data



  


In [54]:
def euclidean_distance(a, b):
    return torch.sqrt(torch.sum((a - b)**2, dim=-1) + 1e-12)
def mean_distance(a, b):
    return torch.mean(euclidean_distance(a, b))

In [55]:
import torch
def generate_graph_dataset(data, radius):
    graphs = []
    for i in range(len(data)):
        masses = torch.tensor(data[i]['masses']).unsqueeze(-1)
        positions = torch.tensor(data[i]['pos'])
        velocities = torch.tensor(data[i]['vel'])
        Gs = torch.ones_like(masses) 
        softening = torch.ones_like(masses) * 0.1
        features = torch.cat([positions, velocities, Gs, softening, masses], dim=-1)
        graph_data = transform_particles_to_graph_with_radius_and_mass_edge_weights(features, positions, radius)
        graph_data.y = torch.tensor(data[i]['acc'])
        graphs.append(graph_data)
    return graphs

In [56]:
# Initialize model and optimizer
model = ParticleGNN(input_dim=9, fc_dim=128, fc_layers=1, gnn_dim=128, message_passing_steps=5, output_dim=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [57]:

ROUNDS = 50
EPOCHS = 4
SCENES = 25
WINDOW_SIZE = 0
RADIUS = 2.0


# Training loop
model.train()
for r in range(ROUNDS):
    print(f'--- Round {r} ---')
    data = generate_dataset(n_scenes=SCENES,window_size=WINDOW_SIZE)
    graphs = generate_graph_dataset(data, RADIUS)
    dataloader = DataLoader(graphs, batch_size=1)
    for epoch in range(EPOCHS):
        epoch_losses = []
        for data in dataloader:
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)
            loss = mean_distance(out, data.y)
            loss.backward()
            optimizer.step()
            epoch_losses.append(loss.item())
        print(f'Epoch {epoch+1}, Loss: {torch.tensor(epoch_losses).mean()}')
    del data, graphs, dataloader
    torch.cuda.empty_cache()
    torch.save(model.state_dict(), f'./models/particle_gnn_{r}.pth')
    







--- Round 0 ---
Generating dataset with 25 scenes...


100%|██████████| 25/25 [07:52<00:00, 18.90s/it]


Epoch 1, Loss: 2.5632526874542236
Epoch 2, Loss: 2.428107261657715
Epoch 3, Loss: 2.373164415359497
Epoch 4, Loss: 2.3450422286987305
--- Round 1 ---
Generating dataset with 25 scenes...


100%|██████████| 25/25 [08:02<00:00, 19.29s/it]


Epoch 1, Loss: 2.265080690383911
Epoch 2, Loss: 2.167938709259033
Epoch 3, Loss: 2.134875774383545
Epoch 4, Loss: 2.1150715351104736
--- Round 2 ---
Generating dataset with 25 scenes...


100%|██████████| 25/25 [08:08<00:00, 19.55s/it]


Epoch 1, Loss: 3.04636549949646
Epoch 2, Loss: 2.944793701171875
Epoch 3, Loss: 2.907914638519287
Epoch 4, Loss: 2.8868069648742676
--- Round 3 ---
Generating dataset with 25 scenes...


100%|██████████| 25/25 [08:02<00:00, 19.32s/it]


Epoch 1, Loss: 2.735196590423584
Epoch 2, Loss: 2.6247127056121826
Epoch 3, Loss: 2.5868544578552246
Epoch 4, Loss: 2.564666271209717
--- Round 4 ---
Generating dataset with 25 scenes...


100%|██████████| 25/25 [08:05<00:00, 19.43s/it]


Epoch 1, Loss: 2.3206770420074463
Epoch 2, Loss: 2.233513593673706
Epoch 3, Loss: 2.1996114253997803
Epoch 4, Loss: 2.182072639465332
--- Round 5 ---
Generating dataset with 25 scenes...


 80%|████████  | 20/25 [06:35<01:38, 19.75s/it]


KeyboardInterrupt: 