In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch
from torch.autograd import Variable
import pickle
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
from collections import defaultdict
import torch.nn.functional as F
from tqdm import tqdm

np.random.seed(0)
output_model_name = "E2E_KF_SST_1Point_mdr95"

# Load SST data at 

In [None]:
with open(
    "/sanssauvegarde/homes/s17ouala/Complement/Koopman reduction/LearningKoopman/Identification/SST_Data_Points/SST_data_points_processed.pickle",
    "rb",
) as handle:
    dict_data = pickle.load(handle)

In [None]:
namesOfLocations = dict_data["namesOfLocations"]
LonOfLocations = dict_data["LonOfLocations"]
LatOfLocations = dict_data["LatOfLocations"]
idx_points = dict_data["idx_points"]

dim = 1

Y_Train = dict_data["data_points_sat"][::5][1 : int(10000 / 5) + 1, :dim]
Y_Test = dict_data["data_points_sat"][::5][int(10000 / 5) + 1 :, :dim]

X_Train = np.copy(Y_Train)
X_Test = np.copy(Y_Test)

X_Train_Dyn = dict_data["data_points_sat_Fil"][::5][: int(10000 / 5) + 1, :dim]
X_Test_Dyn = dict_data["data_points_sat_Fil"][::5][int(10000 / 5) + 1 :, :dim]

In [None]:
namesOfLocations

In [None]:
print(LonOfLocations[0], LatOfLocations[0])

In [None]:
missing_data_rate = 0.95

idx_to_nan_train = np.random.choice(
    np.arange(X_Train.shape[0]),
    size=int(missing_data_rate * X_Train.shape[0]),
    replace=False,
)
idx_to_nan_test = np.random.choice(
    np.arange(X_Test.shape[0]),
    size=int(missing_data_rate * X_Test.shape[0]),
    replace=False,
)

In [None]:
X_Train[idx_to_nan_train, :] = np.nan
X_Test[idx_to_nan_test, :] = np.nan

In [None]:
plt.rcParams.update({"font.size": 22})

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(Y_Train, lw=3, alpha=0.1, label="True state")
plt.plot(X_Train, "o", markersize=16, label="Training observations")
plt.xlabel("Time step (adimensional)")
plt.ylabel("SST °C")
plt.legend(loc=9, bbox_to_anchor=(0.5, 1.1))
plt.tight_layout()
plt.savefig("training_data_sst.png")
plt.savefig("training_data_sst.pdf")
plt.savefig("training_data_sst.svg")

In [None]:
train_len = X_Train_Dyn.shape[0]
test_len = X_Test_Dyn.shape[0]


# Init KF model with a linear model of dim 6

In [None]:
dim_aug = 6
Batch_size = 500
size_int_new = X_Train_Dyn.shape[0]
dim = X_Train_Dyn.shape[-1]
nb_batch = int(size_int_new / Batch_size)

params = {}
params["transition_layers"] = 1
params["bi_linear_layers"] = dim_aug
params["dim_hidden_linear"] = dim_aug
params["dim_input"] = dim
params["dim_output"] = dim_aug
params["dim_latent"] = dim_aug - dim
params["dim_observations"] = dim_aug
params["dim_hidden"] = dim_aug
params["Nb_parts"] = 100
params["dt_integration"] = 1.0
params["n_train"] = 1000
params["transition_layers_B"] = 10
params["dim_hidden_B"] = 100
params["Batch_size"] = Batch_size
dt = 1.0

