In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from itertools import chain

import simple_pinn
import utils
import cv_mesh
import cv_solver
import athena_reader

# for dev purposes, reload these modules each time this cell is run
import importlib
importlib.reload(simple_pinn)
importlib.reload(utils)
importlib.reload(cv_mesh)
importlib.reload(cv_solver)
importlib.reload(athena_reader)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 

# torch.set_default_dtype(torch.float64)

In [None]:
x_domain = y_domain = [-10, 10]
Nx_sample = Ny_sample = 100
mesh = cv_mesh.CVMesh(
    x_domain,
    y_domain,
    Nx_sample,
    Ny_sample,
    quad_pts=(8, 8),
    quad_rule="composite_trapezoid",
    requires_grad=True,
)
mesh.to(DEVICE)

In [None]:
# # The problem consists of an annulus of inner radius a and outer radius b, centered at the origin, with dielectric constant epsilon_1,
# # immersed in a medium of dielectric constant epsilon_0; an external field of strength E_0 is applied in the positive x-direction.
# a = 0.1
# b = 5.0
# epsilon_0 = 1.0
# epsilon_1 = 5.0
# # epsilon_0 = 5.0
# # epsilon_1 = 1.0
# E_0 = 1.0

# # The functions below provide the correct electric potential and its derivatives, as well as some other helpful things.
# # Note that all have a dummy "eos=None" argument, because the CV solver framework expects to pass around an equation of state.


# # function that provides the correct electric potential for inputs (x, y)
# def synthetic_data(inputs, eos=None):
#     x = inputs[..., 0]
#     y = inputs[..., 1]

#     rho = torch.sqrt(x**2 + y**2)
#     phi = torch.atan2(y, x)

#     # Create conditions for the different regions
#     condition1 = rho < a
#     condition2 = (rho >= a) & (rho < b)
#     condition3 = rho >= b

#     # Initialize the potential tensor
#     Phi = torch.zeros_like(rho)

#     # Calculate Phi for rho < a
#     Phi[condition1] = (
#         (
#             -4
#             * b**2
#             * epsilon_1
#             * epsilon_0
#             / (
#                 (
#                     b**2 * (epsilon_1 + epsilon_0) ** 2
#                     - a**2 * (epsilon_1 - epsilon_0) ** 2
#                 )
#             )
#         )
#         * E_0
#         * rho[condition1]
#         * torch.cos(phi[condition1])
#     )

#     # Calculate Phi for a <= rho < b
#     Phi[condition2] = (
#         (
#             -2
#             * a
#             * b**2
#             * epsilon_0
#             / (
#                 (
#                     b**2 * (epsilon_1 + epsilon_0) ** 2
#                     - a**2 * (epsilon_1 - epsilon_0) ** 2
#                 )
#             )
#         )
#         * (
#             (epsilon_1 + epsilon_0) * (rho[condition2] / a)
#             + (epsilon_1 - epsilon_0) * (a / rho[condition2])
#         )
#         * E_0
#         * torch.cos(phi[condition2])
#     )

#     # Calculate Phi for rho >= b
#     Phi[condition3] = (
#         (
#             -rho[condition3]
#             + (b**2 - a**2)
#             * (epsilon_1**2 - epsilon_0**2)
#             / (
#                 b**2 * (epsilon_1 + epsilon_0) ** 2
#                 - a**2 * (epsilon_1 - epsilon_0) ** 2
#             )
#             * b**2
#             / rho[condition3]
#         )
#         * E_0
#         * torch.cos(phi[condition3])
#     )

#     return Phi.unsqueeze(-1)


# # function that provides the x and y partial derivatives of the (correct) potential, for inputs (x, y)
# def synthetic_data_derivatives(inputs, eos=None):
#     phi = synthetic_data(inputs)

