In [None]:
%matplotlib widget
import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import simple_pinn
import cv_mesh
import cv_solver

# for dev purposes, reload these modules each time this cell is run
import importlib
importlib.reload(simple_pinn)
importlib.reload(cv_mesh)
importlib.reload(cv_solver)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 

In [None]:
t_domain = [0.0, 1.0]
x_domain = [-1.0, 1.0]
Nt = 100
Nx = 300
mesh = cv_mesh.CVMesh(
    t_domain, x_domain, Nt, Nx, quad_pts=(4, 4), quad_rule="composite_trapezoid"
)
mesh.to(DEVICE)


def burgers_state_vec_to_fluxes(state_vec, eos):
    u = state_vec[..., 0]
    F_t = u.unsqueeze(-1)
    F_x = (0.5 * u * u).unsqueeze(-1)
    return F_t, F_x


def burgers_state_vec_to_entropy_fluxes(state_vec, eos):
    u = state_vec[..., 0]
    F_t = (u * u).unsqueeze(-1)
    F_x = ((2.0 / 3.0) * u * u * u).unsqueeze(-1)
    return F_t, F_x


def burgers_ic_state_vec_evaluation(eval_points, eos):
    return torch.where(eval_points[..., 1] < 0, 0.0, 1.0).unsqueeze(-1)


def burgers_analytic_soln(eval_points, eos, visc=1.0e-5):
    x = eval_points[..., 1]
    t = eval_points[..., 0]
    # left side of the solution
    u = torch.where(x < 0, 0.0, 1.0)
    # right side of the solution, *ahead* of the shock front
    u = torch.where((x - t) > 0, 1.0, u)
    # right side of the solution, behind the shock front, i.e., the rarefaction fan
    u = torch.where(torch.logical_and(x > 0, x - t < 0), x / t, u)
    return u.unsqueeze(-1)


model = simple_pinn.DirichletPINN(
    2, [32, 32, 32], 1, mesh, burgers_ic_state_vec_evaluation, None
).to(DEVICE)
print(model)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

cv_solver = cv_solver.CVSolver(
    mesh,
    model,
    burgers_state_vec_to_fluxes,
    burgers_ic_state_vec_evaluation,
    None,
    state_vec_to_entropy_fluxes=burgers_state_vec_to_entropy_fluxes,
    analytic_soln=burgers_analytic_soln,
)

In [None]:
n_epochs = 20_000
for epoch in range(n_epochs):
    cv_pde_loss, cv_entropy_loss = cv_solver.forward()
    optimizer.zero_grad()
    loss = cv_pde_loss + 0.1 * cv_entropy_loss
    loss.backward()
    optimizer.step()
    if epoch % 100 == 0:
        print(
            f"Epoch {epoch}: PDE loss: {cv_pde_loss.item():.3e}, Entropy loss: {cv_entropy_loss.item():.3e}"
        )

In [None]:
cv_solver.plot_loss_history()

In [None]:
cv_solver.plot_components(0, with_ics=True, loss_to_plot="pde")

In [None]:
cv_solver.plot_components(-1, with_analytic_soln=True)