In [None]:
class E2E_EM_KS(torch.nn.Module):
    def __init__(self):
        super(E2E_EM_KS, self).__init__()
        # defining DA params
        A = np.random.uniform(
            size=(params["dim_observations"], params["dim_observations"])
        )
        self.A = torch.nn.Parameter(
            data=torch.from_numpy(A).float(), requires_grad=True
        )

        B = 0.1 + 0 * np.random.uniform(
            size=(params["dim_observations"], params["dim_observations"])
        )
        self.B = torch.nn.Parameter(torch.from_numpy(B).float(), requires_grad=False)

        R = 0.1 + 0 * np.random.uniform(size=(dim, dim))
        self.R = torch.nn.Parameter(torch.from_numpy(R).float(), requires_grad=False)

        Q = 0.1 * np.eye((params["dim_observations"]))
        self.Q = torch.nn.Parameter(torch.from_numpy(Q).float(), requires_grad=True)

        X0 = 0 * np.random.uniform(size=(params["dim_observations"]))
        self.X0 = torch.nn.Parameter(torch.from_numpy(X0).float(), requires_grad=False)

        weight_R = 1 + 0 * np.random.uniform(size=(1))
        self.weight_R = torch.nn.Parameter(
            torch.from_numpy(weight_R).float(), requires_grad=True
        )

    def Koopman_Propagator(self, dt=1.0):
        A = (self.A[:, :] - self.A[:, :].T) / 2
        Phi = torch.matrix_exp(A * dt)
        return Phi

    def RMSE(self, E):
        "Returns the Root Mean Squared Error"
        return torch.sqrt(torch.mean(E**2))

    def climat_background(self, X_true):
        "Returns a climatology (mean and covariance)"
        X_true = X_true.detach().numpy()
        xb = np.mean(X_true, 1)
        B = np.cov(X_true)
        xb = torch.from_numpy(xb).float()
        B = torch.from_numpy(B).float()
        return xb, B

    def gaspari_cohn(self, r):
        corr = 0
        if 0 <= r and r < 1:
            corr = 1 - 5 / 3 * r**2 + 5 / 8 * r**3 + 1 / 2 * r**4 - 1 / 4 * r**5
        elif 1 <= r and r < 2:
            corr = (
                4
                - 5 * r
                + 5 / 3 * r**2
                + 5 / 8 * r**3
                - 1 / 2 * r**4
                + 1 / 12 * r**5
                - 2 / (3 * r)
            )
        return corr

    def cov_prob(self, Xs, Ps, X_true):
        "Returns the number of true state in the 95% confidence interval"
        X_true = X_true.detach().numpy()
        Ps = Ps.detach().numpy()
        Xs = Xs.detach().numpy()
        n, T = np.shape(X_true)
        cov_prob = 0
        for i_n in range(n):
            cov_prob += (
                sum(
                    (
                        np.squeeze(Xs[i_n, :])
                        - 1.96 * np.sqrt(np.squeeze(Ps[i_n, i_n, :]))
                        <= X_true[i_n, :]
                    )
                    & (
                        np.squeeze(Xs[i_n, :])
                        + 1.96 * np.sqrt(np.squeeze(Ps[i_n, i_n, :]))
                        >= X_true[i_n, :]
                    )
                )
                / T
            )
        return cov_prob

    def _likelihood0(self, Xf, Pf, Yo, H, get_innovation_statistics):
        T = Xf.shape[1]
        l = 0
        # print(Yo.shape)
        sig_all = np.zeros((T, Yo.shape[0], Yo.shape[0])) * np.nan
        inno_all = np.zeros((T, Yo.shape[0])) * np.nan
        for t in range(T):
            i_var_obs = np.where(~np.isnan(Yo.detach().cpu().data.numpy()[:, t]))[0]
            if len(i_var_obs) > 0:
                sig = torch.mm(
                    torch.mm(H[i_var_obs, :, t], Pf[:, :, t]), H[i_var_obs, :, t].T
                ) + torch.abs(self.weight_R) * (self.R[np.ix_(i_var_obs, i_var_obs)])
                innov = Yo[i_var_obs, t] - torch.mv(H[:, :][i_var_obs, :, t], Xf[:, t])
                # l -= .5 * np.log(np.linalg.det(sig))
                # print(torch.slogdet(sig))
                # print(torch.det((sig).inverse()))
                # print(innov.shape)
                # sign, l_tmp = torch.slogdet(sig)# to print
                tmppp = 0.5 * torch.logdet(sig)  # to print
                # print('logdet',tmppp)
                # print('positive value',  0.5*torch.mm(torch.mm(innov.unsqueeze(0),(sig).inverse()),innov.unsqueeze(0).T))
                # print('sum log eig',0.5*sum(torch.log(torch.eig(sig)[0][:,0])))
                # print('eigs : ',sig.shape, torch.eig(sig)[0][:,0])
                # tmppp2 = .5 * sign * l_tmp
                # tmppp = 0.5*sum(torch.log(torch.eig(sig,eigenvectors=True)[0][:,0]))#   .5 * sign * l_tmp
                l += -tmppp - 0.5 * torch.mm(
                    torch.mm(innov.unsqueeze(0), (sig).inverse()), innov.unsqueeze(0).T
                )
                # sign, l_tmp = torch.slogdet(sig)# to print
                # l -= .5 * sign * l_tmp
                # l -=  .5 * torch.mm(innov.unsqueeze(0),torch.solve(innov.unsqueeze(1),sig)[0])[0,0]# to print
                # l -= tmpp
                # print('innov :', innov.shape)
                # print('inno_all : ',inno_all.shape)
                # print('inno_all[i,i_var_obs] : ',inno_all[t,i_var_obs].shape)
                if get_innovation_statistics:
                    inno_all[t, i_var_obs] = innov.detach().numpy()
                    sig_all[t, :, :][
                        np.ix_(i_var_obs, i_var_obs)
                    ] = sig.detach().numpy()
        return l[0, 0], inno_all, sig_all

    def _likelihood(self, Xf, Pf, Yo, H, get_innovation_statistics):
        T = Xf.shape[1]
        l = 0
        sig_all = np.zeros((T, Yo.shape[0], Yo.shape[0])) * np.nan
        inno_all = np.zeros((T, Yo.shape[0])) * np.nan
        for t in range(T):
            i_var_obs = np.where(~np.isnan(Yo.detach().cpu().data.numpy()[:, t]))[0]
            if len(i_var_obs) > 0:
                sig = torch.mm(
                    torch.mm(H[i_var_obs, :, t], Pf[:, :, t]), H[i_var_obs, :, t].T
                ) + torch.abs(self.weight_R) * (self.R[np.ix_(i_var_obs, i_var_obs)])
                innov = Yo[i_var_obs, t] - torch.mv(H[i_var_obs, :, t], Xf[:, t])
                # l -= .5 * np.log(np.linalg.det(sig))
                # sign, l_tmp = torch.slogdet(sig)# to print
                l -= 0.5 * torch.log(torch.det(sig))  # .5 * sign * l_tmp
                # print('inov : ',innov.shape)
                # print('sig : ',sig.shape)
                tmpp = (
                    0.5
                    * torch.mm(
                        innov.unsqueeze(0).T, torch.solve(innov.unsqueeze(0), sig)[0]
                    )[0, 0]
                )  # to print
                l -= tmpp
                inno_all[t, i_var_obs] = innov.detach().numpy()
                sig_all[t, :, :][np.ix_(i_var_obs, i_var_obs)] = sig.detach().numpy()
        return l, inno_all, sig_all

    def _EKF(self, Nx, No, T, Yo, H, alpha):
        Xa = torch.zeros((Nx, T + 1))
        Xf = torch.zeros((Nx, T))
        Pa = torch.zeros((Nx, Nx, T + 1))
        Pf = torch.zeros((Nx, Nx, T))
        F_all = torch.zeros((Nx, Nx, T))
        H_all = torch.zeros((No, Nx, T))
        Kf_all = torch.zeros((Nx, No, T))
        K = torch.zeros((Nx, No))

        x = self.X0
        Xa[:, 0] = x
        P = torch.abs(self.B)
        Pa[:, :, 0] = P
        for t in range(T):
            # Forecast
            # print(F.shape)
            # print(x.shape)
            F = self.Koopman_Propagator()
            x = torch.mv(F, x)
            # diag_Q = self.Q
            # Q = Q*Q' with Q = reshape(n,n) of output of model_Q for full covariance support
            # Q      = torch.diag(diag_Q)#self.Q_const#
            P = torch.mm(torch.mm(F, P), F.T) + torch.mm(
                self.Q, self.Q.T
            )  # self.Q# + Q#
            P = 0.5 * (P + P.T)
            Pf[:, :, t] = P
            Xf[:, t] = x
            F_all[:, :, t] = F

            # Update
            i_var_obs = np.where(~np.isnan(Yo.detach().cpu().data.numpy()[:, t]))[0]
            # print(t,i_var_obs,T)
            if len(i_var_obs) > 0:
                d = Yo[i_var_obs, t] - torch.mv(H[i_var_obs, :], x)
                S = torch.mm(
                    torch.mm(H[i_var_obs, :], P), H[i_var_obs, :].T
                ) + torch.abs(self.weight_R) * (
                    self.R[np.ix_(i_var_obs, i_var_obs)]
                )  # torch.abs(self.R)[np.ix_(i_var_obs,i_var_obs)]/alpha
                K = torch.mm(torch.mm(P, H[i_var_obs, :].T), (S.inverse()))
                P = torch.mm((torch.eye(Nx) - torch.mm(K, H[i_var_obs, :])), P)
                x = x + torch.mv(K, d)
                Kf_all[:, i_var_obs, t] = K
            Pa[:, :, t + 1] = P
            Xa[:, t + 1] = x
            H_all[:, :, t] = H

        return Xa, Pa, Xf, Pf, F_all, H_all, Kf_all

    def _EKS(self, Nx, No, T, Yo, H, alpha):
        Xa, Pa, Xf, Pf, F, H, Kf = self._EKF(Nx, No, T, Yo, H, alpha)
        Xs = torch.zeros((Nx, T + 1))
        Ps = torch.zeros((Nx, Nx, T + 1))
        K_all = torch.zeros((Nx, Nx, T))

        x = Xa[:, -1]
        Xs[:, -1] = x
        P = Pa[:, :, -1]
        Ps[:, :, -1] = P
        for t in range(T - 1, -1, -1):
            K = torch.mm(torch.mm(Pa[:, :, t], F[:, :, t].T), (Pf[:, :, t].inverse()))
            x = Xa[:, t] + torch.mv(K, x - Xf[:, t])
            P = Pa[:, :, t] - torch.mm(torch.mm(K, Pf[:, :, t] - P), K.T)

            Ps[:, :, t] = P
            Xs[:, t] = x
            K_all[:, :, t] = K

        # pykalman
        Ps_lag = torch.zeros((Nx, Nx, T))
        # Ps_lag[:,:,-1] = ((np.eye(Nx)-Kf[:,:,-1].dot(H[:,:,-1])).dot(F[:,:,-1]).dot(Pa[:,:,-2]))
        for t in range(1, T):
            Ps_lag[:, :, t] = torch.mm(Ps[:, :, t], K_all[:, :, t - 1].T)
        return Xs, Ps, Ps_lag, Xa, Pa, Xf, Pf, H

    def EKS(self, params):
        Yo = params["observations"]
        xb = params["background_state"]
        B = params["background_covariance"]
        Q = params["model_noise_covariance"]
        R = params["observation_noise_covariance"]
        F = params["model_dynamics"]
        # jacF = params['model_jacobian']
        H = params["observation_operator"]
        # jacH = params['observation_jacobian']
        Nx = params["state_size"]
        No = params["observation_size"]
        T = params["temporal_window_size"]
        Xt = params["true_state"]
        alpha = params["inflation_factor"]

        Xs, Ps, Ps_lag, Xa, Pa, Xf, Pf, H = _EKS(Nx, No, T, Yo, H, alpha)
        l, inno_all, sig_all = _likelihood(Xf, Pf, Yo, R, H)
        cov_p = cov_prob(Xs, Ps, Xt)

        res = {
            "smoothed_states": Xs,
            "smoothed_covariances": Ps,
            "smoothed_lagged_covariances": Ps_lag,
            "analysis_states": Xa,
            "analysis_covariance": Pa,
            "forecast_states": Xf,
            "forecast_covariance": Pf,
            "RMSE": RMSE(Xs - Xt),
            "params": params,
            "loglikelihood": l,
            "cov_prob": cov_p,
        }
        return res

    def _ML_crit_numpy(self, Xs, Ps, Ps_lag, Yo, H):
        Xs = Xs.detach().numpy()
        Ps = Ps.detach().numpy()

        Ps_lag = Ps_lag.detach().numpy()
        Yo = Yo.detach().numpy()
        H = H.detach().numpy()
        F = self.Koopman_Propagator().detach().numpy()

        No = Yo.shape[0]
        T = Yo.shape[1]
        Nx = Xs.shape[0]

        No = Yo.shape[0]
        T = Yo.shape[1]
        Nx = Xs.shape[0]

        xb = Xs[:, 0]
        B = Ps[:, :, 0]
        R = 0
        nobs = 0
        sumSig = 0

        # Dreano et al. 2017, Eq. (34)
        for t in range(T):
            if not np.isnan(Yo[0, t]):
                nobs += 1
                R += np.outer(
                    Yo[:, t] - np.dot(H, Xs[:, t + 1]),
                    Yo[:, t] - np.dot(H, Xs[:, t + 1]),
                )
                R += H.dot(Ps[:, :, t + 1]).dot(H.T)
        R = 0.5 * (R + R.T)
        sigma1 = 0.5 * (R + R.T)
        # R /= nobs

        # for Shumway 1982
        mat_A = 0
        mat_B = 0
        mat_C = 0

        for t in range(T):

            # Dreano et al. 2017, Eq. (33)
            sumSig += Ps[:, :, t + 1]
            sumSig += np.outer(
                Xs[:, t + 1] - np.dot(F, Xs[:, t]), Xs[:, t + 1] - np.dot(F, Xs[:, t])
            )  # CAUTION: error in Dreano equations
            sumSig += F.dot(Ps[:, :, t]).dot(F.T)
            sumSig -= Ps_lag[:, :, t].dot(F.T) + F.dot(
                Ps_lag[:, :, t].T
            )  # CAUTION: transpose at the end (error in Dreano equations)
            sumSig = 0.5 * (sumSig + sumSig.T)

        sigma2 = sumSig  # /T
        sigma3 = Ps[:, :, 0]

        R = self.R.detach().numpy()
        Q = self.Q.detach().numpy()
        B = self.B.detach().numpy()

        ll = (
            -0.5 * np.log(np.linalg.norm(np.abs(B)))
            - 0.5 * np.trace((np.dot(np.linalg.inv(np.abs(B)), sigma3)))
            - 0.5 * T * np.log(np.linalg.norm(np.abs(Q)))
            - 0.5 * np.trace((np.dot(np.linalg.inv(np.abs(Q)), sigma2)))
            - 0.5 * T * np.log(np.linalg.norm(np.abs(R)))
            - 0.5 * np.trace((np.dot(np.linalg.inv(np.abs(R)), sigma1)))
        )

        return ll

    def _Get_3_sigmas(self, Xs, Ps, Ps_lag, Yo, H):
        No = Yo.shape[0]
        T = Yo.shape[1]
        Nx = Xs.shape[0]

        xb = Xs[:, 0]
        B = Ps[:, :, 0]
        nobs = 0
        tmp = 0
        sumSig = 0
        # Dreano et al. 2017, Eq. (34)
        for t in range(T):
            i_var_obs = np.where(~np.isnan(Yo.detach().cpu().data.numpy()[:, t]))[0]
            if len(i_var_obs) > 0:
                nobs += 1
                # H = jacH(Xs[:,t+1])
                tmp += torch.outer(
                    Yo[i_var_obs, t] - torch.mv(H[i_var_obs, :], Xs[:, t + 1]),
                    Yo[i_var_obs, t] - torch.mv(H[i_var_obs, :], Xs[:, t + 1]),
                )
                tmp += torch.mm(torch.mm(H, Ps[:, :, t + 1]), H.T)
        sigma1 = 0.5 * (tmp + tmp.T)
        # tmp /= nobs

        # for Shumway 1982
        mat_A = 0
        mat_B = 0
        mat_C = 0

        for t in range(T):
            # Dreano et al. 2017, Eq. (33)
            # F = jacF(Xs[:,t+1])
            sumSig += Ps[:, :, t + 1]

            sumSig += torch.outer(
                Xs[:, t + 1] - torch.mv(self.Koopman_Propagator(), Xs[:, t]),
                Xs[:, t + 1] - torch.mv(self.Koopman_Propagator(), Xs[:, t]),
            )  # CAUTION: error in Dreano equations
            sumSig += torch.mm(
                torch.mm(self.Koopman_Propagator(), Ps[:, :, t]),
                self.Koopman_Propagator().T,
            )
            sumSig -= torch.mm(Ps_lag[:, :, t], self.Koopman_Propagator().T) + torch.mm(
                self.Koopman_Propagator(), Ps_lag[:, :, t].T
            )  # CAUTION: transpose at the end (error in Dreano equations)
            sumSig = 0.5 * (sumSig + sumSig.T)

        sigma2 = sumSig  # /T
        sigma3 = Ps[:, :, 0]
        self.B.data = Ps[:, :, 0]
        self.X0.data = xb
        ll = (
            -0.5 * torch.log(torch.norm(torch.abs(self.B)))
            - 0.5 * torch.trace((torch.mm(torch.abs(self.B).inverse(), sigma3)))
            - 0.5 * T * torch.log(torch.norm(torch.abs(self.Q)))
            - 0.5 * torch.trace((torch.mm(torch.abs(self.Q).inverse(), sigma2)))
            - 0.5 * T * torch.log(torch.norm(torch.abs(self.R)))
            - 0.5 * torch.trace((torch.mm(torch.abs(self.R).inverse(), sigma1)))
        )
        return ll
        # return sigma1, sigma2, sigma3

    def _ML_crit(self, sigma1, sigma2, sigma3, T):
        ll = (
            -0.5 * torch.log(torch.norm(torch.abs(self.B)))
            - 0.5 * torch.trace((torch.mm(torch.abs(self.B).inverse(), sigma3)))
            - 0.5 * T * torch.log(torch.norm(torch.abs(self.Q)))
            - 0.5 * torch.trace((torch.mm(torch.abs(self.Q).inverse(), sigma2)))
            - 0.5 * T * torch.log(torch.norm(torch.abs(self.R)))
            - 0.5 * torch.trace((torch.mm(torch.abs(self.R).inverse(), sigma1)))
        )
        return ll

    def forward(self, params, empty=False):
        if empty:
            return 0
        else:
            # jacF  = params['model_jacobian']
            H = params["observation_operator"]
            # jacH  = params['observation_jacobian']
            Yo = params["observations"]
            nIter = params["nb_EM_iterations"]
            Xt = params["true_state"]
            Nx = params["state_size"]
            No = params["observation_size"]
            T = params["temporal_window_size"]
            alpha = params["inflation_factor"]
            estimateQ = params["is_model_noise_covariance_estimated"]
            estimateR = params["is_observation_noise_covariance_estimated"]
            estimateX0 = params["is_background_estimated"]
            Xs, Ps, Ps_lag, Xa, Pa, Xf, Pf, H_all = self._EKS(Nx, No, T, Yo, H, alpha)
            # ll = self._ML_crit_numpy(Xs, Ps, Ps_lag, Yo, H)

            # LL_TORCH = self._Get_3_sigmas(Xs, Ps, Ps_lag, Yo, H)
            # LL_TORCH = self._ML_crit(S1,S2,S3,Yo.shape[1])

            loglik, inno_all, sig_all = self._likelihood(Xf, Pf, Yo, H_all, True)
            rmse_em = self.RMSE(torch.mm(H, Xs[:, 1:]) - Xt)
            cov_prob_em = self.cov_prob(Xs[:, 1:], Ps[:, :, 1:], Xt)

            res = {
                "forecast": Xf,
                "smoothed_states": Xs,
                "background_state": self.X0,
                "background_covariance": torch.abs(self.B),
                "model_noise_covariance": torch.mm(self.Q, self.Q.T),
                "observation_noise_covariance": torch.abs(self.weight_R) * (self.R),
                #'loglikelihood_crit'             : ll,
                "loglikelihood": loglik,
                "RMSE": rmse_em,
                "cov_prob": cov_prob_em,
                "cov_smooth": Ps,
                "cov_lag": Ps_lag,
                "innoAll": inno_all,
                "sigAll": sig_all,
                "params": params,
            }
            return res