#     # Calculate the derivatives of Phi
#     Phi_x = torch.autograd.grad(
#         phi,
#         inputs,
#         grad_outputs=torch.ones_like(phi),
#         retain_graph=True,
#         create_graph=True,
#     )[0][..., 0]
#     Phi_y = torch.autograd.grad(
#         phi,
#         inputs,
#         grad_outputs=torch.ones_like(phi),
#         retain_graph=True,
#         create_graph=True,
#     )[0][..., 1]
#     return Phi_x, Phi_y


# def synthetic_data_bc_derivatives(inputs, eos=None):
#     # this version gives "ideal" derivatives that are valid at the boundary,
#     # pretending the boundary is actually at infinity
#     Phi_x = E_0 * torch.ones_like(inputs[..., 0])
#     Phi_y = torch.zeros_like(inputs[..., 1])
#     return Phi_x, Phi_y


# # provide the correct epsilon at the given inputs (x, y)
# def true_epsilon(inputs, eos=None):
#     x = inputs[..., 0]
#     y = inputs[..., 1]

#     rho = torch.sqrt(x**2 + y**2)

#     # Create conditions for the different regions
#     # condition1 = rho < a
#     condition2 = (rho >= a) & (rho < b)
#     # condition3 = rho >= b

#     # Set epsilon for a <= rho < b
#     return torch.where(condition2, epsilon_1, epsilon_0)

In [None]:
# The problem consists a cylinder with radius b, centered at the origin, with dielectric constant epsilon_1,
# immersed in a medium of dielectric constant epsilon_0; an external field of strength E_0 is applied in the positive x-direction.
b = 5.0
epsilon_0 = 1.0
epsilon_1 = 5.0
E_0 = 1.0

# The functions below provide the correct electric potential and its derivatives, as well as some other helpful things.
# Note that all have a dummy "eos=None" argument, because the CV solver framework expects to pass around an equation of state.


# function that provides the correct electric potential for inputs (x, y)
def synthetic_data(inputs, eos=None):
    x = inputs[..., 0]
    y = inputs[..., 1]

    rho = torch.sqrt(x**2 + y**2)
    phi = torch.atan2(y, x)

    # Create conditions for the different regions
    condition1 = rho < b
    condition2 = ~condition1

    # Initialize the potential tensor
    Phi = torch.zeros_like(rho)

    # Calculate Phi for rho < a
    Phi[condition1] = (
        -2
        * (epsilon_0 / (epsilon_0 + epsilon_1))
        * E_0
        * rho[condition1]
        * torch.cos(phi[condition1])
    )

    # Calculate Phi for a <= rho < b
    Phi[condition2] = (
        (
            -rho[condition2]
            + ((epsilon_1 - epsilon_0) / (epsilon_1 + epsilon_0))
            * (b**2 / rho[condition2])
        )
        * E_0
        * torch.cos(phi[condition2])
    )

    return Phi.unsqueeze(-1)


# function that provides the x and y partial derivatives of the (correct) potential, for inputs (x, y)
def synthetic_data_derivatives(inputs, eos=None):
    phi = synthetic_data(inputs)

    # Calculate the derivatives of Phi
    Phi_x = torch.autograd.grad(
        phi,
        inputs,
        grad_outputs=torch.ones_like(phi),
        retain_graph=True,
        create_graph=True,
    )[0][..., 0]
    Phi_y = torch.autograd.grad(
        phi,
        inputs,
        grad_outputs=torch.ones_like(phi),
        retain_graph=True,
        create_graph=True,
    )[0][..., 1]
    return Phi_x, Phi_y


def synthetic_data_bc_derivatives(inputs, eos=None):
    # this version gives "ideal" derivatives that are valid at the boundary,
    # pretending the boundary is actually at infinity
    Phi_x = E_0 * torch.ones_like(inputs[..., 0])
    Phi_y = torch.zeros_like(inputs[..., 1])
    return Phi_x, Phi_y


