In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from tqdm import tqdm

In [None]:
torch.set_default_dtype(torch.float64)

# Load data


In [None]:
from scipy import signal
from scipy.fftpack import diff as psdiff
from scipy.integrate import solve_ivp


class AbstractODETarget:
    def __init__(self, dt=1e-3, t_step=0.25, dim=2):
        self.dim = dim
        self.dt = dt
        self.t_step = t_step
        self.n_step = int(t_step / dt)

    def generate_init_data(self, n_traj, traj_len, seed=None):
        data_x = []
        if seed is not None:
            np.random.seed(seed)

        x0 = np.random.uniform(size=(n_traj, self.dim), low=self.x_min, high=self.x_max)

        data_x.append(x0)
        for t in range(traj_len - 1):
            data_x.append(self.euler(data_x[t]))

        data_x = np.asarray(data_x)

        data_x = np.transpose(data_x, [1, 0, 2]).reshape(n_traj * traj_len, self.dim)
        return np.asarray(data_x)

    def generate_next_data(self, data_x):
        data_y = self.euler(data_x)
        return data_y

    def generate_data(self, n_traj, traj_len, seed=None):
        data_x = []
        if seed is not None:
            np.random.seed(seed)

        x0 = np.random.uniform(size=(n_traj, self.dim), low=self.x_min, high=self.x_max)

        data_x.append(x0)
        for t in range(traj_len - 1):
            data_x.append(self.euler(data_x[t]))

        data_x = np.asarray(data_x)
        data_x = np.transpose(data_x, [1, 0, 2])
        return np.asarray(data_x)

    def rhs(self):
        """RHS Function :return: The rhs of one specific ODE."""
        return NotImplementedError

    def euler(self, x):
        """ODE Solver.

        :param x: variable
        :type x: vector (float)
        :return: ODE Solution at t_step after iterating the Euler method n_step times
        :rtype: vector with the same shape as the variable x (float)
        """
        for _ in range(self.n_step):
            x = x + self.dt * self.rhs(x)
        return x


class DuffingOscillator(AbstractODETarget):
    """Duffing equation based on the notation in.

    (https://en.wikipedia.org/wiki/Duffing_equation)
    """

    def __init__(self, dt=1e-3, t_step=0.25, dim=2, delta=0.5, alpha=1.0, beta=-1.0):
        super().__init__(dt, t_step, dim)
        self.delta = delta
        self.alpha = alpha
        self.beta = beta
        self.x_min = -2
        self.x_max = 2

    def rhs(self, x):
        x1 = x[:, 0].reshape(x.shape[0], 1)
        x2 = x[:, 1].reshape(x.shape[0], 1)
        f1 = x2
        f2 = -self.delta * x2 - x1 * (self.beta + self.alpha * x1**2)
        return np.concatenate([f1, f2], axis=-1)

In [None]:
duffing = DuffingOscillator(dt=1e-3, t_step=0.25, dim=2, delta=0.5, alpha=1.0, beta=-1.0)
duffing_data_curr = duffing.generate_init_data(n_traj=1000, traj_len=50, seed=625)
duffing_data_next = duffing.generate_next_data(duffing_data_curr)
duffing_data_curr = torch.tensor(duffing_data_curr).double()
duffing_data_next = torch.tensor(duffing_data_next).double()

In [None]:
N = 1000
L = 50
state_dim = 2
d = state_dim

In [None]:
duffing_data = duffing.generate_data(n_traj=N, traj_len=L, seed=625)
duffing_data = torch.tensor(duffing_data)

In [None]:
dataset = torch.utils.data.TensorDataset(torch.tensor(duffing_data))

In [None]:
duffing_data.shape

In [None]:
index = 999
plt.plot(duffing_data[index, :, 0], duffing_data[index, :, 1])

# Build Model Class

## Dictionary Class