In [None]:
from tqdm import tqdm

model_E2E_EM_KS = E2E_EM_KS()

In [None]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.NAdam(model_E2E_EM_KS.parameters())

In [None]:
for param_group in optimizer.param_groups:
    param_group["lr"] = 0.01

In [None]:
Xb_all, B_all, R_all, Q_all, F_all = [], [], [], [], []

In [None]:
# SSM parameters

# EM parameters
Nx = dim_aug
No = dim
# N_iter = 100
Q_init = torch.eye(Nx)
R_init = torch.eye(No)
T = Batch_size
# parameters
params = {
    "initial_background_state": torch.zeros((dim_aug)),
    "initial_background_covariance": torch.eye(Nx),
    "initial_model_noise_covariance": Q_init,
    "initial_observation_noise_covariance": R_init,
    "state_size": Nx,
    "observation_size": No,
    "temporal_window_size": T,
    "model_noise_covariance_structure": "full",
    "is_model_noise_covariance_estimated": True,
    "is_observation_noise_covariance_estimated": True,
    "is_background_estimated": True,
    "inflation_factor": 1,
}
params["observation_operator"] = torch.zeros((dim, dim_aug))
params["observation_operator"][:dim, :dim] = torch.eye(dim)

In [None]:
X_train_batched = torch.from_numpy(X_Train.reshape(nb_batch, Batch_size, dim)).float()
Y_train_batched = torch.from_numpy(Y_Train.reshape(nb_batch, Batch_size, dim)).float()