# provide the correct epsilon at the given inputs (x, y)
def true_epsilon(inputs, eos=None):
    x = inputs[..., 0]
    y = inputs[..., 1]

    rho = torch.sqrt(x**2 + y**2)

    condition1 = rho < b

    # Set epsilon for a <= rho < b
    return torch.where(condition1, epsilon_1, epsilon_0)

In [None]:
# Plot the correct electric potential

x_plt, y_plt, inputs_plt = mesh.get_eval_points()
inputs_plt.requires_grad = True
Phi_plt = synthetic_data(inputs_plt)
Phi_x_plt, Phi_y_plt = synthetic_data_derivatives(inputs_plt)
plt.imshow(
    Phi_plt.detach().cpu().squeeze().T,
    origin="lower",
    extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
    cmap="viridis",
)
plt.colorbar(label="Phi")
plt.xlabel("x")
plt.ylabel("y")

In [None]:
# Plot a visualization of the electric field

fig = plt.figure()
ax = fig.add_subplot(111)

color = 2 * np.log(
    np.hypot(Phi_x_plt.detach().cpu().numpy(), Phi_y_plt.detach().cpu().numpy())
)
ax.streamplot(
    np.linspace(x_domain[0], x_domain[1], Nx_sample - 1),
    np.linspace(y_domain[0], y_domain[1], Ny_sample - 1),
    -Phi_x_plt.detach().cpu().numpy(),
    -Phi_y_plt.detach().cpu().numpy(),
    color=color,
    linewidth=1,
    cmap=plt.cm.inferno,
    density=1,
    arrowstyle="-",
    arrowsize=1.5,
    broken_streamlines=True,
)
ax.set_aspect("equal")
ax.set_xlabel("x")
ax.set_ylabel("y")

In [None]:
Phi_model = simple_pinn.DirichletPINN(
    2,
    # [64, 64, 64, 64, 64, 64, 64, 64],
    [64, 64, 64, 64],
    1,
    mesh,
    synthetic_data,
    None,
    activation=torch.nn.Tanh(),
    use_bias_in_output_layer=True,
    upwind_only=False,
).to(DEVICE)
# Phi_model.load_state_dict(torch.load("Phi_model.pth"))

# epsilon_model = simple_pinn.SegmentationPINN(
#     2,
#     [64, 64, 64, 64, 64],
#     # [32, 32, 32, 32],
#     torch.tensor([epsilon_0, epsilon_1]).reshape(2, -1),
#     # activation=torch.nn.Tanh(),
#     dropout_rate=0.3,
# ).to(DEVICE)

epsilon_model = simple_pinn.SimplePINN(
    2,
    [64, 64, 64, 64, 64],
    1,
).to(DEVICE)

model = simple_pinn.CombinedPINN(Phi_model, epsilon_model).to(DEVICE)

In [None]:
def state_vec_to_fluxes(state_vec, eos, inputs):
    epsilon = state_vec[..., 0]
    Phi = state_vec[..., 1]
    Phi_x = torch.autograd.grad(
        Phi, inputs, torch.ones_like(Phi), retain_graph=True, create_graph=True
    )[0][..., 0]
    Phi_y = torch.autograd.grad(
        Phi, inputs, torch.ones_like(Phi), retain_graph=True, create_graph=True
    )[0][..., 1]
    x = inputs[..., 0]
    y = inputs[..., 1]
    is_bc = torch.logical_or(
        torch.logical_or(x <= x_domain[0], x >= x_domain[1]),
        torch.logical_or(y <= y_domain[0], y >= y_domain[1]),
    )
    Phi_x_bc, Phi_y_bc = synthetic_data_derivatives(inputs)
    Phi_x = torch.where(is_bc, Phi_x_bc, Phi_x)
    Phi_y = torch.where(is_bc, Phi_y_bc, Phi_y)
    epsilon_true = true_epsilon(inputs)
    epsilon = torch.where(is_bc, epsilon_true, epsilon)
    # F_x = (epsilon * Phi_x).unsqueeze(-1)
    # F_y = (epsilon * Phi_y).unsqueeze(-1)
    # as a test, just use the correct derivatives everywhere
    F_x = (epsilon * Phi_x_bc).unsqueeze(-1)
    F_y = (epsilon * Phi_y_bc).unsqueeze(-1)
    return F_x, F_y

