In [None]:
import torch
import torch.nn as nn
import torch.autograd as autograd
import numpy as np
import matplotlib.pyplot as plt

# ===========================
# 0. Basic configuration
# ===========================
# Use GPU if available, otherwise fall back to CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Fix random seeds for reproducibility (both PyTorch and NumPy)
torch.manual_seed(1234)
np.random.seed(1234)

# Spatial and temporal domain:
#   x ∈ [x_min, x_max],  t ∈ [t_min, t_max]
x_min, x_max = -5.0, 5.0
t_min, t_max = 0.0, np.pi / 2

# Create a regular grid in (x, t) for evaluation and plotting
#   x: 1D array of spatial points
#   t: 1D array of time points
x = np.linspace(x_min, x_max, 200)    # Nx = 200 points in x
t = np.linspace(t_min, t_max, 100)    # Nt = 100 points in t

# Meshgrid for 2D visualization: X, T ∈ R^{Nt×Nx}
X, T = np.meshgrid(x, t)

# Flatten (X, T) into a list of points (Nt*Nx, 2) with columns [x, t]
X_test = np.hstack([X.flatten()[:, None], T.flatten()[:, None]])


# ===========================
# 1. PINN model definition
# ===========================
class SchrodingerPINN(nn.Module):
    """
    Physics-Informed Neural Network for the nonlinear Schrödinger equation.

    The network approximates:
        (x, t) -> [u(x,t), v(x,t)]
    where:
        u(x,t) : real part of the solution
        v(x,t) : imaginary part of the solution
    """

    def __init__(self, layers):
        """
        Args:
            layers: list of layer sizes, e.g. [2, 100, 100, 100, 100, 2]
                    layers[0]   = input dimension (x,t)
                    layers[-1]  = output dimension (u,v)
                    intermediate entries are hidden layer sizes
        """
        super().__init__()
        net = []
        for i in range(len(layers) - 2):
            net.append(nn.Linear(layers[i], layers[i + 1]))
            net.append(nn.Tanh())  # Tanh activation in all hidden layers
        # Last layer: linear mapping to 2 outputs (u, v), no activation
        net.append(nn.Linear(layers[-2], layers[-1]))
        self.net = nn.Sequential(*net)

    def forward(self, x):
        """
        Forward pass.

        Args:
            x: tensor of shape (N, 2) with columns [x, t]

        Returns:
            tensor of shape (N, 2): [u_pred, v_pred]
        """
        return self.net(x)


# Network architecture similar to the DeepXDE example:
#   input : 2 (x, t)
#   hidden: 4 layers of 100 neurons each
#   output: 2 (u, v)
layers = [2] + [100] * 4 + [2]
model = SchrodingerPINN(layers).to(device)


# ===========================
# 2. Sampling functions
# ===========================
def sample_interior(N_f):
    """
    Sample collocation points inside the space–time domain for the PDE residual.

    Args:
        N_f: number of interior points

    Returns:
        xt_f: tensor of shape (N_f, 2), each row [x, t]
    """
    x_f = np.random.uniform(x_min, x_max, (N_f, 1))
    t_f = np.random.uniform(t_min, t_max, (N_f, 1))
    xt_f = np.hstack([x_f, t_f])
    return torch.tensor(xt_f, dtype=torch.float32, device=device)


def sample_initial(N_ic):
    """
    Sample points on the initial time slice t = t_min (initial condition).

    Initial condition:
        u(x, 0) = 2 / cosh(x)
        v(x, 0) = 0

    Args:
        N_ic: number of initial-condition points

    Returns:
        xt_ic: tensor of shape (N_ic, 2), points [x, t_min]
        u0   : tensor of shape (N_ic, 1), target u(x,0)
        v0   : tensor of shape (N_ic, 1), target v(x,0)
    """
    x_ic = np.random.uniform(x_min, x_max, (N_ic, 1))
    t_ic = np.full_like(x_ic, t_min)
    xt_ic = np.hstack([x_ic, t_ic])

    # Exact initial conditions
    u0 = 2.0 / np.cosh(x_ic)
    v0 = np.zeros_like(x_ic)

    xt_ic = torch.tensor(xt_ic, dtype=torch.float32, device=device)
    u0 = torch.tensor(u0, dtype=torch.float32, device=device)
    v0 = torch.tensor(v0, dtype=torch.float32, device=device)
    return xt_ic, u0, v0


