In [None]:
%matplotlib widget
import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import simple_pinn
import utils
import cv_mesh
import cv_solver

# for dev purposes, reload these modules each time this cell is run
import importlib
importlib.reload(simple_pinn)
importlib.reload(utils)
importlib.reload(cv_mesh)
importlib.reload(cv_solver)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 

# torch.set_default_dtype(torch.float64)

In [None]:
# t_domain = [0.0, 1.0]
# x_domain = [-2.5, 2.5]
# Nt = 128
# Nx = 128
# mesh = cv_mesh.CVMesh(
#     t_domain, x_domain, Nt, Nx, quad_pts=(6, 6), quad_rule="gauss-legendre"
# )
# mesh.to(DEVICE)


# # let's try the set of primites (rho, M = rho * v, E)


# def euler_ideal_gas_eos(state_vec, gamma=1.4):
#     rho = state_vec[..., 0]
#     M = state_vec[..., 1]
#     E = state_vec[..., 2]

#     e = torch.divide(E - 0.5 * torch.divide(M * M, rho), rho)

#     eps = 1.0e-6
#     e = torch.clamp(e, min=eps)
#     rho = torch.clamp(rho, min=eps)

#     s = torch.log(e) / (gamma - 1.0) - torch.log(rho)
#     p = (gamma - 1.0) * rho * e
#     return s, p


# def euler_state_vec_to_fluxes(state_vec, eos):
#     rho = state_vec[..., 0]
#     M = state_vec[..., 1]
#     E = state_vec[..., 2]

#     _, p = eos(state_vec)

#     F_t = torch.stack([rho, M, E], dim=-1)
#     F_x = torch.stack(
#         [M, torch.divide(M * M, rho) + p, (E + p) * torch.divide(M, rho)], dim=-1
#     )
#     return F_t, F_x


# def euler_state_vec_to_entropy_fluxes(state_vec, eos):
#     rho = state_vec[..., 0]
#     M = state_vec[..., 1]

#     s, _ = eos(state_vec)

#     F_t = (-s * rho).unsqueeze(-1)
#     F_x = (-s * M).unsqueeze(-1)
#     return F_t, F_x


# def sod_ic_state_vec_evaluation(eval_points, eos):
#     gamma = 1.4
#     rho_left, rho_right = 3.0, 1.0
#     E_left, E_right = 3.0 / (gamma - 1.0), 1.0 / (gamma - 1.0)
#     rho = torch.where(eval_points[..., 1] < 0.0, rho_left, rho_right)
#     M = torch.zeros_like(rho)
#     E = torch.where(eval_points[..., 1] < 0.0, E_left, E_right)
#     return torch.stack([rho, M, E], dim=-1)


# def analytic_sod_soln(eval_points, eos, gamma=1.4):
#     rho_left, rho_right = 3.0, 1.0
#     M_left, M_right = 0.0, 0.0
#     E_left, E_right = 3.0 / (gamma - 1.0), 1.0 / (gamma - 1.0)
#     left_state_vec = (rho_left, M_left, E_left)
#     right_state_vec = (rho_right, M_right, E_right)
#     return utils.analytic_sod_soln(eval_points, left_state_vec, right_state_vec, gamma)


# model = simple_pinn.DirichletPINN(
#     2,
#     [64, 64, 64, 64, 64, 64, 64, 64],
#     3,
#     mesh,
#     sod_ic_state_vec_evaluation,
#     euler_ideal_gas_eos,
#     use_bias_in_output_layer=False,
# ).to(DEVICE)
# print(model)

# solver = cv_solver.CVSolver(
#     mesh,
#     model,
#     euler_state_vec_to_fluxes,
#     euler_state_vec_to_entropy_fluxes,
#     sod_ic_state_vec_evaluation,
#     analytic_sod_soln,
#     euler_ideal_gas_eos,
#     component_names=["rho", "M", "E"],
# )

In [None]:
t_domain = [0.0, 1.0]
x_domain = [-2.5, 2.5]
Nt = 128
Nx = 128
mesh = cv_mesh.CVMesh(
    t_domain, x_domain, Nt, Nx, quad_pts=(6, 6), quad_rule="gauss-legendre"
)
mesh.to(DEVICE)


# let's try the set of primites (rho, v, E)