In [None]:
(
    F_x_eval_points,
    F_y_eval_points,
    _,
    _,
    _,
    _,
) = mesh.get_training_eval_points_and_weights()

eval_points = torch.cat(
    [F_x_eval_points.reshape(-1, 2), F_y_eval_points.reshape(-1, 2)], dim=0
)


def model_to_data_comparison(model, eos=None):
    deriv_to_phi_weight = 10.0
    state_vec = model(eval_points)
    epsilon = state_vec[..., 0]
    Phi = state_vec[..., 1]
    synthetic_data_output = synthetic_data(eval_points, eos)
    Phi_x = torch.autograd.grad(
        Phi,
        eval_points,
        torch.ones_like(Phi),
        retain_graph=True,
        create_graph=True,
    )[0][..., 0]
    Phi_y = torch.autograd.grad(
        Phi,
        eval_points,
        torch.ones_like(Phi),
        retain_graph=True,
        create_graph=True,
    )[0][..., 1]
    synthetic_data_output_x, synthetic_data_output_y = synthetic_data_derivatives(
        eval_points, eos
    )
    model_and_deriv_outputs = torch.cat(
        [
            Phi,
            deriv_to_phi_weight * Phi_x,
            deriv_to_phi_weight * Phi_y,
        ],
        dim=-1,
    )
    synthetic_data_and_deriv_outputs = torch.cat(
        [
            synthetic_data_output.squeeze(-1),
            deriv_to_phi_weight * synthetic_data_output_x,
            deriv_to_phi_weight * synthetic_data_output_y,
        ],
        dim=-1,
    )
    return model_and_deriv_outputs, synthetic_data_and_deriv_outputs

In [None]:
solver = cv_solver.CVSolver(
    mesh,
    model,
    state_vec_to_fluxes,
    synthetic_data,
    None,
    model_to_data_comparison=model_to_data_comparison,
)

In [None]:
def train(
    n_epochs,
    lr_Phi,
    lr_epsilon,
    cv_pde_loss_w,
    data_loss_w,
    use_scheduler=False,
    lr_scheduler=1e-3,
):
    if lr_epsilon is None:
        lr_epsilon = lr_Phi
    if use_scheduler:
        optimizer = optim.Adam(
            chain(Phi_model.parameters(), epsilon_model.parameters()),
            lr=lr_scheduler,
        )
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=100,
            factor=0.5,
            threshold=1e-3,
            min_lr=1e-6,
            verbose=True,
        )
    else:
        optimizer = optim.Adam(
            [
                {"params": Phi_model.parameters(), "lr": lr_Phi},
                {"params": epsilon_model.parameters(), "lr": lr_epsilon},
            ]
        )
    for epoch in range(n_epochs):
        Phi_model.train()
        epsilon_model.train()
        cv_pde_loss, data_loss = solver.forward()
        optimizer.zero_grad()
        loss = cv_pde_loss_w * cv_pde_loss + data_loss_w * data_loss
        loss.backward()
        optimizer.step()
        if use_scheduler:
            scheduler.step(loss)
        if epoch % 10 == 0 or epoch == n_epochs - 1:
            Phi_model.eval()
            epsilon_model.eval()
            cv_pde_loss, data_loss = solver.forward()
            Phi_eval = Phi_model(eval_points).detach().cpu()
            epsilon_eval = epsilon_model(eval_points).detach().cpu()
            Phi_true = synthetic_data(eval_points).detach().cpu()
            epsilon_true = true_epsilon(eval_points).detach().cpu().unsqueeze(-1)
            # Calculate the L2 relative error between the true and predicted Phi and epsilon
            Phi_error = torch.norm(Phi_eval - Phi_true) / torch.norm(Phi_true)
            epsilon_error = torch.norm(epsilon_eval - epsilon_true) / torch.norm(
                epsilon_true
            )
            lr_update_str = (
                f"Current LR: {optimizer.param_groups[0]['lr']:.2e}"
                if use_scheduler
                else ""
            )
            print(
                f"Epoch {epoch}: {lr_update_str} "
                f"PDE loss: {cv_pde_loss.item():.10e}"
                f" Data loss: {data_loss.item():.3e}"
                f" Phi error: {Phi_error.item():.3e}"
                f" epsilon error: {epsilon_error.item():.3e}"
                f" argmax: {torch.abs(solver.cv_pde_loss_structure).argmax().item()}"
            )

