In [3]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

In [14]:
import math


x0: float = 1.0 
t_min = 0.0
t_max = 1.0 

hidden_size = 16 # number of neurons 
lr = 0.1


seed = 1 


class MLP(nn.Module):
    def __init__(self,hidden_size = hidden_size , output_tanh = False):
        super().__init__()
        layers = [
            nn.Linear(1, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1), 
        ]
        if output_tanh:
            layers.append(nn.Tanh()) 
        
        self.mlp = nn.Sequential(*layers)
    
    def forward(self, t):
        return self.mlp(t)
    


def choose_device(device = ""):
    if device:
        chosen = torch.device(device)
        if chosen.type == "mps": # 
            raise ValueError("MPS does not support float64")
        return chosen
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


def grad(outputs, inputs):
    return torch.autograd.grad( # compute derivative of outputs w.r.t. inputs 
        outputs,
        inputs,
        grad_outputs=torch.ones_like(outputs), # vector Jacobian product
        create_graph= True,  # because we will differentiate again 
        retain_graph=True,
    )[0]

def pde(x,t, omega): # your physical system
    dxdt = grad(x,t) # first derivative
    d2xdt2 = grad(dxdt, t) 
    return d2xdt2 +  (omega*omega)*x 


def model_loss(model_nn,t, x0_true , dx0dt_true, omega,bc_weight,  mse): # true initial value 
    x = model_nn(t) # neural network 
    f = pde(x,t,omega) # physical system 

    t0 = torch.zeros((1,1), dtype=t.dtype, device=t.device, requires_grad=True)
    x0_pred = model_nn(t0) 
    dx0dt_pred = grad(x0_pred,t0)

    loss_bc = mse(dx0dt_pred, dx0dt_true)
    loss_ic = mse(x0_pred,x0_true)
    loss_pde = mse(f, torch.zeros_like(f))

    loss = loss_pde + bc_weight*(loss_bc + loss_ic) # in this particular case: bc_weight 

    x_true = x0_true*torch.cos(omega*t) # true solution 
    loss_data = mse(x, x_true)

    metrics = {
        "loss_total": float(loss.detach().cpu()),
        "loss_bc": float(loss_bc.detach().cpu()),
        "loss_ic": float(loss_ic.detach().cpu()),
        "loss_pde": float(loss_pde.detach().cpu()),
        "loss_data": float(loss_data.detach().cpu()),
    }

    return loss, metrics


def train(model, device):
    omega = 2.0 * torch.pi 
    mse = nn.MSELoss()
    t = torch.linspace(t_min,t_max, 100, dtype=torch.float64, device=device,requires_grad=True).view(-1,1)
    x0_true = torch.tensor([[x0]], dtype=torch.float64,device= device)
    dx0dt_true = torch.zeros((1,1), dtype=torch.float64, device=device) # boundary condition 

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for step in range(1,1000 + 1):
        optimizer.zero_grad() 
        loss_val, metrics = model_loss(
            model_nn=model,
            t=t,
            x0_true=x0_true,
            dx0dt_true=dx0dt_true,
            omega=omega,
            bc_weight=1000,
            mse=mse
        )
        loss_val.backward()
        optimizer.step()

        if step ==1 or step % 40 == 0:
            print(metrics)



In [15]:
device = choose_device("cpu")
torch.set_default_dtype(torch.float64)
model = MLP().to(device)
train(model,device)

{'loss_total': 1669.023345413963, 'loss_bc': 0.00023518054132183432, 'loss_ic': 1.5711219567089207, 'loss_pde': 97.66620816372023, 'loss_data': 0.5736324862236907}
{'loss_total': 261.8001415520049, 'loss_bc': 0.009121935391725964, 'loss_ic': 0.10905503133303293, 'loss_pde': 143.62317482724598, 'loss_data': 0.5304720550417436}
{'loss_total': 1243.0685386318964, 'loss_bc': 0.0033509649797199585, 'loss_ic': 0.0005799169048668543, 'loss_pde': 1239.1376567473096, 'loss_data': 1.2705995265433279}
{'loss_total': 681.0918693926845, 'loss_bc': 8.707520402343972e-08, 'loss_ic': 0.6245240692854183, 'loss_pde': 56.56771303206221, 'loss_data': 0.5172182276483784}
{'loss_total': 540.8496569121605, 'loss_bc': 1.138474861201276e-08, 'loss_ic': 0.23245358659995366, 'loss_pde': 308.3960589274582, 'loss_data': 0.8250145759182536}
{'loss_total': 613.7217283848116, 'loss_bc': 4.084500542225666e-18, 'loss_ic': 0.4243355775365602, 'loss_pde': 189.3861508482514, 'loss_data': 0.6195428805253418}
{'loss_total':