In [1]:
#!/usr/bin/env python

import argparse
import os
import sys
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
import scipy.io
import torch
from scipy.io import savemat, loadmat
from scipy.integrate import solve_ivp


from torch.utils.data import DataLoader, TensorDataset
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator
from mpl_toolkits.mplot3d import Axes3D

THIS_DIR = os.path.dirname(os.path.abspath("__file__"))
PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath("__file__")))
sys.path.append(PARENT_DIR)

import generalized_quad_embs.modules_quad_stable as module
import generalized_quad_embs.plots_helper as plot
from generalized_quad_embs import utils
from generalized_quad_embs.constants import DATA_DIR


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device.type == "cpu":
    print("No GPU found!")
else:
    print("Great, a GPU is there")
print("=" * 50)

# Plotting setting
prop_cycle = plt.rcParams["axes.prop_cycle"]
colors = prop_cycle.by_key()["color"]

plt.rc("font", size=20)  # controls default text size
plt.rc("axes", titlesize=20)  # fontsize of the title
plt.rc("axes", labelsize=20)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=20)  # fontsize of the x tick labels
plt.rc("ytick", labelsize=20)  # fontsize of the y tick labels
plt.rc("legend", fontsize=15)  # fontsize of the legend

/Users/goyalp/Manuscripts/Year_2023/generalized_quad_embs
/Users/goyalp/Manuscripts/Year_2023/generalized_quad_embs/data
No GPU found!


In [2]:
## Define the parameters
@dataclass
class Parameters:
    """It contain necessary parameters for this example."""

    canonical_dim = 3  # canonical dimension
    train_index = 600  # number of pts for training
    latent_dim = 6  # latent canonical dimensional
    hidden_dim = 16  # number of neurons in a hidden layer
    batch_size = 64  # batch size
    learning_rate = 5e-3  # Learning rate
    encoder = "MLP"
    confi_model: str = None  # model configuration
    epoch: int = None  # number of epochs which are externally controlled
    path: str = (
        None  # path where the results will be save and it is also externally controlled
    )
    loss_weights: tuple = (10.0, 1.0)
    
        


In [3]:
def plot_singular_values():
    """It is a helper function for plotting singular values."""
    plt.rc("font", size=30)  # controls default text size

    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    ax.semilogy(_S / _S[0])
    ax.set(
        ylabel="singular values (rel)", xlabel="$k$", xlim=(-1, 50), ylim=(1e-6, 2e0)
    )
    ax.grid()
    plt.tight_layout()
    fig.savefig(params.path + "svd_plot.png", dpi=300)
    fig.savefig(params.path + "svd_plot.pdf")

    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    Ssum = np.zeros_like(_S)
    Ssum[0] = _S[0]
    for i in range(len(_S[1:])):
        Ssum[i] = sum(_S[: i + 1])

    # plt.plot(_S)
    ax.plot(list(range(1, len(_S) + 1)), Ssum / sum(_S))
    ax.plot([3, 3], [0.05, 1.0], "g--")
    ax.plot([-1, 50], [sum(_S[:3]) / sum(_S), sum(_S[:3]) / sum(_S)], "g--")

    ax.set(ylabel="energy captured", xlabel="$k$", xlim=(-1, 20), ylim=(0.05, 1.05))
    ax.set_xticks([0, 3, 5, 10, 15, 20])
    ax.set_yticks([0.25, 0.527, 0.75, 1.0])
    ax.grid()
    plt.tight_layout()
    fig.savefig(params.path + "energy_plot.png", dpi=300)
    fig.savefig(params.path + "energy_plot.pdf")