In [None]:
class AbstractDictionary:
    def __init__(self, n_psi_train, add_constant=True):
        self.n_psi_train = n_psi_train
        self.add_constant = add_constant

    def generate_B(self, inputs):
        target_dim = inputs.shape[-1]  # Get the last dimension of the input tensor

        if self.add_constant:
            self.n_psi = self.n_psi_train + target_dim + 1
            # Initialize B matrix with zeros
            self.B = torch.zeros(
                (self.n_psi, target_dim), dtype=inputs.dtype, device=inputs.device
            )
            # Setting the sub-diagonal elements to 1
            for i in range(target_dim):
                self.B[i + 1, i] = 1.0
        else:
            self.basis_func_number = self.n_psi_train + target_dim
            # Initialize B matrix with zeros
            self.B = torch.zeros(
                (self.basis_func_number, target_dim), dtype=inputs.dtype, device=inputs.device
            )
            # Setting the diagonal elements to 1
            for i in range(target_dim):
                self.B[i, i] = 1.0

        return self.B


class DicNN(nn.Module):
    """Trainable dictionaries."""

    def __init__(self, inputs_dim=1, layer_sizes=[64, 64], n_psi_train=22, activation_func="tanh"):
        super(DicNN, self).__init__()
        self.inputs_dim = inputs_dim
        self.layer_sizes = layer_sizes
        self.n_psi_train = n_psi_train
        self.activation_func = activation_func

        # Creating the input layer
        self.input_layer = nn.Linear(self.inputs_dim, layer_sizes[0], bias=False)

        # Creating hidden layers
        self.hidden_layers = nn.ModuleList()
        for in_features, out_features in zip(layer_sizes[:-1], layer_sizes[1:]):
            self.hidden_layers.append(nn.Linear(in_features, out_features))

        # Creating the output layer
        self.output_layer = nn.Linear(layer_sizes[-1], n_psi_train)

    def forward(self, inputs):
        # Check layer dimension
        if inputs.shape[-1] != self.inputs_dim:
            print(f"Error: Expected input dimension {self.inputs_dim}, but got {inputs.shape[-1]}")
            return None  # Optionally, you could raise an exception here

        # Apply the input layer
        psi_x_train = self.input_layer(inputs)

        # Apply hidden layers with residual connections
        for layer in self.hidden_layers:
            if self.activation_func == "tanh":
                psi_x_train = psi_x_train + F.tanh(layer(psi_x_train))
            elif self.activation_func == "relu":
                psi_x_train = psi_x_train + F.relu(layer(psi_x_train))
            else:
                raise ValueError("Unsupported activation function")

        # Apply the output layer
        outputs = self.output_layer(psi_x_train)
        return outputs


class PsiNN(nn.Module, AbstractDictionary):
    def __init__(
        self,
        inputs_dim=1,
        dic_trainable=DicNN,
        layer_sizes=[64, 64],
        n_psi_train=22,
        activation_func="tanh",
        add_constant=True,
    ):
        super(PsiNN, self).__init__()
        self.n_psi_train = n_psi_train
        self.add_constant = add_constant
        # Create an instance of the dic_trainable with given parameters
        self.dicNN = (
            dic_trainable(inputs_dim, layer_sizes, n_psi_train, activation_func)
            if n_psi_train != 0
            else None
        )

    def forward(self, inputs):
        outputs = []

        # Add a constant column of ones
        if self.add_constant:
            constant = torch.ones_like(inputs)[..., [0]]
            outputs.append(constant)

        # Add the original inputs
        outputs.append(inputs)

        # Add the output from dicNN if applicable
        if self.n_psi_train != 0:
            psi_x_train = self.dicNN(inputs)
            outputs.append(psi_x_train)

        # Concatenate along the feature dimension
        outputs = torch.cat(outputs, dim=-1) if len(outputs) > 1 else outputs[0]

        return outputs

## Constant K Class

In [None]:
class ConstantMatrixMultiplier(nn.Module):
    def __init__(self, n_psi, dict_cons=True):
        super(ConstantMatrixMultiplier, self).__init__()
        # Initialize K as a n_spi x n_psi trainable parameter
        self.n_psi = n_psi
        # initial_weights = torch.eye(n_psi)*1/10
        initial_weights = torch.randn(n_psi, n_psi)
        self.K = nn.Parameter(initial_weights)
        self.K.requires_grad = False

    def forward(self, inputs):
        # Perform matrix multiplication
        # inputs should be of shape (batch_size, n_psi)
        # K is (n_psi, n_psi), so the result will be of shape (batch_size, n_psi)
        return torch.matmul(inputs, self.K)