In [None]:
for param_group in optimizer.param_groups:
    param_group["lr"] = 0.01

In [None]:
plt.plot(X_train_batched[3, :, 0], "*")
plt.plot(
    Y_train_batched[3, :, 0],
)

In [None]:
b = 3
params["observations"] = torch.from_numpy(X_Test.T).float()
params["nb_EM_iterations"] = 1
params["temporal_window_size"] = X_Test.shape[0]
params["true_state"] = torch.from_numpy(Y_Test.T).float()
res_EM_EKS_init = model_E2E_EM_KS(params)

In [None]:
model_E2E_EM_KS.Q.grad

In [None]:
optimizer

In [None]:
model_E2E_EM_KS.load_state_dict(torch.load(output_model_name + ".pt"))

In [None]:
params["ntrain"] = [2000, 10000]
for t in range(params["ntrain"][0]):
    for b in range(nb_batch):
        # Forward pass: Compute predicted y by passing x to the model
        params["observations"] = X_train_batched[b, :, :].T
        params["nb_EM_iterations"] = 1
        params["temporal_window_size"] = Batch_size
        params["true_state"] = Y_train_batched[b, :, :].T
        res_EM_EKS = model_E2E_EM_KS(params)
        loss = -res_EM_EKS["loglikelihood"]
        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        Xb_all.append(model_E2E_EM_KS.X0.detach().clone().numpy())
        # model_E2E_EM_KS.R.data = model_E2E_EM_KS.R.data-0.000001*model_E2E_EM_KS.R.grad.data
        # model_E2E_EM_KS.Q.data = model_E2E_EM_KS.Q.data-0.000001*model_E2E_EM_KS.Q.grad.data
        B_all.append(torch.abs(model_E2E_EM_KS.B).detach().clone().numpy())
        R_all.append(
            (torch.abs(model_E2E_EM_KS.weight_R) * (model_E2E_EM_KS.R))
            .detach()
            .clone()
            .numpy()
        )
        Q_all.append(torch.abs(model_E2E_EM_KS.Q).detach().clone().numpy())
        F_all.append(
            ((model_E2E_EM_KS.A - model_E2E_EM_KS.A.T) / 2).detach().clone().numpy()
        )
    print(t, b, loss)
    if t > 400:
        for param_group in optimizer.param_groups:
            param_group["lr"] = 0.001