def POD_coeffs_plots():
    """It is a helper function to plot POD coefficients
    (grouth truth as well as learned ones).
    """
    fig1, ax1 = plt.subplots(1, 1, figsize=(5, 3))
    fig2, ax2 = plt.subplots(1, 1, figsize=(5, 3))

    for k in range(data.shape[-1]):
        ax1.plot(t, data[:, k].cpu())
        ax2.plot(t, decoded_latent_sol[:, k].cpu())
        ax1.set_ylim(-4.5, 4.5)
        ax2.set_ylim(-4.5, 4.5)

    ax1.set(xlabel="time", ylabel="POD coeffs")
    ax2.set(xlabel="time", ylabel="POD coeffs")

    fig1.savefig(
        params.path + f"pod_coeffs_ground_truth_{k}.png",
        dpi=300,
        bbox_inches="tight",
        pad_inches=0.1,
    )
    fig1.savefig(
        params.path + f"pod_coeffs_ground_truth_{k}.pdf",
        bbox_inches="tight",
        pad_inches=0.1,
    )

    fig2.savefig(
        params.path + f"pod_coeffs_learned_{k}.png",
        dpi=300,
        bbox_inches="tight",
        pad_inches=0.1,
    )
    fig2.savefig(
        params.path + f"pod_coeffs_learned_{k}.pdf", bbox_inches="tight", pad_inches=0.1
    )


def plot_decoded_solution(save_plots=True):
    """It is a helper function to plot solutions on the full grid.

    Args:
        save_plots (bool, optional): It indicates whether to save plots or not. Defaults to True.
    """

    def mycolorbar(ax):
        fbar = fig.colorbar(
            im,
            ax=ax,
            orientation="horizontal",
            pad=0.3,
            format="%.1e",
            ticks=[
                _min,
                (2 / 3) * _min + (1 / 3) * _max,
                (1 / 3) * _min + (2 / 3) * _max,
                _max,
            ],
        )
        tick_font_size = 13
        fbar.ax.tick_params(labelsize=tick_font_size)

    fig, axes = plt.subplots(1, 4, figsize=(18, 3), sharex=True, sharey=True)

    _min, _max = np.min(q_data.T[:, ::1].numpy()), np.max(q_data.T[:, ::1].numpy())
    im = axes[0].pcolormesh(
        T[:, ::1], X[:, ::1], q_data.T[:, ::1], vmin=_min, vmax=_max
    )

    axes[0].set(title="ground truth", xlabel="time", ylabel="$q$")
    mycolorbar(axes[0])

    _min, _max = np.min(q_pod.T[:, ::1].numpy()), np.max(q_pod.T[:, ::1].numpy())
    im = axes[1].pcolormesh(T[:, ::1], X[:, ::1], q_pod.T[:, ::1], vmin=_min, vmax=_max)

    axes[1].set(title=" linear-decoder", xlabel="time")
    mycolorbar(axes[1])

    _min, _max = np.min(q_rec_quad.T[:, ::1].numpy()), np.max(
        q_rec_quad.T[:, ::1].numpy()
    )
    im = axes[2].pcolormesh(
        T[:, ::1], X[:, ::1], q_rec_quad.T[:, ::1], vmin=_min, vmax=_max
    )

    axes[2].set(title=" quad-decoder", xlabel="time")
    mycolorbar(axes[2])

    _min, _max = np.min(q_rec.T[:, ::1].numpy()), np.max(q_rec.T[:, ::1].numpy())
    im = axes[3].pcolormesh(T[:, ::1], X[:, ::1], q_rec.T[:, ::1], vmin=_min, vmax=_max)

    mycolorbar(axes[3])
    axes[3].set(title=" convo-decoder", xlabel="time")
    plt.subplots_adjust(wspace=0.3, hspace=0)

    if save_plots:
        plt.savefig(
            params.path + f"q_compare_{i}.png",
            dpi=300,
            bbox_inches="tight",
            pad_inches=0.1,
        )
        # plt.savefig(
        #     params.path + f"q_compare_{i}.pdf", bbox_inches="tight", pad_inches=0.1
        # )

    fig, axes = plt.subplots(1, 4, figsize=(18, 3), sharex=True, sharey=True)

    _min, _max = np.min(p_data.T[:, ::1].numpy()), np.max(p_data.T[:, ::1].numpy())
    im = axes[0].pcolormesh(
        T[:, ::1], X[:, ::1], p_data.T[:, ::1], vmin=_min, vmax=_max
    )

    axes[0].set(title="ground truth", xlabel="time", ylabel="$p$")
    mycolorbar(axes[0])

    _min, _max = np.min(p_pod.T[:, ::1].numpy()), np.max(p_pod.T[:, ::1].numpy())
    im = axes[1].pcolormesh(T[:, ::1], X[:, ::1], p_pod.T[:, ::1], vmin=_min, vmax=_max)

    axes[1].set(title=" linear-decoder", xlabel="time")
    mycolorbar(axes[1])

    _min, _max = np.min(p_rec_quad.T[:, ::1].numpy()), np.max(
        p_rec_quad.T[:, ::1].numpy()
    )
    im = axes[2].pcolormesh(
        T[:, ::1], X[:, ::1], p_rec_quad.T[:, ::1], vmin=_min, vmax=_max
    )

    axes[2].set(title=" quad-decoder", xlabel="time")
    mycolorbar(axes[2])

    _min, _max = np.min(p_rec.T[:, ::1].numpy()), np.max(p_rec.T[:, ::1].numpy())
    im = axes[3].pcolormesh(T[:, ::1], X[:, ::1], p_rec.T[:, ::1], vmin=_min, vmax=_max)

    mycolorbar(axes[3])
    axes[3].set(title=" convo-decoder", xlabel="time")
    plt.subplots_adjust(wspace=0.3, hspace=0)

    if save_plots:
        plt.savefig(
            params.path + f"p_compare_{i}.png",
            dpi=300,
            bbox_inches="tight",
            pad_inches=0.1,
        )
        # plt.savefig(
        #     params.path + f"p_compare_{i}.pdf", bbox_inches="tight", pad_inches=0.1
        # )


