In [None]:
%matplotlib widget
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from simple_pinn import SchrodingersEqDataset, SimplePINN, evalAndCompare
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D

In [None]:
# load the data
dataset = SchrodingersEqDataset("../PINNs/main/Data/NLS.mat")

dl = DataLoader(dataset, batch_size=10_000, shuffle=True)

# create the model
model = SimplePINN(input_size=2, hidden_layers=[20, 20, 20, 20], output_size=2)
print(model)

optimizer = optim.Adam(model.parameters(), lr=0.01)

n_epochs = 200
for epoch in range(n_epochs):
    for inputs, targets in dl:
        outputs = model(inputs)
        # loss = model.loss_fn(outputs, targets)
        loss = torch.nn.MSELoss()(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"epoch {epoch + 1}/{n_epochs}, loss={loss.item():.3f}")

In [None]:
x_grid, t_grid, pred_h, target_h = evalAndCompare(model, dataset)

In [None]:
fig = plt.figure()

t_sample = [0, 50, 100, 150]  # Sample values of t
colors = ["r", "g", "b", "k"]  # Colors for each time point
for t_ndx, color in zip(t_sample, colors):
    plt.plot(
        x_grid[:, t_ndx],
        pred_h.detach().numpy()[:, t_ndx],
        linestyle="--",
        label=f"pred_h, t={t_ndx}",
        color=color,
    )
    plt.plot(
        x_grid[:, t_ndx],
        target_h.detach().numpy()[:, t_ndx],
        linestyle="-",
        label=f"target_h, t={t_ndx}",
        color=color,
    )

plt.xlabel("x")
plt.ylabel("h")
plt.legend(loc="best")
plt.show()

In [None]:
# Forward pass
outputs = model(inputs)

tmp = torch.ones_like(outputs)
tmp[:, 1] = 0
tmp

# Compute gradients
h_wrt_x = torch.autograd.grad(outputs, inputs, grad_outputs=tmp, create_graph=True)[0]
h_wrt_xx = torch.autograd.grad(h_wrt_x, inputs, grad_outputs=tmp, create_graph=True)[0]

In [None]:
h_wrt_xx