In [None]:
import numpy as np
import torch
import torch.optim
import simple_pinn
import matplotlib
import matplotlib.pyplot as plt
from pyrecorder.recorder import Recorder
from pyrecorder.writers.video import Video
from pyrecorder.converters.matplotlib import Matplotlib
from itertools import chain
import time

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
x_domain = y_domain = [0, 1]
Nx_sample = Ny_sample = 100
x = torch.linspace(x_domain[0], x_domain[1], Nx_sample, requires_grad=True).to(DEVICE)
y = torch.linspace(y_domain[0], y_domain[1], Ny_sample, requires_grad=True).to(DEVICE)
X, Y = torch.meshgrid(x, y, indexing="ij")
inputs = torch.stack((X, Y), dim=-1)

In [None]:
K_0 = 1.0
K_1 = 10.0


def synthetic_data(inputs):
    u = torch.sin(2.0 * torch.pi * inputs[..., 0]) * torch.sin(
        2.0 * torch.pi * inputs[..., 1]
    )
    return u.unsqueeze(-1)

In [None]:
u_model = simple_pinn.SimplePINN(
    2,
    # [64, 64, 64, 64, 64, 64, 64, 64],
    [32, 32, 32, 32],
    1,
    activation=torch.nn.Tanh(),
    # use_bias_in_output_layer=True,
).to(DEVICE)

K_model = simple_pinn.SegmentationPINN(
    2,
    # [64, 64, 64, 64, 64, 64, 64, 64],
    [32, 32, 32, 32],
    torch.tensor([K_0, K_1]).reshape(2, -1),
    # activation=torch.nn.Tanh(),
).to(DEVICE)

In [None]:
def eval_bulk_loss(u_model, K_model, inputs):
    u = u_model(inputs)
    K = K_model(inputs)
    u_x = torch.autograd.grad(
        u, inputs, torch.ones_like(u), retain_graph=True, create_graph=True
    )[0][..., 0]
    u_y = torch.autograd.grad(
        u, inputs, torch.ones_like(u), retain_graph=True, create_graph=True
    )[0][..., 1]
    d_kux_dx = torch.autograd.grad(
        u_x,
        inputs,
        torch.ones_like(u_x),
        retain_graph=True,
        create_graph=True,
    )[0][..., 0]
    d_kuy_dy = torch.autograd.grad(
        u_y,
        inputs,
        torch.ones_like(u_y),
        retain_graph=True,
        create_graph=True,
    )[0][..., 1]
    rhs = (
        8.0
        * torch.pi**2
        * torch.where(
            inputs[..., 0] < 0.25,
            K_0,
            K_1,
        )
        * torch.sin(2.0 * torch.pi * inputs[..., 0])
        * torch.sin(2.0 * torch.pi * inputs[..., 1])
    )
    return torch.mean(torch.square(K.squeeze() * (d_kux_dx + d_kuy_dy) + rhs))

In [None]:
# impose Dirichlet boundary conditions on all boundaries
def eval_bc_loss(model, inputs):
    # bottom_boundary = inputs[0, :, :]
    # top_boundary = inputs[-1, :, :]
    # left_boundary = inputs[:, 0, :]
    # right_boundary = inputs[:, -1, :]
    # u_bottom = model(bottom_boundary)
    # u_top = model(top_boundary)
    # u_left = model(left_boundary)
    # u_right = model(right_boundary)
    # u_bc = torch.zeros_like(u_bottom)
    # return (
    #     torch.mean(torch.square(u_bottom - u_bc))
    #     + torch.mean(torch.square(u_top - u_bc))
    #     + torch.mean(torch.square(u_left - u_bc))
    #     + torch.mean(torch.square(u_right - u_bc))
    # )
    u_model = model(inputs)
    return torch.mean(torch.square(u_model))