In [4]:
utils.reproducibility_seed(seed=100)

params = Parameters()
parser = argparse.ArgumentParser()

parser.add_argument(
    "--confi_model",
    type=str,
    default="quad",
    choices={"linear", "quad", "cubic", "linear_nostability", "quad_opinf"},
    help="Enforcing model hypothesis",
)

parser.add_argument("--epochs", type=int, default=400, help="Number of epochs")

args = parser.parse_args([])

params.confi_model = args.confi_model
params.epoch = args.epochs

params.path = "./../Results/Burgers/" + params.confi_model + "/"

if not os.path.exists(params.path):
    os.makedirs(params.path)
    print("The new directory is created as " + params.path)

color_idx, method_name = utils.define_color_method(params)

In [6]:
# import h5py
# file_mat = "/Users/goyalp/Manuscripts/Year_2023/generalized_quad_embs/data/Burgers_dirichilet_data.mat"
# with h5py.File(file_mat, 'r') as f:
#     print(list(f.keys()))  # to inspect contents

In [7]:
# ## Loading data
# data = loadmat("./Burgers_dirichilet_data.mat")
# data = loadmat(str(DATA_DIR / "Burgers_dirichilet_data.mat"))
data = loadmat("/Users/goyalp/Manuscripts/Year_2023/generalized_quad_embs/data/Burgers_dirichilet_data.mat")
X_all = data["X_data"].transpose(0, 2, 1)[:,::1,:]
dX_all = data["dX_data"].transpose(0, 2, 1)[:,::1,:]

print(X_all.shape)
x_shift = 0 * np.ones((1, X_all.shape[1], 1))
X_all = X_all - x_shift

idxs = list(np.arange(0, 13))
testing_idxs = list([2, 5, 8, 11])
train_idxs = list(set(idxs) - set(testing_idxs))
# train_idxs = testing_idxs

# testing_idxs = list([2])
# train_idxs = testing_idxs

X_testing = X_all[testing_idxs]
X_training = X_all[train_idxs]