In [None]:
torch.save(model_E2E_EM_KS.state_dict(), output_model_name + ".pt")
Dyn_Mat = ((model_E2E_EM_KS.A - model_E2E_EM_KS.A.T) / 2).detach().clone()
torch.linalg.eig(Dyn_Mat)[0].imag

# testing the training set

In [None]:
b = 3
params["observations"] = X_train_batched[b, :, :].T
params["nb_EM_iterations"] = 1
params["temporal_window_size"] = Batch_size
params["true_state"] = Y_train_batched[b, :, :].T
res_EM_EKS = model_E2E_EM_KS(params)

In [None]:
res_EM_EKS["smoothed_states"][:, 1]

In [None]:
plt.plot(res_EM_EKS["smoothed_states"][0, 1:].detach())
plt.plot(params["true_state"][0, :])
plt.plot(params["observations"][0, :], "*")

In [None]:
xs = res_EM_EKS["smoothed_states"].detach()
# ll=res_EM_EKS['loglikelihood_crit']
cov_prob = res_EM_EKS["cov_prob"]
Ps = res_EM_EKS["cov_smooth"].detach()

length_plot = params["temporal_window_size"]  # -200
plt.figure(figsize=(10, 10))
plt.plot(xs[0, :length_plot])
plt.fill_between(
    np.arange(length_plot),
    xs[0, :length_plot] - torch.sqrt(Ps[0, 0, :length_plot]),
    xs[0, :length_plot] + torch.sqrt(Ps[0, 0, :length_plot]),
    label="Smoothed PDF",
)
plt.plot(params["observations"].T, "o", label="observations")
plt.plot(params["true_state"].T, label="True state")
plt.title("cov prob : " + str(cov_prob))
plt.legend()