In [None]:
def eval_data_loss(u_model, K_model, inputs):
    u = u_model(inputs)
    u_gt = synthetic_data(inputs)
    return torch.mean(torch.square(u - u_gt))

In [None]:
def plot_K_grid(K_grid, epoch):
    plt.imshow(
        K_grid.squeeze().T,
        origin="lower",
        extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
        cmap="viridis",
        vmin=K_0,
        vmax=K_1,
    )
    plt.colorbar(label="K_grid")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(f"K_grid (Epoch {epoch})")
    plt.xticks([0, 0.25, 0.5, 0.75, 1])

In [None]:
def plot_u_grid(u_model, epoch):
    u_grid = u_model(inputs).detach().cpu()
    u_gt = synthetic_data(inputs).detach().cpu()

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

    im1 = ax1.imshow(
        u_grid.squeeze().T,
        origin="lower",
        extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
        cmap="viridis",
    )
    ax1.set_xlabel("x")
    ax1.set_ylabel("y")
    ax1.set_title(f"u_grid (Epoch {epoch})")
    ax1.set_xticks([0, 0.25, 0.5, 0.75, 1])

    im2 = ax2.imshow(
        u_gt.squeeze().T,
        origin="lower",
        extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
        cmap="viridis",
    )
    ax2.set_xlabel("x")
    ax2.set_ylabel("y")
    ax2.set_title(f"u_gt (Epoch {epoch})")
    ax2.set_xticks([0, 0.25, 0.5, 0.75, 1])

    # Create a shared colorbar
    cbar = fig.colorbar(im2, ax=[ax1, ax2], label="u_grid")

In [None]:
def plot_error_grid(u_model, epoch):
    u_grid = u_model(inputs).detach().cpu()
    u_gt = synthetic_data(inputs).detach().cpu()

    relative_error = torch.norm(u_grid - u_gt, p=2, dim=-1) / torch.norm(
        u_gt, p=2, dim=-1
    )

    plt.imshow(
        relative_error.squeeze().T,
        origin="lower",
        extent=[x_domain[0], x_domain[1], y_domain[0], y_domain[1]],
        cmap="viridis",
    )
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(f"Relative Error (Epoch {epoch})")
    plt.xticks([0, 0.25, 0.5, 0.75, 1])
    plt.colorbar(label="Relative Error")

In [None]:
bulk_loss_history = []
bc_loss_history = []
data_loss_history = []
total_loss_history = []

In [None]:
# note this is so simple in part because we're solving a problem on the unit square
def sample_bulk_inputs(N):
    randomly_sampled_bulk_inputs = torch.rand(N, N, 2, requires_grad=True).to(DEVICE)
    return randomly_sampled_bulk_inputs

In [None]:
def sample_bc_inputs(N):
    lower_bc_inputs = torch.stack([torch.rand(N), torch.zeros(N)], dim=-1).to(DEVICE)
    upper_bc_inputs = torch.stack([torch.rand(N), torch.ones(N)], dim=-1).to(DEVICE)
    left_bc_inputs = torch.stack([torch.zeros(N), torch.rand(N)], dim=-1).to(DEVICE)
    right_bc_inputs = torch.stack([torch.ones(N), torch.rand(N)], dim=-1).to(DEVICE)
    bc_inputs = torch.cat(
        [lower_bc_inputs, upper_bc_inputs, left_bc_inputs, right_bc_inputs]
    )
    bc_inputs.requires_grad = True
    return bc_inputs

In [None]:
def assemble_bc_inputs(inputs):
    lower_bc_inputs = inputs[0, :, :]
    upper_bc_inputs = inputs[-1, :, :]
    left_bc_inputs = inputs[:, 0, :]
    right_bc_inputs = inputs[:, -1, :]
    return torch.cat(
        [lower_bc_inputs, upper_bc_inputs, left_bc_inputs, right_bc_inputs]
    )

In [None]:
# u_model_lr = 1e-5
# K_model_lr = 1e-4