dX_testing = dX_all[testing_idxs]
dX_training = dX_all[train_idxs]


t = data["t"].T
print(f"Training trajectories: {X_training.shape}")
print(f"Testing trajectories:  {X_testing.shape}\n")

num_inits = X_training.shape[0]
print(f"shape of X for training:\t{X_training.shape}")
print(f"Training samples: {num_inits}")

(13, 256, 1001)
Training trajectories: (9, 256, 1001)
Testing trajectories:  (4, 256, 1001)

shape of X for training:	(9, 256, 1001)
Training samples: 9


In [8]:
# Compute SVD in order to prepare low-dimensional data

temp_X = np.hstack(X_training)
[U, S, V] = np.linalg.svd(temp_X)

temp_dX = np.hstack(dX_training)

tol = 1e-1
reduced_orders = []

[U, S, V] = np.linalg.svd(temp_X)
r = 1
while r < len(S) + 1:
    if 1 - sum(S[:r]) / sum(S) < tol:
        break
    r += 1
    
print(f"Domainant model for first: {r}")
print(f"Energy captured by the snapshots: {100*sum(S[:r])/sum(S)}%")

Domainant model for first: 11
Energy captured by the snapshots: 90.8740314568382%


In [9]:
params.canonical_dim = 256
params.latent_dim = 4


Projection_V = np.eye(256,256)
r = 256
print(f"Proj matrix shape: {Projection_V.shape}")
temp_Xr = Projection_V.T @ temp_X  # reduced data
temp_dXr = Projection_V.T @ temp_dX  # reduced data


fac = abs(temp_Xr).max()
fac = 1
temp_Xr = temp_Xr/fac

