In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from model import GNN
from utils import charge_data
from torch.utils.data import Dataset
from constants import DICT_Y0, DICT_CASE, PARAM_ADIM, M

In [2]:
class GraphDataset(Dataset):
    def __init__(self, X_full, U_full, nb_neighbours):
        self.X_full = X_full
        self.U_full = U_full
        time_ = X_full[:, 2].unique()[5]
        self.delta_t = (X_full[:, 2].unique()[1:] - X_full[:, 2].unique()[:-1]).mean()
        X = X_full[X_full[:, 2] == time_][:, :2]
        indices = torch.argsort(X[:, 0]*1.1 + X[:, 1]*1.5)
        X_sort = X[indices]
        # Broadcasting de génie
        diff = X_sort[:, None, :] - X_sort[None, :, :]
        # Calcul des distances euclidiennes
        distances = torch.sqrt((diff ** 2).sum(dim=2))
        self.edge_neighbours = torch.empty((X_sort.shape[0], nb_neighbours), dtype=torch.long)
        edge_attributes = torch.empty((X_sort.shape[0], nb_neighbours, 3))
        for i in range(X_sort.shape[0]):
            self.edge_neighbours[i, :] = torch.topk(distances[i, :], nb_neighbours + 1, largest=False)[1][1:]
            for k, neighbour in enumerate(self.edge_neighbours[i, :]):
                neighbour = neighbour.item()
                # print(X_sort[i,0]-X_sort[neighbour, 0])
                # print(i, neighbour)
                # print(X_sort[i,0]-X_sort[neighbour, 0])
                # print(X_sort[i,1]-X_sort[neighbour, 1])
                # print(torch.sqrt(torch.sum((X_sort[i, :]-X_sort[neighbour, :])**2)))
                edge_attributes[i, k, :] = torch.stack((X_sort[i,0]-X_sort[neighbour, 0], X_sort[i,1]-X_sort[neighbour, 1], torch.sqrt(torch.sum((X_sort[i]-X_sort[neighbour])**2))))
                print(torch.stack((X_sort[i,0]-X_sort[neighbour, 0], X_sort[i,1]-X_sort[neighbour, 1], torch.sqrt(torch.sum((X_sort[i]-X_sort[neighbour])**2)))))
        self.edge_attributes = edge_attributes


    def __len__(self):
        return self.X_full[:, 2].unique().shape[0] - 1

    def __getitem__(self, idx):
        time_ = self.X_full[:, 2].unique()[idx]
        masque_time = self.X_full[:, 2]==time_
        X = self.X_full[masque_time][:, :2]
        indices = torch.argsort(X[:, 0]*1.1 + X[:, 1]*1.5)
        U_sort = self.U_full[masque_time][indices]

        # le next time
        time_n = self.X_full[:, 2].unique()[idx + 1]
        masque_time_n = self.X_full[:, 2]==time_n
        X_n = self.X_full[masque_time_n][:, :2]
        indices_n = torch.argsort(X_n[:, 0]*1.1 + X_n[:, 1]*1.5)
        U_sort_n = self.U_full[masque_time_n][indices_n]
        return U_sort, U_sort_n

In [3]:

hyper_param = {
    "num": 1,
    "case": 2,
    "nb_epoch": 1000,
    "save_rate": 10,
    "batch_size": 10,
    "lr": 1e-3,
    "gamma_scheduler": 0.999,
    "x_min": -0.06,
    "x_max": 0.06,
    "y_min": -0.06,
    "y_max": 0.06,
    "t_min": 6.5,
    'nb_hidden': 2,
    'dim_latent': 32,
    'nb_gn': 5,
    'nb_hidden_encode': 4,
    'nb_neighbours': 4
}

hyper_param['H'] = [DICT_CASE[str(hyper_param['case'])]]
hyper_param['ya0'] = [DICT_Y0[str(hyper_param['num'])]]
hyper_param['file'] = [
    f"model_{hyper_param['num']}_case_{hyper_param['case']}.csv"
    ]
hyper_param['m']=M

In [4]:
dataset = GraphDataset(X_f, U_f, 4)

NameError: name 'X_f' is not defined

In [100]:
torch.stack((torch.tensor(0.), torch.tensor(-0.1081) ))