# u_model_params = u_model.parameters()
# K_model_params = K_model.parameters()

# optimizer = torch.optim.Adam(
#     [
#         {"params": u_model_params, "lr": u_model_lr},
#         # {"params": K_model_params, "lr": K_model_lr},
#     ]
# )

# # optimizer = torch.optim.LBFGS(
# #     u_model.parameters(), lr=0.01, max_iter=20, line_search_fn="strong_wolfe"
# # )

In [None]:
# bulk_inputs = sample_bulk_inputs(100)
# bc_inputs = sample_bc_inputs(100)

bulk_inputs = inputs
bc_inputs = assemble_bc_inputs(inputs)


def closure():
    optimizer.zero_grad()
    bulk_loss = eval_bulk_loss(u_model, K_model, bulk_inputs)
    bc_loss = eval_bc_loss(u_model, bc_inputs)
    data_loss = eval_data_loss(u_model, K_model, inputs)
    total_loss = bulk_loss + bc_loss + data_loss
    total_loss.backward()
    return total_loss

In [None]:
class TrainModel:
    def __init__(self, record=False):
        self.epochs_total = 0
        self.record = record

    def train(self, n_epochs, lr_u, lr_K=None):
        if lr_K is None:
            lr_K = lr_u
            optimizer = torch.optim.Adam(
                [
                    {"params": u_model.parameters(), "lr": lr_u},
                    {"params": K_model.parameters(), "lr": lr_K},
                ]
            )

        for n in range(n_epochs):
            bulk_loss = eval_bulk_loss(u_model, K_model, bulk_inputs)
            bc_loss = eval_bc_loss(u_model, bc_inputs)
            data_loss = eval_data_loss(u_model, K_model, inputs)
            total_loss = bulk_loss + bc_loss  # + data_loss
            bulk_loss_history.append(bulk_loss.item())
            bc_loss_history.append(bc_loss.item())
            data_loss_history.append(data_loss.item())
            total_loss_history.append(total_loss.item())
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            if n % 10 == 0:
                print(
                    f"lr = {lr_u}, {lr_K}; Epoch {n}: Total Loss = {total_loss.item():.4e}, Bulk Loss = {bulk_loss.item():.4e}, BC Loss = {bc_loss.item():.4e}, Data Loss = {data_loss.item():.4e}"
                )
                K_grid = K_model(inputs).detach().cpu()
                if self.record:
                    plot_K_grid(K_grid, self.epochs_total)
                    rec.record()
                    time.sleep(0.1)
            self.epochs_total += 1

In [None]:
train_model = TrainModel(record=True)

In [None]:
converter = Matplotlib(dpi=120)
writer = Video("K_training_v2.mp4", fps=24)
with Recorder(writer) as rec:
    train_model.train(5000, 1e-2)
    train_model.train(10_000, 1e-3)
    train_model.train(10_000, 1e-4)
    train_model.train(20_000, 1e-5)
    train_model.train(20_000, 1e-6)
    train_model.train(20_000, 1e-7)

In [None]:
n_epochs = 100_000

In [None]:
# converter = Matplotlib(dpi=120)
# writer = Video("K_training.mp4", fps=24)

# n_epochs = 20_000

# bulk_inputs = sample_bulk_inputs(100)
# bc_inputs = sample_bc_inputs(100)

# with Recorder(writer) as rec:
#     for n in range(n_epochs):
#         optimizer.zero_grad()
#         bulk_loss = eval_bulk_loss(u_model, K_model, bulk_inputs)
#         bc_loss = eval_bc_loss(u_model, bc_inputs)
#         total_loss = bulk_loss + bc_loss
#         bulk_loss_history.append(bulk_loss.item())
#         bc_loss_history.append(bc_loss.item())
#         total_loss_history.append(total_loss.item())
#         total_loss.backward()
#         optimizer.step()
#         if n % 100 == 0:
#             print(
#                 f"Epoch {n}: Total Loss = {total_loss.item():.4e}, Bulk Loss = {bulk_loss.item():.4e}, BC Loss = {bc_loss.item():.4e}"
#             )
#             bulk_inputs = sample_bulk_inputs(100)
#             bc_inputs = sample_bc_inputs(100)
#             K_grid = K_model(inputs).detach().cpu()
#             # plot_K_grid(K_grid, n)
#             # rec.record()
#             # time.sleep(0.1)