In [None]:
train(
    1000,
    lr_Phi=0,
    lr_epsilon=1e-3,
    cv_pde_loss_w=1.0,
    data_loss_w=0.0,
    use_scheduler=False,
)

In [None]:
Phi_model.eval()
epsilon_model.eval()

In [None]:
# torch.save(Phi_model.state_dict(), "Phi_model.pth")

In [None]:
plt.imshow(solver.cv_pde_loss_structure.detach().cpu().squeeze().T, cmap="viridis")
plt.colorbar(label="Loss")
plt.xlabel("x")
plt.ylabel("y")

In [None]:
def plot_epsilon_grid(e_model, epoch):
    x, y, inputs = mesh.get_eval_points()
    e_grid = e_model(inputs).detach().cpu()
    e_gt = true_epsilon(inputs).detach().cpu().unsqueeze(-1)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

    im1 = ax1.imshow(
        e_grid.squeeze().T,
        origin="lower",
        extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
        cmap="viridis",
    )
    ax1.set_xlabel("x")
    ax1.set_ylabel("y")
    ax1.set_title(f"e_grid (Epoch {epoch})")
    # ax1.set_xticks([0, 0.25, 0.5, 0.75, 1])

    # im2 = ax2.imshow(
    #     e_gt.squeeze().T,
    #     origin="lower",
    #     extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
    #     cmap="viridis",
    # )
    # ax2.set_xlabel("x")
    # # ax2.set_ylabel("y")
    # ax2.set_title(f"e_gt (Epoch {epoch})")
    # # ax2.set_xticks([0, 0.25, 0.5, 0.75, 1])

    # Create a shared colorbar
    # cbar = fig.colorbar(im2, ax=[ax1, ax2], label="e_grid")

In [None]:
def plot_u_grid(u_model, epoch):
    x, y, inputs = mesh.get_eval_points()
    u_grid = u_model(inputs).detach().cpu()
    u_gt = synthetic_data(inputs).detach().cpu()

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

    im1 = ax1.imshow(
        u_grid.squeeze().T,
        origin="lower",
        extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
        cmap="viridis",
    )
    ax1.set_xlabel("x")
    ax1.set_ylabel("y")
    ax1.set_title(f"u_grid (Epoch {epoch})")
    # ax1.set_xticks([0, 0.25, 0.5, 0.75, 1])

    im2 = ax2.imshow(
        u_gt.squeeze().T,
        origin="lower",
        extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
        cmap="viridis",
    )
    ax2.set_xlabel("x")
    # ax2.set_ylabel("y")
    ax2.set_title(f"u_gt (Epoch {epoch})")
    # ax2.set_xticks([0, 0.25, 0.5, 0.75, 1])

    # Create a shared colorbar
    cbar = fig.colorbar(im2, ax=[ax1, ax2], label="u_grid")

In [None]:
x, y, inputs = mesh.get_eval_points()
plot_u_grid(Phi_model, -1)

In [None]:
plot_epsilon_grid(epsilon_model, -1)