In [None]:
Dyn_sys = (model_E2E_EM_KS.A - model_E2E_EM_KS.A.T) / 2

In [None]:
torch.linalg.eig(Dyn_sys)[0].imag

In [None]:
# inp_concat = torch.cat((X_Train_Dyn_batched[0,:,:], model_E2E_EM_KS.y_aug[0,:,:]), dim=1)
y_pred2 = np.zeros((2000, dim_aug))
tmp = res_EM_EKS["smoothed_states"][:, 100:101].T
tmp[:, :] = tmp[:, :]
y_pred2[0, :] = tmp.cpu().data.numpy()
for i in range(1, 2000):
    F = model_E2E_EM_KS.Koopman_Propagator()  # [0,:,:]
    y_pred2[i, :] = torch.mv(F, torch.from_numpy(y_pred2[i - 1, :]).float()).detach()

In [None]:
plt.plot(y_pred2[:200, 0])
plt.plot(res_EM_EKS["smoothed_states"][0, 100:].detach())
plt.plot()

# test set results

In [None]:
b = 3
params["observations"] = torch.from_numpy(X_Test.T).float()
params["nb_EM_iterations"] = 1
params["temporal_window_size"] = X_Test.shape[0]
params["true_state"] = torch.from_numpy(Y_Test.T).float()
res_EM_EKS = model_E2E_EM_KS(params)