## Koopman Prediction Class

In [None]:
class Koopman_predictor(nn.Module):
    def __init__(self, dict, model_K):
        super(Koopman_predictor, self).__init__()
        self.dict = dict
        self.model_K = model_K

    def forward(self, inputs):
        # Apply dictionary
        psi_x = self.dict(inputs)
        # Apply Koopman operator
        K_psi_x = self.model_K(psi_x)
        return K_psi_x

# Build Data Loader


In [None]:
from torch.utils.data import DataLoader, Dataset


class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y

# Build Model

In [None]:
state_dim = 2
layer_sizes = [256, 256, 256]
n_psi_train = 22
activation_func = "tanh"

In [None]:
n_psi = 1 + state_dim + n_psi_train

In [None]:
dict_nn = PsiNN(
    inputs_dim=state_dim,
    layer_sizes=layer_sizes,
    n_psi_train=n_psi_train,
    activation_func=activation_func,
)
model_K = ConstantMatrixMultiplier(n_psi=n_psi)

Koopman_model = Koopman_predictor(dict_nn, model_K)

In [None]:
optimizer = torch.optim.Adam(list(Koopman_model.parameters()), lr=1e-2)
loss_function = nn.MSELoss()

In [None]:
# dataset = MyDataset(duffing_data_curr, duffing_data_next)

data_loader = DataLoader(dataset, batch_size=100, shuffle=True)

In [None]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.8, patience=20
)

# Train Model

In [None]:
def compute_K(dict, data_x, data_y, reg):
    # Compute representations
    psi_x = dict(data_x)
    psi_y = dict(data_y)

    # Transpose psi_x
    psi_xt = psi_x.t()  # Transposing the matrix

    # Identity matrix with the same dimension as psi_x
    idmat = torch.eye(psi_x.shape[1], dtype=torch.float64)

    # Regularized inverse computation
    xtx = torch.mm(psi_xt, psi_x)  # Matrix multiplication of psi_xt and psi_x
    xtx_inv = torch.pinverse(reg * idmat + xtx)  # Pseudoinverse of regularized matrix

    # Matrix multiplication of psi_xt and psi_y
    xty = torch.mm(psi_xt, psi_y)

    # Compute the regularized K matrix
    K_reg = torch.mm(xtx_inv, xty)

    return K_reg

In [None]:
Koopman_model.dict(duffing_data).shape

In [None]:
# Number of training epochs
# Koopman_model.to('cuda:0')


