In [None]:
import numpy as np
import torch
import torch.optim
import simple_pinn
import matplotlib
import matplotlib.pyplot as plt
from pyrecorder.recorder import Recorder
from pyrecorder.writers.video import Video
from pyrecorder.converters.matplotlib import Matplotlib
import time

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
x_domain = y_domain = [0, 1]
Nx_sample = Ny_sample = 100
x = torch.linspace(x_domain[0], x_domain[1], Nx_sample, requires_grad=True).to(DEVICE)
y = torch.linspace(y_domain[0], y_domain[1], Ny_sample, requires_grad=True).to(DEVICE)
X, Y = torch.meshgrid(x, y, indexing="ij")
inputs = torch.stack((X, Y), dim=-1)

In [None]:
K_0 = 1.0
K_1 = 4.0


def synthetic_data(inputs):
    u = torch.where(
        inputs[..., 0] < 0.5, inputs[..., 0], 0.5 + (K_0 / K_1) * (inputs[..., 0] - 0.5)
    )
    return u.unsqueeze(-1)

In [None]:
u_model = simple_pinn.SimplePINN(
    2,
    [64, 64, 64, 64, 64, 64, 64, 64],
    1,
    # activation=torch.nn.Tanh(),
    # use_bias_in_output_layer=True,
).to(DEVICE)

K_model = simple_pinn.SegmentationPINN(
    2,
    [64, 64, 64, 64, 64, 64, 64, 64],
    torch.tensor([K_0, K_1]).reshape(2, -1),
).to(DEVICE)

In [None]:
def eval_bulk_loss(u_model, K_model, inputs):
    u = u_model(inputs)
    K = K_model(inputs)
    u_x = torch.autograd.grad(
        u, inputs, torch.ones_like(u), retain_graph=True, create_graph=True
    )[0][..., 0]
    u_y = torch.autograd.grad(
        u, inputs, torch.ones_like(u), retain_graph=True, create_graph=True
    )[0][..., 1]
    d_kux_dx = torch.autograd.grad(
        u_x,
        inputs,
        torch.ones_like(u_x),
        retain_graph=True,
        create_graph=True,
    )[0][..., 0]
    d_kuy_dy = torch.autograd.grad(
        u_y,
        inputs,
        torch.ones_like(u_y),
        retain_graph=True,
        create_graph=True,
    )[0][..., 1]
    return torch.sum(torch.square(d_kux_dx + d_kuy_dy))

In [None]:
# impose Dirichlet boundary conditions on Gamma1 (the right boundary): u(1, y) = u_gt(1, y)
def eval_Gamma1_loss(model, inputs):
    Gamma1_inputs = inputs[-1, :, :]
    u = model(Gamma1_inputs)
    # u_gt = synthetic_data(Gamma1_inputs)
    u_gt = 0.5 * (1.0 + (K_0 / K_1))
    return torch.sum(torch.square(u - u_gt))

In [None]:
# impose Neumann boundary conditions on Gamma2 (the top boundary): K(x, 1) * u_y(x, 1) = 0
def eval_Gamma2_loss(u_model, K_model, inputs):
    Gamma2_inputs = inputs[:, -1, :]
    u = u_model(Gamma2_inputs)
    u_y = torch.autograd.grad(
        u, Gamma2_inputs, torch.ones_like(u), retain_graph=True, create_graph=True
    )[0][..., 1]
    K = K_model(Gamma2_inputs)
    return torch.sum(torch.square(K * u_y.unsqueeze(-1)))

In [None]:
# impose Neumann boundary conditions on Gamma0 (the bottom boundary): K(x, 0) * u_y(x, 0) = 0
def eval_Gamma0_loss(u_model, K_model, inputs):
    Gamma0_inputs = inputs[:, 0, :]
    u = u_model(Gamma0_inputs)
    u_y = torch.autograd.grad(
        u, Gamma0_inputs, torch.ones_like(u), retain_graph=True, create_graph=True
    )[0][..., 1]
    K = K_model(Gamma0_inputs)
    return torch.sum(torch.square(K * u_y.unsqueeze(-1)))