In [None]:
xs = res_EM_EKS["smoothed_states"].detach()
# ll=res_EM_EKS['loglikelihood_crit']
cov_prob = res_EM_EKS["cov_prob"]
Ps = res_EM_EKS["cov_smooth"].detach()

length_plot = params["temporal_window_size"]  # -200
plt.figure(figsize=(10, 6))
plt.plot(params["true_state"].T, label="True state", lw=3, alpha=0.3)
plt.plot(params["observations"].T, "o", label="observations", markersize=16)
plt.plot(xs[0, :length_plot], color="r", label="PDF mean")
plt.fill_between(
    np.arange(length_plot),
    xs[0, :length_plot] - torch.sqrt(Ps[0, 0, :length_plot]),
    xs[0, :length_plot] + torch.sqrt(Ps[0, 0, :length_plot]),
    color="tab:green",
    lw=6,
    label="PDF Standard deviation",
)
# plt.title('cov prob : '+ str(cov_prob))
plt.legend(loc=9, bbox_to_anchor=(0.7, 1.2))
plt.xlabel("Time step (adimensional)")
plt.ylabel("SST °C")
plt.savefig("smoothing_sst_pdf.png")
plt.savefig("smoothing_sst_pdf.pdf")
plt.savefig("smoothing_sst_pdf.svg")