tensor([ 0.0000, -0.1081])

In [101]:
torch.concat((torch.zeros(1), torch.zeros(1)))

tensor([0., 0.])

In [102]:
piche = dataset.edge_attributes

In [106]:
dataset.edge_attributes
dataset.edge_attributes
dataset.edge_attributes

tensor([[[ 0.0000, -0.1081,  0.1081],
         [-0.1428,  0.0000,  0.1428],
         [-0.1428, -0.1081,  0.1791],
         [-0.0636, -0.1906,  0.2009]],

        [[-0.0825, -0.0288,  0.0874],
         [ 0.0000, -0.1081,  0.1081],
         [ 0.1428,  0.0000,  0.1428],
         [-0.1681,  0.0470,  0.1745]],

        [[-0.0636, -0.0825,  0.1042],
         [ 0.0000,  0.1081,  0.1081],
         [-0.1428,  0.0000,  0.1428],
         [ 1.0000,  2.0000,  3.0000]],

        ...,

        [[ 0.0298,  0.0825,  0.0877],
         [ 0.0000, -0.1081,  0.1081],
         [ 0.1090,  0.0000,  0.1090],
         [ 0.1090, -0.1081,  0.1535]],

        [[ 0.0825,  0.0288,  0.0874],
         [ 0.0000,  0.1081,  0.1081],
         [-0.1090,  0.0000,  0.1090],
         [-0.1090,  0.1081,  0.1535]],

        [[ 0.0000,  0.1081,  0.1081],
         [ 0.1090,  0.0000,  0.1090],
         [ 0.1090,  0.1081,  0.1535],
         [ 0.0298,  0.1906,  0.1929]]])

In [103]:
piche[2, 3]=torch.tensor([1.,2.,3.])

In [104]:
piche[2, 3]

tensor([1., 2., 3.])

In [105]:
dataset.edge_attributes[2, 3]

tensor([1., 2., 3.])

In [5]:
import torch 

In [6]:
a = torch.randn((5,3, 4))

In [11]:
a

tensor([[[-1.2838,  0.0213,  0.1135,  2.4483],
         [-0.2755, -0.1877,  0.6954, -0.9889],
         [ 0.0054,  1.0827,  0.1994,  0.0959]],

        [[ 1.3519,  0.9063,  0.8944,  0.1244],
         [ 0.7181, -0.2826,  0.5639,  0.7737],
         [-1.9730, -0.1986, -0.3044, -0.2061]],

        [[ 0.1128,  2.3611,  1.3186,  0.4827],
         [-0.8359,  0.0433,  0.8415, -0.0158],
         [-1.8652, -1.5363, -1.2953,  0.9155]],

        [[ 1.7934, -0.8430,  0.5325,  0.0409],
         [-0.3138, -0.6120,  2.2826, -0.8945],
         [-0.9073,  1.4630, -0.0495,  0.9709]],

        [[ 0.5358, -0.0045,  1.4072,  0.2024],
         [-0.0117, -0.8032, -1.7285, -0.6000],
         [-1.3600, -1.5378, -1.2304, -0.0209]]])

In [10]:
a + 0.05*torch.randn_like(a)

tensor([[[-1.2745,  0.0083,  0.0226,  2.5374],
         [-0.3226, -0.2034,  0.6836, -0.9504],
         [-0.0141,  1.0703,  0.0439,  0.0201]],

        [[ 1.2690,  0.9001,  0.8887,  0.0970],
         [ 0.7107, -0.2790,  0.6068,  0.8280],
         [-2.0173, -0.2709, -0.3346, -0.1197]],

        [[ 0.0907,  2.2702,  1.3379,  0.4878],
         [-0.8706,  0.0127,  0.8036,  0.0300],
         [-1.8692, -1.5607, -1.3504,  1.0059]],

        [[ 1.7593, -0.9113,  0.5366,  0.1056],
         [-0.3094, -0.5791,  2.2427, -0.8925],
         [-0.9039,  1.4117, -0.0286,  0.8650]],

        [[ 0.5382, -0.0038,  1.3397,  0.2231],
         [-0.0072, -0.8301, -1.7087, -0.5797],
         [-1.3330, -1.5420, -1.2513, -0.0695]]])