def sample_boundary(N_bc):
    """
    Sample boundary points at x = x_min and x = x_max for periodic BC enforcement.

    For each sampled time t, we create two points:
        left  : (x_min, t)
        right : (x_max, t)

    These will be used to enforce:
        u(x_min,t) = u(x_max,t), v(x_min,t) = v(x_max,t)
        u_x(x_min,t) = u_x(x_max,t), v_x(x_min,t) = v_x(x_max,t)

    Args:
        N_bc: number of time samples for the boundary

    Returns:
        xt_left  : tensor of shape (N_bc, 2) with [x_min, t]
        xt_right : tensor of shape (N_bc, 2) with [x_max, t]
    """
    t_bc = np.random.uniform(t_min, t_max, (N_bc, 1))

    x_left = np.full_like(t_bc, x_min)
    x_right = np.full_like(t_bc, x_max)

    xt_left = np.hstack([x_left, t_bc])
    xt_right = np.hstack([x_right, t_bc])

    xt_left = torch.tensor(xt_left, dtype=torch.float32, device=device)
    xt_right = torch.tensor(xt_right, dtype=torch.float32, device=device)
    return xt_left, xt_right


# ===========================
# 3. PDE residual (Schrödinger)
# ===========================
def pde_residual(model, xt):
    """
    Compute the residual of the nonlinear Schrödinger equation:

        f_u =   u_t + 0.5 v_xx + (u^2 + v^2) v = 0
        f_v = - v_t + 0.5 u_xx + (u^2 + v^2) u = 0

    Here:
        u(x,t) : real part
        v(x,t) : imaginary part

    Args:
        model: the neural network model
        xt   : tensor (N, 2) with [x, t]

    Returns:
        f_u, f_v: tensors (N,1) representing the PDE residuals for u and v
    """
    # We need gradients w.r.t. x and t, so enable autograd on xt
    xt.requires_grad_(True)

    # Forward pass through the network: (x,t) -> [u,v]
    uv = model(xt)  # shape (N, 2)
    u = uv[:, 0:1]  # real part
    v = uv[:, 1:2]  # imaginary part

    # First derivatives of u: u_x, u_t
    grads_u = autograd.grad(
        u,
        xt,
        torch.ones_like(u),
        create_graph=True,
        retain_graph=True
    )[0]  # shape (N, 2): [:,0]=∂u/∂x, [:,1]=∂u/∂t
    u_x = grads_u[:, 0:1]
    u_t = grads_u[:, 1:2]

    # First derivatives of v: v_x, v_t
    grads_v = autograd.grad(
        v,
        xt,
        torch.ones_like(v),
        create_graph=True,
        retain_graph=True
    )[0]  # shape (N, 2): [:,0]=∂v/∂x, [:,1]=∂v/∂t
    v_x = grads_v[:, 0:1]
    v_t = grads_v[:, 1:2]

    # Second derivatives (with respect to x again) for u and v
    u_xx = autograd.grad(
        u_x,
        xt,
        torch.ones_like(u_x),
        create_graph=True,
        retain_graph=True
    )[0][:, 0:1]  # take derivative w.r.t x-component
    v_xx = autograd.grad(
        v_x,
        xt,
        torch.ones_like(v_x),
        create_graph=True,
        retain_graph=True
    )[0][:, 0:1]

    # Nonlinear term |h|^2 = u^2 + v^2
    abs_h_sq = u**2 + v**2

    # PDE residuals
    f_u =   u_t + 0.5 * v_xx + abs_h_sq * v
    f_v = - v_t + 0.5 * u_xx + abs_h_sq * u

    return f_u, f_v


# ===========================
# 4. Total loss function
# ===========================
mse_loss = nn.MSELoss()