In [None]:
import scipy.stats as stats

mu = 0
sigma_comp = np.mean(
    np.sqrt(
        res_EM_EKS["sigAll"][
            np.where(~np.isnan(res_EM_EKS["innoAll"]))[0],
            np.where(~np.isnan(res_EM_EKS["innoAll"]))[1],
            np.where(~np.isnan(res_EM_EKS["innoAll"]))[1],
        ]
    )
)
x_comp = np.linspace(mu - 3 * sigma_comp, mu + 3 * sigma_comp, 100)

sigma_samp = np.std(res_EM_EKS["innoAll"][np.where(~np.isnan(res_EM_EKS["innoAll"]))])
x_samp = np.linspace(mu - 3 * sigma_samp, mu + 3 * sigma_samp, 100)


plt.plot(x_comp, stats.norm.pdf(x_comp, mu, sigma_comp), label="formula Q/R")
plt.plot(x_samp, stats.norm.pdf(x_samp, mu, sigma_samp), label="sample STD")
plt.legend()
plt.show()

In [None]:
b = 3
params["observations"] = torch.from_numpy(X_Test.T).float()
params["observations"][:, 200:] = np.nan
params["nb_EM_iterations"] = 1
params["temporal_window_size"] = X_Test.shape[0]
params["true_state"] = torch.from_numpy(Y_Test.T).float()
res_EM_EKS = model_E2E_EM_KS(params)

In [None]:
xs = res_EM_EKS["smoothed_states"].detach()
# ll=res_EM_EKS['loglikelihood_crit']
cov_prob = res_EM_EKS["cov_prob"]
Ps = res_EM_EKS["cov_smooth"].detach()

length_plot = params["temporal_window_size"]  # -200
plt.figure(figsize=(10, 10))
plt.plot(xs[0, :length_plot])
plt.fill_between(
    np.arange(length_plot),
    xs[0, :length_plot] - torch.sqrt(Ps[0, 0, :length_plot]),
    xs[0, :length_plot] + torch.sqrt(Ps[0, 0, :length_plot]),
    label="Standard deviation",
)
plt.plot(params["observations"].T, "o", label="observations")
plt.plot(params["true_state"].T, label="True state")
# plt.title('cov prob : '+ str(cov_prob))
plt.legend()

In [None]:
import scipy.stats as stats

mu = 0
sigma_comp = np.mean(
    np.sqrt(
        res_EM_EKS["sigAll"][
            np.where(~np.isnan(res_EM_EKS["innoAll"]))[0],
            np.where(~np.isnan(res_EM_EKS["innoAll"]))[1],
            np.where(~np.isnan(res_EM_EKS["innoAll"]))[1],
        ]
    )
)
x_comp = np.linspace(mu - 3 * sigma_comp, mu + 3 * sigma_comp, 100)

sigma_samp = np.std(res_EM_EKS["innoAll"][np.where(~np.isnan(res_EM_EKS["innoAll"]))])
x_samp = np.linspace(mu - 3 * sigma_samp, mu + 3 * sigma_samp, 100)


sigma_comp_init = np.mean(
    np.sqrt(
        res_EM_EKS_init["sigAll"][
            np.where(~np.isnan(res_EM_EKS_init["innoAll"]))[0],
            np.where(~np.isnan(res_EM_EKS_init["innoAll"]))[1],
            np.where(~np.isnan(res_EM_EKS_init["innoAll"]))[1],
        ]
    )
)
x_comp_init = np.linspace(mu - 3 * sigma_comp_init, mu + 3 * sigma_comp_init, 100)


res_EM_EKS_init

plt.plot(
    x_comp, stats.norm.pdf(x_comp_init, mu, sigma_comp_init), label="formula Q/R init"
)
plt.plot(x_comp, stats.norm.pdf(x_comp, mu, sigma_comp), label="formula Q/R")
plt.plot(x_samp, stats.norm.pdf(x_samp, mu, sigma_samp), label="sample STD")
plt.legend()
plt.show()

In [None]:
optimizer = torch.optim.Adam(
    [
        {"params": model_E2E_EM_KS.A},
        {"params": model_E2E_EM_KS.weight_R, "lr": 1e-2},
        {"params": model_E2E_EM_KS.Q, "lr": 1e-2},
    ],
    lr=1e-3,
)

In [None]:
print(t,b,loss)
    if 400>t>300:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.01
    if t>400:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.001