In [None]:
# impose Neumann boundary conditions on Gamma3 (the left boundary): K(0, y) * u_x(0, y) - K_0 = 0
def eval_Gamma3_loss(u_model, K_model, inputs):
    Gamma3_inputs = inputs[0, :, :]
    u = u_model(Gamma3_inputs)
    u_x = torch.autograd.grad(
        u, Gamma3_inputs, torch.ones_like(u), retain_graph=True, create_graph=True
    )[0][..., 0]
    K = K_model(Gamma3_inputs)
    return torch.sum(torch.square(K * u_x.unsqueeze(-1) - K_0))

In [None]:
def eval_data_loss(u_model, K_model, inputs):
    u = u_model(inputs)
    u_gt = synthetic_data(inputs)
    return torch.sum(torch.square(u - u_gt))

In [None]:
def plot_K_grid(K_grid, epoch):
    plt.imshow(
        K_grid.squeeze().T,
        origin="lower",
        extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
        cmap="viridis",
        vmin=1,
        vmax=4,
    )
    plt.colorbar(label="K_grid")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(f"K_grid (Epoch {epoch})")
    plt.xticks([0, 0.25, 0.5, 0.75, 1])

In [None]:
def plot_u_grid(u_model, epoch):
    u_grid = u_model(inputs).detach().cpu()
    u_gt = synthetic_data(inputs).detach().cpu()

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

    im1 = ax1.imshow(
        u_grid.squeeze().T,
        origin="lower",
        extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
        cmap="viridis",
    )
    ax1.set_xlabel("x")
    ax1.set_ylabel("y")
    ax1.set_title(f"u_grid (Epoch {epoch})")
    ax1.set_xticks([0, 0.25, 0.5, 0.75, 1])

    im2 = ax2.imshow(
        u_gt.squeeze().T,
        origin="lower",
        extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
        cmap="viridis",
    )
    ax2.set_xlabel("x")
    ax2.set_ylabel("y")
    ax2.set_title(f"u_gt (Epoch {epoch})")
    ax2.set_xticks([0, 0.25, 0.5, 0.75, 1])

    # Create a shared colorbar
    cbar = fig.colorbar(im2, ax=[ax1, ax2], label="u_grid")

In [None]:
def plot_error_grid(u_model, epoch):
    u_grid = u_model(inputs).detach().cpu()
    u_gt = synthetic_data(inputs).detach().cpu()

    relative_error = torch.norm(u_grid - u_gt, p=2, dim=-1) / torch.norm(
        u_gt, p=2, dim=-1
    )

    plt.imshow(
        relative_error.squeeze().T,
        origin="lower",
        extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
        cmap="viridis",
    )
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(f"Relative Error (Epoch {epoch})")
    plt.xticks([0, 0.25, 0.5, 0.75, 1])
    plt.colorbar(label="Relative Error")

In [None]:
u_model_lr = 1e-4
K_model_lr = 1e-4

u_model_params = u_model.parameters()
K_model_params = K_model.parameters()

optimizer = torch.optim.Adam(
    [
        {"params": u_model_params, "lr": u_model_lr},
        # {"params": K_model_params, "lr": K_model_lr},
    ]
)

In [None]:
converter = Matplotlib(dpi=120)
writer = Video("K_training.mp4", fps=24)

bulk_loss_history = []
Gamma0_loss_history = []
Gamma1_loss_history = []
Gamma2_loss_history = []
Gamma3_loss_history = []
data_loss_history = []
total_loss_history = []

n_epochs = 10_000

with Recorder(writer) as rec:
    for n in range(n_epochs):
        optimizer.zero_grad()
        bulk_loss = eval_bulk_loss(u_model, K_model, inputs)
        Gamma0_loss = eval_Gamma0_loss(u_model, K_model, inputs)
        Gamma1_loss = eval_Gamma1_loss(u_model, inputs)
        Gamma2_loss = eval_Gamma2_loss(u_model, K_model, inputs)
        Gamma3_loss = eval_Gamma3_loss(u_model, K_model, inputs)
        data_loss = eval_data_loss(u_model, K_model, inputs)
        total_loss = (
            bulk_loss
            + Gamma0_loss
            + Gamma1_loss
            + Gamma2_loss
            + Gamma3_loss
            + data_loss
        )
        bulk_loss_history.append(bulk_loss.item())
        Gamma0_loss_history.append(Gamma0_loss.item())
        Gamma1_loss_history.append(Gamma1_loss.item())
        Gamma2_loss_history.append(Gamma2_loss.item())
        Gamma3_loss_history.append(Gamma3_loss.item())
        total_loss_history.append(total_loss.item())
        data_loss_history.append(data_loss.item())
        total_loss.backward()
        optimizer.step()
        if n % 10 == 0:
            print(
                f"Epoch {n}: Total Loss = {total_loss.item():.4e}, Bulk Loss = {bulk_loss.item():.4e}, Gamma0 Loss = {Gamma0_loss.item():.4e}, Gamma1 Loss = {Gamma1_loss.item():.4e}, Gamma2 Loss = {Gamma2_loss.item():.4e}, Gamma3 Loss = {Gamma3_loss.item():.4e}, Data Loss = {data_loss.item():.4e}"
            )
            K_grid = K_model(inputs).detach().cpu()
            plot_K_grid(K_grid, n)
            rec.record()
            time.sleep(0.1)