def loss_fn(model, N_f, N_ic, N_bc):
    """
    Build the global PINN loss:
        L = L_PDE + α * L_IC + β * L_BC

    Where:
        - L_PDE   : mean squared residual of the PDE in the interior
        - L_IC    : mismatch with initial conditions at t = t_min
        - L_BC    : mismatch with periodic boundary conditions at x_min, x_max
                    (both function values and first derivatives)

    Args:
        model: SchrodingerPINN model
        N_f  : number of interior points for PDE residual
        N_ic : number of points for initial condition
        N_bc : number of points for boundary condition

    Returns:
        total_loss, lpde, lic, lbc_val, lbc_deriv
    """

    # ----- 4.1 Interior (PDE) loss -----
    xt_f = sample_interior(N_f)
    f_u, f_v = pde_residual(model, xt_f)
    loss_pde = mse_loss(f_u, torch.zeros_like(f_u)) + \
               mse_loss(f_v, torch.zeros_like(f_v))

    # ----- 4.2 Initial condition loss -----
    # Sample points at t = t_min and apply u(x,0)=2/cosh(x), v(x,0)=0
    xt_ic, u0, v0 = sample_initial(N_ic)
    uv_ic = model(xt_ic)
    u_ic = uv_ic[:, 0:1]
    v_ic = uv_ic[:, 1:2]

    loss_ic = mse_loss(u_ic, u0) + mse_loss(v_ic, v0)

    # ----- 4.3 Periodic boundary loss (function values) -----
    # Sample points at x_min and x_max for the same times and enforce:
    #   u(x_min,t) = u(x_max,t), v(x_min,t) = v(x_max,t)
    xt_left, xt_right = sample_boundary(N_bc)
    uv_left = model(xt_left)
    uv_right = model(xt_right)

    u_left = uv_left[:, 0:1]
    v_left = uv_left[:, 1:2]
    u_right = uv_right[:, 0:1]
    v_right = uv_right[:, 1:2]

    loss_bc_val = mse_loss(u_left, u_right) + mse_loss(v_left, v_right)

    # ----- 4.4 Periodic boundary loss (first derivatives in x) -----
    # Enforce periodic derivatives:
    #   u_x(x_min,t) = u_x(x_max,t), v_x(x_min,t) = v_x(x_max,t)
    xt_left.requires_grad_(True)
    xt_right.requires_grad_(True)

    uv_left = model(xt_left)
    uv_right = model(xt_right)
    u_left = uv_left[:, 0:1]
    v_left = uv_left[:, 1:2]
    u_right = uv_right[:, 0:1]
    v_right = uv_right[:, 1:2]

    grads_u_left = autograd.grad(
        u_left, xt_left, torch.ones_like(u_left),
        create_graph=True, retain_graph=True
    )[0]
    grads_v_left = autograd.grad(
        v_left, xt_left, torch.ones_like(v_left),
        create_graph=True, retain_graph=True
    )[0]
    grads_u_right = autograd.grad(
        u_right, xt_right, torch.ones_like(u_right),
        create_graph=True, retain_graph=True
    )[0]
    grads_v_right = autograd.grad(
        v_right, xt_right, torch.ones_like(v_right),
        create_graph=True, retain_graph=True
    )[0]

    u_x_left = grads_u_left[:, 0:1]
    v_x_left = grads_v_left[:, 0:1]
    u_x_right = grads_u_right[:, 0:1]
    v_x_right = grads_v_right[:, 0:1]

    loss_bc_deriv = mse_loss(u_x_left, u_x_right) + mse_loss(v_x_left, v_x_right)

    # ----- 4.5 Combine losses with weights -----
    # You can tune these weights (10, 1, 1) depending on training behavior.
    loss = loss_pde + 10.0 * loss_ic + 1.0 * loss_bc_val + 1.0 * loss_bc_deriv

    return loss, loss_pde.item(), loss_ic.item(), loss_bc_val.item(), loss_bc_deriv.item()


# ===========================
# 5. Training loop (Adam)
# ===========================
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Number of points for each type of constraint
N_f = 5000   # interior collocation points
N_ic = 200   # initial-condition points
N_bc = 200   # boundary points

num_epochs = 5000

for epoch in range(1, num_epochs + 1):
    optimizer.zero_grad()
    loss, lpde, lic, lbcv, lbcd = loss_fn(model, N_f, N_ic, N_bc)
    loss.backward()
    optimizer.step()

    # Print training progress every 500 epochs
    if epoch % 500 == 0:
        print(
            f"Epoch {epoch:5d} | "
            f"Total: {loss.item():.4e} | "
            f"PDE: {lpde:.4e} | IC: {lic:.4e} | "
            f"BC_val: {lbcv:.4e} | BC_deriv: {lbcd:.4e}"
        )

# (Optional) You could add an L-BFGS refinement step here if you want.


# ===========================
# 6. Inference on the grid
# ===========================
model.eval()

with torch.no_grad():
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32, device=device)
    pred = model(X_test_tensor).cpu().numpy()   # shape (Nt*Nx, 2)

# Reshape predictions back to (Nt, Nx) for plotting:
#   T.shape == X.shape == (Nt, Nx) = (100, 200)
u_pred = pred[:, 0].reshape(T.shape)  # real part u(x,t)
v_pred = pred[:, 1].reshape(T.shape)  # imaginary part v(x,t)

# Amplitude (or "mass density" in Schrödinger context)
h_pred = np.sqrt(u_pred**2 + v_pred**2)