Xr = np.zeros((num_inits, r, temp_Xr.shape[-1] // num_inits))
dXr = np.zeros((num_inits, r, temp_Xr.shape[-1] // num_inits))

for i in range(0, num_inits):
    temp = int(temp_Xr.shape[-1] // num_inits)
    temp_x = temp_Xr[:, i * temp : (i + 1) * temp]
    temp_dx = temp_dXr[:, i * temp : (i + 1) * temp]
#     temp_dx = utils.ddt_uniform(temp_x, (t[1] - t[0]).item(), order=4)
    Xr[i] = temp_x
    dXr[i] = temp_dx


Proj matrix shape: (256, 256)


In [10]:
# to remove edge points since derivatives are not accutate therein.
Xr = Xr[:,:,:]
dXr = dXr[:,:,:]
print(Xr.shape)
print(dXr.shape)

Xr_v = np.hstack(Xr).T
dXr_v = np.hstack(dXr).T

train_dset = list(zip(torch.tensor(Xr_v).double().to(device).requires_grad_(), torch.tensor(dXr_v).double().to(device)))

train_dl = torch.utils.data.DataLoader(train_dset, batch_size = params.batch_size, shuffle = True)
dataloaders = {'train': train_dl}

(9, 256, 1001)
(9, 256, 1001)


In [13]:
import torch.nn as nn
import torch.nn.functional as F

class AutoencoderConvo(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()

        # Encoder
        self.conv1 = nn.Conv1d(
            in_channels=1, out_channels=4, kernel_size=5, stride=2, padding=2
        )
        self.conv2 = nn.Conv1d(
            in_channels=4, out_channels=8, kernel_size=3, stride=2, padding=1
        )
        self.conv3 = nn.Conv1d(
            in_channels=8, out_channels=16, kernel_size=3, stride=2, padding=1
        )
        self.conv4 = nn.Conv1d(
            in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1
        )
        self.linear_e1 = nn.Linear(in_features=512, out_features=32)
        self.linear_e2 = nn.Linear(in_features=32, out_features=32)
        self.linear_e3 = nn.Linear(in_features=32, out_features=latent_dim)

        # Decoder
        self.linear_d1 = nn.Linear(in_features=latent_dim + latent_dim**2, out_features=32)
        self.linear_d2 = nn.Linear(in_features=32, out_features=32)
        self.linear_d3 = nn.Linear(in_features=32, out_features=512)
        self.deconv1 = nn.ConvTranspose1d(
            in_channels=32,
            out_channels=16,
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1,
        )
        self.deconv2 = nn.ConvTranspose1d(
            in_channels=16,
            out_channels=8,
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1,
        )
        self.deconv3 = nn.ConvTranspose1d(
            in_channels=8,
            out_channels=8,
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1,
        )
        self.deconv4 = nn.ConvTranspose1d(
            in_channels=8,
            out_channels=8,
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1,
        )
        self.deconv5 = nn.Conv1d(
            in_channels=8,
            out_channels=1,
            kernel_size=3,
            stride=1,
            padding=1,
#             output_padding=1,
        )
        

    def encode(self, x):
        h = x.reshape(x.shape[0],1,x.shape[1])
#         print(f'reshape step: {h.shape}')
        h = self.conv1(h)
        h = F.selu_(h)
#         print(f'after conv1: {h.shape}')
        h = self.conv2(h)
        h = F.selu_(h)
#         print(f'after conv2: {h.shape}')
        h = self.conv3(h)
        h = F.selu_(h)
#         print(f'after conv3: {h.shape}')
        h = self.conv4(h)
        h = F.selu_(h)
#         print(f'after conv4: {h.shape}')
        h = nn.Flatten()(h)
#         print(f'after flattening: {h.shape}')
        h =  F.selu_(self.linear_e1(h))
#         print(f'after linear_e1: {h.shape}')
        h = h + F.selu_(self.linear_e2(h))
#         print(f'after linear_e2: {h.shape}')
        h = self.linear_e3(h)
#         print(f'after linear_e3: {h.shape}')
        
        return h

    def decode(self, z):
        z2 = utils.kron(z, z) 
        z_all = torch.concat((z,z2), dim = -1)
#         print(f'z2 shape: {z2.shape} | z_all  shape: {z_all.shape}')
        h = self.linear_d1(z_all)
#         h = h + F.selu_(h)
#         print(f'after linear_d1: {h.shape}')
        h = h + F.selu_(self.linear_d2(h))
#         print(f'after linear_d2: {h.shape}')
        h = F.selu_(self.linear_d3(h))
#         print(f'after linear_d3: {h.shape}')
        h = h.reshape(h.shape[0], 32, 16)
#         print(f'after reshape: {h.shape}')

        h = self.deconv1(h)
        h = F.selu_(h)
#         print(f'after deconv1: {h.shape}')

        h = self.deconv2(h)
        h = F.selu_(h)
#         print(f'after deconv2: {h.shape}')
        h = self.deconv3(h)
        h = F.selu_(h)
#         print(f'after deconv3: {h.shape}')
        h = self.deconv4(h)
        h = F.selu_(h)
#         print(f'after deconv4: {h.shape}')
        h = self.deconv5(h)
#         print(f'after deconv5: {h.shape}')
        h = nn.Flatten()(h)
#         print(f'after flattening: {h.shape}')
        return h

    def forward(self, x):
        z = self.encode(x)
        x_hat = self.decode(z)
        return x_hat

In [14]:
x, dx = next(iter(train_dl))
print(x.shape)
models_conv = AutoencoderConvo(latent_dim=4).double().to(device)
models_conv(x.double().to(device)).shape

torch.Size([64, 256])


torch.Size([64, 256])

In [15]:
torch.concat

<function torch._VariableFunctionsClass.concat>

In [16]:
models = module.network_models(params)
models['ae'] = AutoencoderConvo(latent_dim=4).double().to(device)
optim = torch.optim.Adam(
    [
        {
            "params": models["ae"].parameters(),
            "lr": params.learning_rate,
            "weight_decay": 1e-5,
        },
        {
            "params": models["vf"].parameters(),
            "lr": params.learning_rate,
            "weight_decay": 1e-5,
        },
    ]
)

Nonlinear autoencoder and quadratic system with Gaurantee stability!
Global Stability Gurantees!
B_term: False


In [17]:
models['ae']

AutoencoderConvo(
  (conv1): Conv1d(1, 4, kernel_size=(5,), stride=(2,), padding=(2,))
  (conv2): Conv1d(4, 8, kernel_size=(3,), stride=(2,), padding=(1,))
  (conv3): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
  (conv4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
  (linear_e1): Linear(in_features=512, out_features=32, bias=True)
  (linear_e2): Linear(in_features=32, out_features=32, bias=True)
  (linear_e3): Linear(in_features=32, out_features=4, bias=True)
  (linear_d1): Linear(in_features=20, out_features=32, bias=True)
  (linear_d2): Linear(in_features=32, out_features=32, bias=True)
  (linear_d3): Linear(in_features=32, out_features=512, bias=True)
  (deconv1): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
  (deconv2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
  (deconv3): ConvTranspose1d(8, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))


In [18]:
# %%prun
models, err_t = module.train_quad_conv(models, train_dl, optim, params)

function from train_quad_conv
Training begins!
weights for loss function are: (10.0, 1.0)
Epoch 1/400 | loss_VF: 1.37e-03 | loss_AE: 1.10e-01 | learning rate: 5.00e-03
Epoch 2/400 | loss_VF: 7.84e-05 | loss_AE: 1.15e-01 | learning rate: 5.00e-03
Epoch 3/400 | loss_VF: 1.94e-05 | loss_AE: 8.93e-02 | learning rate: 5.00e-03
Epoch 4/400 | loss_VF: 5.44e-05 | loss_AE: 7.90e-02 | learning rate: 5.00e-03


KeyboardInterrupt: 

In [None]:
import profile
pr = profile.Profile()


In [None]:
PROPERTY = {
    "cmap": cm.viridis,
    "antialiased": True,
    "rcount": 500,
    "ccount": 500,
    "linewidth": 0,
}

def plotting_chafee():
    fig, ax = plt.subplots(1, 3, figsize=(16, 4), subplot_kw={"projection": "3d"})
    # Plot the surface.
    surf = ax[0].plot_surface(
        space_grid,
        time_grid,
        (X_testing[k] + x_shift.reshape(-1, 1)).T,
        **PROPERTY,
    )
    surf = ax[1].plot_surface(
        space_grid,
        time_grid,
        (full_sol_OpInf + x_shift.reshape(-1, 1)).T,
        **PROPERTY,
    )
    surf = ax[2].plot_surface(
        space_grid,
        time_grid,
        np.log10(abs((X_testing[k] - full_sol_OpInf).T)),
        **PROPERTY,
    )
    ax[0].set(xlabel="$x$", ylabel="time", zlabel="$u(x,t)$", title="groundtruth")

    ax[1].set(
        xlabel="$x$",
        ylabel="time",
        zlabel="$\hat{u}(x,t)$",
        title="learned model",
    )
    ax[2].set(
        xlabel="$x$",
        ylabel="time",
        zlabel="error in log-scale",
        title="absolute error",
    )

#     ax[0].set_zlim([0, 1.25])
#     ax[1].set_zlim([0, 1.25])

    for _ax in ax:
        _ax.xaxis.labelpad = 10
        _ax.yaxis.labelpad = 20
        _ax.zaxis.labelpad = 10

    ax[2].tick_params(axis="z", direction="out", pad=10)
    ax[2].zaxis.labelpad = 20

    ax[2].set_zscale("linear")

    plt.tight_layout(pad=0.2, w_pad=0.1, h_pad=0.1)

    # plt.show()
    # fig.savefig(
    #     params.path + f"simulation_test_{k}_order_{r}.pdf",
    #     bbox_inches="tight",
    #     pad_inches=0,
    # )
    fig.savefig(
        params.path + f"simulation_test_{k}_order_{r}.png",
        dpi=300,
        bbox_inches="tight",
        pad_inches=0,
    )


In [None]:
# Extracting autoencoder and hnn (hamiltonian)
autoencoder, vf = models["ae"], models["vf"]

In [None]:
# Prepare learned models for integration
def learned_model(t, x):
    """It yields time-derivative of x at time t.
    It is obtained throught the time-derivative of Hamiltonian function.

    Args:
        t (float): time
        x (float): state variable containing position and momenta.

    Returns:
        float: time-derivative of x
    """
    x = torch.tensor(
        x.reshape(-1, params.latent_dim), dtype=torch.float64, requires_grad=True
    ).to(device)
    y = vf.vector_field(x)
    y = y.detach()
    return y.cpu().numpy()

In [None]:
models['vf'] = models['vf'].to(device)

t = np.arange(0, len(t)) * (t[1] - t[0])

# testing learned model
space = np.arange(0, 1, 1 / 256)
time = np.array(data["t"].T)
space_grid, time_grid = np.meshgrid(space, time)

Err_testing = []
Learned_sols = []
Truth_sols = []

with torch.no_grad():
    for k in range(X_testing.shape[0]):
        y0_test = Projection_V.T @ X_testing[k, :, 0]/fac
        encoded_initial = autoencoder.encode(torch.from_numpy(y0_test)[None].to(device)).reshape(-1,).detach().cpu()
        print(f'Initial condition encoded: {encoded_initial.shape}')

        sol_OpInf = solve_ivp(learned_model, [t[0], t[-1]], y0=encoded_initial, 
                              t_eval=t, rtol = 1e-8, atol = 1e-8)
        
        sol_OpInf = torch.tensor(sol_OpInf.y, dtype=torch.float64).to(device)
        decoded_latent_sol = autoencoder.decode(sol_OpInf.T).detach().cpu().numpy()
        print(f'decoded solution: {decoded_latent_sol.shape}')

        full_sol_OpInf = fac*Projection_V @ decoded_latent_sol.T

        plotting_chafee()
        err = (np.linalg.norm(full_sol_OpInf - X_testing[k])) / (
            np.linalg.norm(X_testing[k])
        )
        print(abs(full_sol_OpInf - X_testing[k]).max())
        print(f'test {k} error: {err}')
        
        Err_testing.append(err)
        Truth_sols.append(X_testing[k])
        Learned_sols.append(full_sol_OpInf)
    savemat(params.path + f"simulation_error_order_{r}.mat",  {"errors": Err_testing, 
                                                                "learned_sols":Learned_sols,
                                                                "truth_sols": Truth_sols
                                                                })

In [None]:
Err_testing

In [None]:
r

In [None]:
models['ae'].linear_e1.weight, models['ae'].linear_e1.weight.sum(), models['ae'].linear_e1.weight.abs().sum()

In [None]:
models['ae'].linear_d1.weight, models['ae'].linear_d1.weight.sum(), models['ae'].linear_d1.weight.abs().sum();

In [None]:
def train_quad_finetuning(models, pre_trained_model,  train_dl, optim, params):
    """It does training to learn vector field for the systems that are canonical

    Args:
        models (nn.module): models containing autoencoder and Hamiltonian networks
        train_dl (dataloder): training data
        optim (optimizer): optmizer to update parameters of neural networks
        params (dataclass): contains necessary parameter e.g., number of epochs

    Returns:
        (model, loss): trained model and loss as training progresses
    """

    scheduler = torch.optim.lr_scheduler.StepLR(
        optim, step_size=150 * len(train_dl), gamma=0.1
    )

    print("Training begins!")

    mse_loss = nn.MSELoss()

    err_t = []
    for i in range(params.epoch):

        for x, dxdt in train_dl:
            
            with torch.no_grad():
                z = pre_trained_model["ae"].encode(x)
            x_hat = models["ae"].decode(z)
            loss_ae = 0.5 * mse_loss(x_hat, x)  # Encoder loss
            loss_ae += 0.5 * (x_hat - x).abs().mean()  # Encoder loss

            loss = 1e1*loss_ae
            optim.zero_grad()
            loss.backward()
            optim.step()
            scheduler.step()
            err_t.append(loss.item())

        if (i + 1) % 1 == 0:
            lr = optim.param_groups[0]["lr"]
            print(
                f"Epoch {i+1}/{params.epoch} | loss_AE: {loss_ae.item():.2e} | learning rate: {lr:.2e}"
            )
    return models, err_t


In [None]:
import copy 

models_finetune = module.network_models(params)
models_finetune['ae'] = AutoencoderConvo(latent_dim=4).double().to(device)

sd = models['ae'].state_dict()
models_finetune['ae'].load_state_dict(copy.deepcopy(sd))

optim_finetune = torch.optim.Adam(
    [
        {
            "params": models_finetune["ae"].parameters(),
            "lr": params.learning_rate,
            "weight_decay": 1e-6,
        },
    ]
)



models_finetune, err_finetune = train_quad_finetuning(models_finetune, models, train_dl, optim_finetune, params)

In [None]:
models['ae'].linear_e1.weight, models['ae'].linear_e1.weight.sum(), models['ae'].linear_e1.weight.abs().sum()

In [None]:
#verify
for name, param in models['ae'].named_parameters():
    print(name,param.requires_grad)

In [None]:
autoencoder_finetune = models_finetune['ae']

In [None]:
models['vf'] = models['vf'].to(device)

t = np.arange(0, len(t)) * (t[1] - t[0])

# testing learned model
space = np.arange(0, 1, 1 / 256)
time = np.array(data["t"].T)
space_grid, time_grid = np.meshgrid(space, time)

Err_testing = []
Learned_sols = []
Truth_sols = []


with torch.no_grad():
    for k in range(X_testing.shape[0]):
        y0_test = Projection_V.T @ X_testing[k, :, 0]/fac
        encoded_initial = autoencoder.encode(torch.from_numpy(y0_test)[None].to(device)).reshape(-1,).detach().cpu()
        print(f'Initial condition encoded: {encoded_initial.shape}')

        sol_OpInf = solve_ivp(learned_model, [t[0], t[-1]], y0=encoded_initial, 
                              t_eval=t, rtol = 1e-8, atol = 1e-8)
        
        sol_OpInf = torch.tensor(sol_OpInf.y, dtype=torch.float64).to(device)
        decoded_latent_sol = autoencoder_finetune.decode(sol_OpInf.T).detach().cpu().numpy()
        print(f'decoded solution: {decoded_latent_sol.shape}')

        full_sol_OpInf = fac*Projection_V @ decoded_latent_sol.T

        plotting_chafee()
        err = (np.linalg.norm(full_sol_OpInf - X_testing[k])) / (
            np.linalg.norm(X_testing[k])
        )
        print(abs(full_sol_OpInf - X_testing[k]).max())
        print(f'test {k} error: {err}')
        Err_testing.append(err)
        Truth_sols.append(X_testing[k])
        Learned_sols.append(full_sol_OpInf)
    savemat(params.path + f"simulation_error_order_{r}_polished_AE.mat",  {"errors": Err_testing, 
                                                                "learned_sols":Learned_sols,
                                                                "truth_sols": Truth_sols
                                                                })

Initial condition encoded: torch.Size([4])
decoded solution: (1001, 256)
0.2475831567963216
test 0 error: 0.008851457732680427
Initial condition encoded: torch.Size([4])
decoded solution: (1001, 256)
0.38868711397668676
test 1 error: 0.010535238678166069
Initial condition encoded: torch.Size([4])
decoded solution: (1001, 256)
0.42201450073533797
test 2 error: 0.009247118355709578
Initial condition encoded: torch.Size([4])
decoded solution: (1001, 256)
0.6096916724260248
test 3 error: 0.017053969200681543