num_epochs = 60
loss_history = []
T = 2
for epoch in range(num_epochs):
    loop = tqdm(data_loader, leave=True)
    epoch_losses = []
    for x in loop:
        # x shape: (n_traj, traj_length, state_dim)

        # output_pred = Koopman_model(x_curr)

        # output_next = Koopman_model.dict(x_next)

        psi = Koopman_model.dict(x[0])
        target = 0
        for i in range(L - T):
            for j in range(T):
                target += torch.linalg.norm(
                    psi[:, i + j, :]
                    - psi[:, i, :] @ torch.matrix_power(Koopman_model.model_K.K, j)
                )

        zero = torch.zeros_like(target)
        # Compute the loss
        loss = loss_function(target, zero)

        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Collect loss for this batch
        epoch_losses.append(loss.item())

        current_lr = optimizer.param_groups[0]["lr"]

        # Update progress bar with current loss.
        loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        loop.set_postfix(loss=loss.item(), lr=current_lr)

    # Average loss for this epoch
    average_epoch_loss = sum(epoch_losses) / len(epoch_losses)
    loss_history.append(average_epoch_loss)

    # # Update the weights of model_K with the pinverse method
    # psi_x = dict_nn(duffing_data_curr)
    # psi_y = dict_nn(duffing_data_next)
    # psi_curr_pinv = torch.pinverse(psi_x)
    # K_weights = torch.matmul(psi_curr_pinv, psi_y)
    # model_K.K.data = K_weights

    duffing_data_curr = duffing_data[:, :-1, :]
    duffing_data_next = duffing_data[:, 1:, :]
    duffing_data_curr = duffing_data_curr.reshape(-1, duffing_data_curr.shape[-1])
    duffing_data_next = duffing_data_next.reshape(-1, duffing_data_next.shape[-1])

    K_weights = compute_K(Koopman_model.dict, duffing_data_curr, duffing_data_next, reg=0.01)
    Koopman_model.model_K.K.data = K_weights

    # Perform a forward pass with the updated model_K to compute the loss for the epoch
    with torch.no_grad():
        output_curr = Koopman_model.dict(duffing_data[:, :-1, :]).detach()
        output_next = Koopman_model.dict(duffing_data[:, 1:, :]).detach()
        output_pred = Koopman_model.model_K(output_curr).detach()
        loss_total = loss_function(output_next, output_pred)

    # Update learning rate based on total loss at the end of epoch
    # scheduler.step(average_epoch_loss)
    scheduler.step(loss_total)

    # Print the loss at the end of each epoch using tqdm.write to avoid breaking the progress bar layout
    tqdm.write(f"Epoch {epoch + 1}/{num_epochs} finished with updated loss: {loss_total.item()}")

In [None]:
for name, param in Koopman_model.model_K.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param} \n")

In [None]:
eigenvalues, eigenvectors = torch.linalg.eig(Koopman_model.model_K.K.data)
sorted_indices = torch.argsort(eigenvalues.real, descending=True)
sorted_eigenvalues = eigenvalues[sorted_indices]
sorted_eigenvectors = eigenvectors[:, sorted_indices]

In [None]:
plt.scatter(sorted_eigenvalues.real, sorted_eigenvalues.imag)

In [None]:
plt.plot(sorted_eigenvalues)

# Evaluation

In [None]:
# def pred_soln(Koopman_model, x0, Nt):
#     Koopman_model.eval()
#     x_pred_list = [x0]

#     psi_x0 = Koopman_model.dict(x0)
#     psi_x_pred_list = [psi_x0]

#     B = Koopman_model.dict.generate_B(x0)

#     for _ in range(Nt):
#         psi_pred = Koopman_model.model_K(psi_x_pred_list[-1])
#         x_pred = torch.matmul(psi_pred, B)
#         x_pred_list.append(x_pred.detach())
#         psi_x_pred_list.append(psi_pred.detach())

#     return torch.stack(x_pred_list, dim=1)

In [None]:
def pred_soln(Koopman_model, x0, Nt):
    Koopman_model.eval()
    x_pred_list = [x0]
    # psi_x0 = Koopman_model.dict(x0)
    # psi_pred = psi_x0

    B = Koopman_model.dict.generate_B(x0)
    x_pred = x0

    for _ in range(Nt):
        psi_pred = Koopman_model(x_pred)
        x_pred = torch.matmul(psi_pred, B)
        x_pred_list.append(x_pred.detach())

    return torch.stack(x_pred_list, dim=1)

In [None]:
duffing_data_test = duffing.generate_init_data(n_traj=1, traj_len=50, seed=521)
duffing_data_test.shape

In [None]:
test_length = 50

In [None]:
duffing_test_pred_iter = pred_soln(
    Koopman_model=Koopman_model,
    x0=torch.tensor(duffing_data_test[0]).double().reshape(1, -1),
    Nt=test_length - 1,
)

In [None]:
# duffing_test_pred = pred_soln(dict=dict_nn,
#                               model_K=model_K,
#                               x0=torch.tensor(duffing_data_curr[0]).double().reshape(1,-1),
#                               Nt=49)

In [None]:
plt.scatter(duffing_data_test[:, 0], duffing_data_test[:, 1], label="True")
plt.scatter(duffing_test_pred_iter[:, :, 0], duffing_test_pred_iter[:, :, 1], label="Predicted")
plt.legend()