# Initialise the SympNet

We will use symplectic neural networks with quadratic ridge polynomials (which is the best method for quadratic Hamiltonians) by setting the optional arguments `max_degree=2` and `method='P'`.

In [8]:
from strupnet import SympNet
import torch 

sympnet = SympNet(
    dim=1, # dimension of p or q
    layers=2,
    max_degree=2,
    method='P',
)

# Generate training and testing data 

We will generate data of the form $ \{x(ih)\}_{i=0}^{n+1}=\{p(ih), q(ih)\}_{i=0}^{n+1}$, where $x(t)$ is the solution to the Hamiltonian ODE $\dot{x} = J\nabla H $, with the simple Harmonic oscillator Hamiltonian $ H = \frac{1}{2} (p^2 + q^2) $. The data is arranged in the form $ x_0 = \{x(ih)\}_{i=0}^{n} $, $ x_1 = \{x((i+1)h)\}_{i=0}^{n} $ and same for $ t $. 


In [9]:
def simple_harmonic_oscillator_solution(t_start, t_end, timestep):
    time_grid = torch.linspace(t_start, t_end, int((t_end-t_start)/timestep)+1)
    p_sol = torch.cos(time_grid)
    q_sol = torch.sin(time_grid)
    pq_sol = torch.stack([p_sol, q_sol], dim=-1)
    return pq_sol, time_grid.unsqueeze(dim=1)

timestep=0.05
x, t = simple_harmonic_oscillator_solution(t_start=0, t_end=1, timestep=timestep)
x_test, t_test = simple_harmonic_oscillator_solution(t_start=1, t_end=4, timestep=timestep)
x0, x1, t0, t1 = x[:-1, :], x[1:, :], t[:-1, :], t[1:, :]
x0_test, x1_test, t0_test, t1_test = x_test[:-1, :], x_test[1:, :], t_test[:-1, :], t_test[1:, :]

# Train the sympnet like any PyTorch model 
All the models in `strupnet` inherit from `torch.nn.Module` and can be trained as such. The loss function can be defined as follows. Letting $\Phi_h^{\theta}(x)$ denote the SympNet, where $\theta$ denotes its set of trainable parameters, then we want to find $\theta$ that minimises 

$\qquad loss=\sum_{i=0}^{n}\|\Phi_h^{\theta}(x(ih))-x\left((i+1)h\right)\|^2$


In [10]:
optimizer = torch.optim.Adam(sympnet.parameters(), lr=0.01)
mse = torch.nn.MSELoss()
for epoch in range(1000):
    optimizer.zero_grad()    
    x1_pred = sympnet(x=x0, dt=t1 - t0)
    loss = mse(x1, x1_pred)
    loss.backward()
    optimizer.step()

print("Final loss value: ", loss.item())

Final loss value:  3.02371001268796e-33


# Evaluate the trained model on the test data set

In [11]:
x1_test_pred = sympnet(x=x0_test, dt=t1_test - t0_test)

print("Test set error", torch.norm(x1_test_pred - x1_test).item())

Test set error 6.808010630459377e-16