In [None]:
plt.figure()

# Plot the loss history components
plt.plot(bulk_loss_history, label="Bulk Loss")
plt.plot(bc_loss_history, label="BC Loss")
plt.plot(data_loss_history, label="Data Loss")

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss History Components")
plt.loglog()
plt.legend()

plt.show()

In [None]:
plt.figure()
K_grid = K_model(inputs).detach().cpu()
plot_K_grid(K_grid, n_epochs)
plt.show()

In [None]:
plot_u_grid(u_model, n_epochs)
plt.show()

In [None]:
plot_error_grid(u_model, n_epochs)

In [None]:
# Define the y values for sampling
y_values = np.linspace(y_domain[0], y_domain[1], 10)

# Initialize an empty list to store the results
results = []

# Define a color map for assigning colors to y values
color_map = matplotlib.colormaps.get_cmap("tab10")

# Iterate over the y values and evaluate the model and synthetic data
for i, y in enumerate(y_values):
    # Create a tensor with the x values
    x_values = torch.linspace(x_domain[0], x_domain[1], 100)

    # Create the inputs tensor
    plt_inputs = torch.stack((x_values, torch.full_like(x_values, y)), dim=-1).to(
        DEVICE
    )

    # Evaluate the model and synthetic data
    u_model_output = u_model(plt_inputs).detach().cpu().numpy()
    synthetic_data_output = synthetic_data(plt_inputs).detach().cpu().numpy()

    # Append the results to the list
    results.append(
        (x_values.numpy(), u_model_output.squeeze(), synthetic_data_output.squeeze(), y)
    )

# Plot the results
for i, (x_values, u_model_output, synthetic_data_output, y) in enumerate(results):
    color = color_map(i % color_map.N)  # Assign a unique color to each y value
    plt.plot(x_values, u_model_output, label=f"u_model (y={y:.2f})", color=color)
    plt.plot(
        x_values,
        synthetic_data_output,
        label=f"synthetic_data (y={y:.2f})",
        linestyle="--",
        color=color,
    )

plt.xlabel("x")
plt.ylabel("Output")
plt.title("Comparison of u_model and synthetic data")

# Place the legend outside the plot
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

plt.show()

In [None]:
# Define the y values for sampling
y_values = np.linspace(y_domain[0], y_domain[1], 10)

# Initialize an empty list to store the results
results = []

# Define a color map for assigning colors to y values
color_map = matplotlib.colormaps.get_cmap("tab10")

# Iterate over the y values and evaluate the model and synthetic data
for i, y in enumerate(y_values):
    # Create a tensor with the x values
    x_values = torch.linspace(x_domain[0], x_domain[1], 100)

    # Create the inputs tensor
    plt_inputs = torch.stack((x_values, torch.full_like(x_values, y)), dim=-1).to(
        DEVICE
    )

    # Evaluate the model and synthetic data
    K_grid_output = K_model(plt_inputs).detach().cpu().numpy()

    # Append the results to the list
    results.append((x_values.numpy(), K_grid_output.squeeze(), y))

# Plot the results
for i, (x_values, K_grid_output, y) in enumerate(results):
    color = color_map(i % color_map.N)  # Assign a unique color to each y value
    plt.plot(x_values, K_grid_output, label=f"y={y:.2f}", color=color)

plt.xlabel("x")
plt.ylabel("K_grid")
plt.title("K_grid as a function of x for various y")
plt.legend()

plt.show()