In [None]:
plt.figure()

# Plot the loss history components
plt.plot(bulk_loss_history, label="Bulk Loss")
plt.plot(Gamma0_loss_history, label="Gamma0 Loss")
plt.plot(Gamma1_loss_history, label="Gamma1 Loss")
plt.plot(Gamma2_loss_history, label="Gamma2 Loss")
plt.plot(Gamma3_loss_history, label="Gamma3 Loss")

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss History Components")
plt.semilogy()
plt.legend()

plt.show()

In [None]:
plt.figure()
K_grid = K_model(inputs).detach().cpu()
plot_K_grid(K_grid, n_epochs)
plt.show()

In [None]:
plot_u_grid(u_model, n_epochs)
plt.show()

In [None]:
plot_error_grid(u_model, n_epochs)

In [None]:
# Define the y values for sampling
y_values = np.linspace(y_domain[0], y_domain[1], 10)

# Initialize an empty list to store the results
results = []

# Define a color map for assigning colors to y values
color_map = matplotlib.colormaps.get_cmap("tab10")

# Iterate over the y values and evaluate the model and synthetic data
for i, y in enumerate(y_values):
    # Create a tensor with the x values
    x_values = torch.linspace(x_domain[0], x_domain[1], 100)

    # Create the inputs tensor
    plt_inputs = torch.stack((x_values, torch.full_like(x_values, y)), dim=-1).to(
        DEVICE
    )

    # Evaluate the model and synthetic data
    u_model_output = u_model(plt_inputs).detach().cpu().numpy()
    synthetic_data_output = synthetic_data(plt_inputs).detach().cpu().numpy()

    # Append the results to the list
    results.append(
        (x_values.numpy(), u_model_output.squeeze(), synthetic_data_output.squeeze(), y)
    )

# Plot the results
for i, (x_values, u_model_output, synthetic_data_output, y) in enumerate(results):
    color = color_map(i % color_map.N)  # Assign a unique color to each y value
    plt.plot(x_values, u_model_output, label=f"u_model (y={y:.2f})", color=color)
    plt.plot(
        x_values,
        synthetic_data_output,
        label=f"synthetic_data (y={y:.2f})",
        linestyle="--",
        color=color,
    )

plt.xlabel("x")
plt.ylabel("Output")
plt.title("Comparison of u_model and synthetic data")

# Place the legend outside the plot
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

plt.show()

In [None]:
# Define the y values for sampling
y_values = np.linspace(y_domain[0], y_domain[1], 10)

# Initialize an empty list to store the results
results = []

# Define a color map for assigning colors to y values
color_map = matplotlib.colormaps.get_cmap("tab10")

# Iterate over the y values and evaluate the model and synthetic data
for i, y in enumerate(y_values):
    # Create a tensor with the x values
    x_values = torch.linspace(x_domain[0], x_domain[1], 100)

    # Create the inputs tensor
    plt_inputs = torch.stack((x_values, torch.full_like(x_values, y)), dim=-1).to(
        DEVICE
    )

    # Evaluate the model and synthetic data
    K_grid_output = K_model(plt_inputs).detach().cpu().numpy()

    # Append the results to the list
    results.append((x_values.numpy(), K_grid_output.squeeze(), y))

# Plot the results
for i, (x_values, K_grid_output, y) in enumerate(results):
    color = color_map(i % color_map.N)  # Assign a unique color to each y value
    plt.plot(x_values, K_grid_output, label=f"y={y:.2f}", color=color)

plt.xlabel("x")
plt.ylabel("K_grid")
plt.title("K_grid as a function of x for various y")
plt.legend()

plt.show()