def euler_ideal_gas_eos(state_vec, gamma=1.4):
    rho = state_vec[..., 0]
    v = state_vec[..., 1]
    E = state_vec[..., 2]

    e = torch.divide(E - 0.5 * rho * v * v, rho)

    eps = 1.0e-6
    e = torch.clamp(e, min=eps)
    rho = torch.clamp(rho, min=eps)

    s = torch.log(e) / (gamma - 1.0) - torch.log(rho)
    p = (gamma - 1.0) * rho * e
    return s, p


def euler_state_vec_to_fluxes(state_vec, eos):
    rho = state_vec[..., 0]
    v = state_vec[..., 1]
    E = state_vec[..., 2]

    _, p = eos(state_vec)

    F_t = torch.stack([rho, rho * v, E], dim=-1)
    F_x = torch.stack([rho * v, rho * v * v + p, (E + p) * v], dim=-1)
    return F_t, F_x


def euler_state_vec_to_entropy_fluxes(state_vec, eos):
    rho = state_vec[..., 0]
    v = state_vec[..., 1]

    s, _ = eos(state_vec)

    F_t = (-s * rho).unsqueeze(-1)
    F_x = (-s * rho * v).unsqueeze(-1)
    return F_t, F_x


def sod_ic_state_vec_evaluation(eval_points, eos):
    gamma = 1.4
    rho_left, rho_right = 3.0, 1.0
    E_left, E_right = 3.0 / (gamma - 1.0), 1.0 / (gamma - 1.0)
    rho = torch.where(eval_points[..., 1] < 0.0, rho_left, rho_right)
    v = torch.zeros_like(rho)
    E = torch.where(eval_points[..., 1] < 0.0, E_left, E_right)
    return torch.stack([rho, v, E], dim=-1)


def analytic_sod_soln(eval_points, eos, gamma=1.4):
    rho_left, rho_right = 3.0, 1.0
    M_left, M_right = 0.0, 0.0
    E_left, E_right = 3.0 / (gamma - 1.0), 1.0 / (gamma - 1.0)
    left_state_vec = (rho_left, M_left, E_left)
    right_state_vec = (rho_right, M_right, E_right)
    state_vec = utils.analytic_sod_soln(
        eval_points, left_state_vec, right_state_vec, gamma
    )
    state_vec[..., 1] = torch.divide(state_vec[..., 1], state_vec[..., 0])
    return state_vec


model = simple_pinn.DirichletPINN(
    2,
    [64, 64, 64, 64, 64, 64, 64, 64],
    3,
    mesh,
    sod_ic_state_vec_evaluation,
    euler_ideal_gas_eos,
    use_bias_in_output_layer=False,
).to(DEVICE)
print(model)

solver = cv_solver.CVSolver(
    mesh,
    model,
    euler_state_vec_to_fluxes,
    euler_state_vec_to_entropy_fluxes,
    sod_ic_state_vec_evaluation,
    analytic_sod_soln,
    euler_ideal_gas_eos,
    component_names=["rho", "v", "E"],
)

In [None]:
def train(optimizer, n_epochs, entropy_loss_weight=0.0):
    for epoch in range(n_epochs):
        cv_pde_loss, cv_entropy_loss = solver.forward()
        optimizer.zero_grad()
        loss = cv_pde_loss + entropy_loss_weight * cv_entropy_loss
        loss.backward()
        optimizer.step()
        if epoch % 100 == 0:
            print(
                f"Epoch {epoch}: PDE loss: {cv_pde_loss.item():.3e}, entropy loss: {cv_entropy_loss.item():.3e}."
            )

In [None]:
optimizer = optim.Adam(model.parameters(), lr=torch.tensor(1e-2))
train(optimizer, 2000)
optimizer = optim.Adam(model.parameters(), lr=torch.tensor(1e-3))
train(optimizer, 20_000, entropy_loss_weight=0.1)
optimizer = optim.Adam(model.parameters(), lr=torch.tensor(1e-4))
train(optimizer, 20_000, entropy_loss_weight=0.1)

In [None]:
solver.plot_loss_history()

In [None]:
solver.plot_components(
    0,
    centered=False,
    with_ics=True,
    loss_to_plot="PDE",
)

In [None]:
solver.plot_components(
    -1,
    centered=False,
    with_ics=True,
    with_analytic_soln=True,
    loss_to_plot="PDE",
)

In [None]:
model = solver.model
mesh = solver.mesh

In [None]:
solver.plot_components(
    -1,
    centered=False,
    with_ics=True,
    with_analytic_soln=True,
)

In [None]:
solver.animate_components("sod.gif", with_ics=True)