# TP 4 - Neural Galerkin

In [12]:
import torch
from torch.autograd import grad as grad
from torch.func import functional_call, jacrev, vmap
import matplotlib.pyplot as plt

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"torch loaded; device is {device}")

torch.set_default_dtype(torch.double)

### Neural network

In [305]:
class mlp(torch.nn.Module):
    def __init__(self, layer_widths = list):
        super().__init__()

        self.layer_widths = layer_widths
        
        self.hidden_layers = []
        for i in range( len(layer_widths)-1):
            self.hidden_layers.append(torch.nn.Linear(self.layer_widths[i], self.layer_widths[i+1], dtype=torch.double))

        self.hidden_layers = torch.nn.ModuleList(self.hidden_layers)
        self.activation = torch.nn.Tanh()

    def forward(self, inputs):
        for i in range( len(self.layer_widths)-2):
                inputs = self.activation(self.hidden_layers[i].forward(inputs))

        inputs = self.hidden_layers[-1].forward(inputs)
        return inputs


### Sampling data

In [1]:
class SamplerBox:

    pass

In [2]:
def test_make_data():
    pass

### Pinn class

In [13]:
class NeuralGalerkin:

    pass 

    
    def compute_jacobian(self, x, mu):
        """
        this function computes the Jacobians of the model with respect to the weights at each (x,mu) of a tensor
        If we have n points, we have n jacobians J(\theta)(x,mu).
        """
        theta = {k: v.detach() for k, v in self.network.named_parameters()}

        def fnet(theta, x, mu):
            return functional_call(self.network, theta, torch.cat([x, mu], axis=0))
        
        # (None, 0, 0) means that:
        #   - the first argument (params) is not batched
        #   - the second argument (x) is batched along the first dimension
        #   - the third argument (mu) is batched along the first dimension
        jac = vmap(jacrev(fnet), (None, 0, 0))(theta, x, mu).values()

        # jac is a dict of jagged tensors, we want to:
        #   - first reshape each jagged tensor to (nb_data, nb_unknowns, nb_params)
        #   - then concatenate them along the last dimension
        nb_data = x.shape[0]
        return torch.cat([j.reshape((nb_data, 1, -1)) for j in jac], axis=-1)

    def compute_M_and_F(self):
        """
        this function computes the mass matrix and the RHS of the Neural Galerkin method
        M(theta) = frac1/N sum (J(theta) otimes J(theta))(x,mu)
        F(theta) = frac1/N sum (J(theta) f(theta))(x,mu)
        """
        nb_data = 10000
        x = self.sampler_space(nb_data)
        mu = self.sampler_param(nb_data)
        jacobian = self.compute_jacobian(x, mu)

        self.M = self.regularization_matrix + torch.einsum("bjs,bjr->sr", jacobian, jacobian) / nb_data

        #advection = ...

        self.F = torch.einsum("bji,bj->i", jacobian, advection.unsqueeze(1)) / nb_data
        self.F.flatten()


    