In [39]:
from constants import DICT_CASE, M
import pandas as pd
import torch
from utils import charge_data, GraphDataset
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

In [21]:
hyper_param = {
    "num": 1,
    "case": 2,
    "nb_epoch": 1000,
    "save_rate": 10,
    "batch_size": 10000,
    "lr_init": 1e-4,
    "gamma_scheduler": 0.999,
    "nb_layers": 15,
    "nb_neurons": 64,
    "n_data_test": 5000,
    "nb_points": 100,
    "x_min": -0.06,
    "x_max": 0.06,
    "y_min": -0.06,
    "y_max": 0.06,
    "t_min": 6.5,
    "nb_period": 20,
    "nb_period_plot": 2,
    "force_inertie_bool": False,
    "nb_period": 20,
    "u_border": True,
    "v_border": False,
    "p_border": True,
    "r_min": 0.026/2,
    'theta_border_min': 0.1,
    'is_res': True,
    'nb_blocks': 60,  # Pour ResNet
    'nb_layer_block': 3,  # Pour ResNet
    'nb_timestep': 10
}
hyper_param['H'] = [DICT_CASE[str(hyper_param['case'])]]
hyper_param['file'] = [
    f"model_{hyper_param['num']}_case_{hyper_param['case']}.csv"
    ]
hyper_param['m'] = M

param_adim = {"V": 1.0, "L": 0.025, "rho": 1.2}

In [24]:
X_full_time, U_full_time, mean_std = charge_data(hyper_param, param_adim)

In [26]:
dataset = GraphDataset(X_full_time, U_full_time)

In [27]:
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)

In [28]:
edge_neighbours = dataset.edge_neighbours

In [34]:
U, U_n = next(iter(dataloader))

In [35]:
U

tensor([[[ 0.5035, -0.1058,  0.6239],
         [ 0.5091, -0.1260,  0.6131],
         [ 0.4960, -0.1081,  0.6413],
         ...,
         [ 0.8478,  0.1131,  0.1834],
         [ 0.8259,  0.1562,  0.2070],
         [ 0.8280,  0.1215,  0.2039]],

        [[ 0.5036, -0.0605,  0.5536],
         [ 0.5091, -0.0810,  0.5431],
         [ 0.4959, -0.0626,  0.5708],
         ...,
         [ 0.8245,  0.2343,  0.1758],
         [ 0.8037,  0.2591,  0.1889],
         [ 0.8081,  0.2332,  0.1904]]])

In [41]:
MLP_to_latent = MLP(3, 128, 2)

In [42]:
U_latent = MLP_to_latent(U)

In [45]:
U.shape

torch.Size([2, 5540, 3])

In [44]:
U_latent.shape

torch.Size([2, 5540, 128])

In [99]:
class GNN(nn.Module):
    def __init__(self, hyper_param, edge_neighbours):
        super().__init__()
        self.nb_hidden = hyper_param['nb_hidden']
        self.nb_hidden_encode = hyper_param['nb_hidden_encode']
        self.dim_latent = hyper_param['dim_latent']
        self.nb_gn = hyper_param['nb_gn']
        self.edge_neighbours = edge_neighbours
        self.gnn = nn.ModuleList(
            [MLP(dim_in=3, dim_out=self.dim_latent, nb_hidden=self.nb_hidden_encode)]
            + [GN(self.edge_neighbours, self.nb_hidden, self.dim_latent)]
            + [MLP(dim_in=self.dim_latent, dim_out=3, nb_hidden=self.nb_hidden_encode)]
        )
    
    def forward(self, x):
        for block in self.gnn:
            x = block(x)
        return x



class GN(nn.Module):
    def __init__(self, edge_neighbours, nb_hidden, dim_latent):
        super().__init__()
        self.nb_neighbours = edge_neighbours.shape[1]
        self.mlp_neigh = nn.ModuleList(
            [MLP(dim_latent, dim_latent, nb_hidden)] +  # pour lui même 
            [ MLP(dim_latent, dim_latent, nb_hidden) for _ in range(self.nb_neighbours)] # pour les voisins
            ) 
        self.edge_neighbours = edge_neighbours # Nb_noeuds * nb_neigbours
        self.LayerNorm = torch.nn.LayerNorm(dim_latent)
    
    def forward(self, x):
        message = self.mlp_neigh[0](x) # batch_size * nb_noeuds * dim_latent
        for k in range(1, self.nb_neighbours+1):
            message += self.mlp_neigh[k](x[:, self.edge_neighbours[:, k-1]]) 
        return self.LayerNorm(x + F.relu(message))



class MLP(nn.Module):
    def __init__(self, dim_in, dim_out, nb_hidden):
        # self.dim_in = dim_in
        # self.dim_out = dim_out
        super().__init__()
        self.linear_first = nn.ModuleList([nn.Linear(dim_in, dim_out)])
        self.hidden = nn.ModuleList([
            nn.Linear(dim_out, dim_out) for _ in range(nb_hidden)
        ])
        self.mlp = self.linear_first + self.hidden
        self.initial_param()
    
    def forward(self, x):
        for layer in self.mlp:
            x = F.relu(layer(x))
        return x
    
    def initial_param(self):
        for layer in self.mlp:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)



In [102]:
hyper_param = {
    'nb_hidden': 2,
    'dim_latent': 128,
    'nb_gn': 5,
    'nb_hidden_encode': 4,
}

In [103]:
gnn = GNN(hyper_param, edge_neighbours)

In [104]:
gnn(U).shape

torch.Size([2, 5540, 3])

In [105]:
U_n.shape

torch.Size([2, 5540, 3])

In [106]:
U_latent = MLP_to_latent(U)

In [107]:
GN_first_latent = GN(edge_neighbours, nb_hidden=2, dim_latent=128)

In [108]:
edge_neighbours.shape

torch.Size([5540, 4])

In [109]:
GN_first_latent(U_latent).shape

torch.Size([2, 5540, 128])

In [60]:
U_latent.shape

torch.Size([2, 5540, 128])

In [63]:
U_latent[:, edge_neighbours[:, 0]]

tensor([[[0.0000, 0.0000, 0.0090,  ..., 0.0204, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0094,  ..., 0.0202, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0092,  ..., 0.0211, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0090,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0080,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0098,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0078,  ..., 0.0184, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0088,  ..., 0.0182, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0081,  ..., 0.0185, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0051,  ..., 0.0000, 0.0009, 0.0000],
         [0.0000, 0.0000, 0.0042,  ..., 0.0000, 0.0029, 0.0000],
         [0.0000, 0.0000, 0.0060,  ..., 0.0000, 0.0000, 0.0000]]],
       grad_fn=<IndexBackward0>)

In [57]:
GN_first_latent

GN(
  (mlp_neigh): ModuleList(
    (0-4): 5 x MLP(
      (linear_first): ModuleList(
        (0): Linear(in_features=128, out_features=128, bias=True)
      )
      (hidden): ModuleList(
        (0-1): 2 x Linear(in_features=128, out_features=128, bias=True)
      )
      (mlp): ModuleList(
        (0-2): 3 x Linear(in_features=128, out_features=128, bias=True)
      )
    )
  )
)