In [None]:
############################# Import Section #################################

## Imports related to PyTorch
import torch
import torchvision
import torch.nn as nn
import torch.utils.data as Data
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torchvision import transforms, utils
from torch.utils.data import TensorDataset, DataLoader,Dataset
import torch.nn.functional as F

## Generic imports
import os
import time
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
from sklearn.metrics import classification_report
from copy import deepcopy
import math
import random
from tqdm import tqdm

## Dependencies classes and functions
from utils import gridRing
from utils import asMinutes
from utils import timeSince
from utils import getWeights
from utils import save_checkpoint
from utils import getListOfFolders

## Import Model
from DyanOF import OFModel,creatRealDictionary,fista

############################# Import Section #################################

In [None]:
def load_data(df,chunk_size=1):
    X = []
    for i in tqdm(df.index.unique()):
        x = torch.FloatTensor(df.loc[i].values)
        size = x.shape[0]
        if chunk_size > 1:
            size = int(size/chunk_size)
        x = torch.chunk(x,chunk_size)
        X.extend(x)
    X = torch.stack(X, 0)
    return X

def load_dataset(dataset_name = "lorenz",file_path=r'C:\Users\lpott\Desktop\DYAN\Code\data',chunk_size=1):

    with open(os.path.join(file_path,f"{dataset_name}/{dataset_name}_train_inputs.pickle"), "rb") as handle:
        train_df = pickle.load(handle)

    with open(os.path.join(file_path,f"{dataset_name}/{dataset_name}_test_inputs.pickle"), "rb") as handle:
        test_df = pickle.load(handle)

    X_train = load_data(train_df,chunk_size)
    X_test = load_data(test_df,chunk_size)

    return X_train,X_test

In [None]:
class differential_dataset(Dataset):

    def __init__(self,X,horizon):

        self.X = X
        self.horizon = horizon
        self.D = X.shape[-1]
        self.T = X.shape[1]-self.horizon+1
        
        print(self.horizon)
        print(self.T)
#         self.mu = torch.mean(X,dim=[0,1])#torch.tensor([torch.mean(X[:,:,0]), torch.mean(X[:,:,1]), torch.mean(X[:,:,2])])
#         self.std = torch.std(X,dim=[0,1])#torch.tensor([torch.std(X[:,:,0]), torch.std(X[:,:,1]), torch.std(X[:,:,2])])
        self.mu = torch.tensor([torch.mean(X[:,:,i]) for i in range(self.D)])
        self.std = torch.tensor([torch.std(X[:,:,i]) for i in range(self.D)])
    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self,idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()
        if type(idx) is int:
            idx = [idx]

        
        start = torch.randint(low=0,high=self.T,size=(len(idx),))
        windows = torch.tensor([list(range(i,i+self.horizon)) for i in start]).unsqueeze(-1).repeat(1,1,self.D)
        x = torch.gather(self.X[idx],1,windows).squeeze()

        return x

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(1)].permute(1,0,2)
        return x

In [None]:
class encoder(nn.Module):
    def __init__(self,D,embed_dim,latent_dim,dimforward,nhead,encoder_layers=1,device='cuda:0'):
        super(encoder,self).__init__()
        self.D = D
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.device = device
        self.nhead = nhead
        
        self.input_projection = nn.Linear(D,embed_dim)
        
        self.tencoder = nn.ModuleList([nn.TransformerEncoderLayer(d_model=embed_dim,dim_feedforward=dimforward,nhead=nhead,batch_first=True,dropout=0) for i in range(encoder_layers)])
#         self.tencoder1 = nn.TransformerEncoderLayer(d_model=embed_dim,nhead=nhead,batch_first=True,dropout=.1)
#         self.tencoder2 = nn.TransformerEncoderLayer(d_model=embed_dim,nhead=nhead,batch_first=True,dropout=.1)

        self.projection = nn.Linear(embed_dim,latent_dim)
        self.pos_encoder = PositionalEncoding(embed_dim,dropout=0)

        
    def forward(self,x):
#         x = self.pos_encoder(x)
#         x = torch.tanh(self.input_projection(x))
        x = self.pos_encoder(self.input_projection(x))
        sz = x.shape[1]
#         mask = self.generate_square_subsequent_mask(sz)
        for encoder in self.tencoder:
            x = torch.tanh(encoder(x))
#         x = torch.tanh(self.tencoder1(x))
#         x = torch.tanh(self.tencoder2(x))

        latent = self.projection(x)
        
        return latent
    
    def generate_square_subsequent_mask(self,sz):
        """Generates an upper-triangular matrix of -inf, with zeros on diag."""
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1).to(self.device)

In [None]:
class decoder(nn.Module):
    def __init__(self,D,embed_dim,latent_dim,dimforward,nhead,decoder_layers=1,device='cuda:0'):
        super(decoder,self).__init__()
        self.D = D
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.device = device
        self.nhead = nhead
        
        self.output_projection = nn.Linear(latent_dim,embed_dim)
        
        
        self.tdecoder = nn.ModuleList([nn.TransformerEncoderLayer(d_model=embed_dim,nhead=nhead,dim_feedforward=dimforward,batch_first=True,dropout=.0) for i in range(decoder_layers)])

#         self.tdecoder1 = nn.TransformerEncoderLayer(d_model=embed_dim,nhead=nhead,batch_first=True,dropout=.1)
#         self.tdecoder2 = nn.TransformerEncoderLayer(d_model=embed_dim,nhead=nhead,batch_first=True,dropout=.1)

        self.projection = nn.Linear(embed_dim,D)
        self.pos_encoder = PositionalEncoding(embed_dim,dropout=0)

    def forward(self,x):
#         x = self.pos_encoder(x)
#         x = torch.tanh(self.output_projection(x))
        x = self.pos_encoder(self.output_projection(x))
        sz = x.shape[1]
        mask = self.generate_square_subsequent_mask(sz)
        for i,decoder in enumerate(self.tdecoder):
            x = torch.tanh(decoder(x,mask))

#         x = torch.tanh(self.tdecoder1(x))
#         x = torch.tanh(self.tdecoder2(x))

        x = self.projection(x)
        
        return x
    
    def generate_square_subsequent_mask(self,sz):
        """Generates an upper-triangular matrix of -inf, with zeros on diag."""
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1).to(self.device)

In [None]:
class TDYANT(nn.Module):
    def __init__(self, 
                 Drr, 
                 Dtheta,
                 N,
                 D,
                 embed_dim,
                 latent_dim,
                 dimforward,
                 nhead,
                 encoder_layers=1,
                 decoder_layers=1,
                 device='cuda:0',
                clamp=2):
        super(TDYANT, self).__init__()
        
        self.rr = nn.Parameter(Drr)
        self.theta = nn.Parameter(Dtheta)
        #self.T = T
        self.device = device
        self.latent_dim = latent_dim
        self.embed_dim = embed_dim
        self.nhead = nhead
        self.dimforward = dimforward
        
        self.encoder_ = encoder(D,embed_dim,latent_dim,dimforward,nhead,encoder_layers,device)
        self.decoder_ = decoder(D,embed_dim,latent_dim,dimforward,nhead,decoder_layers,device)
        self.clamp = clamp
        
        self.register_buffer('mu', torch.zeros((D,)))
        self.register_buffer('std', torch.ones((D,)))
        
    def forward(self, x,horizon):
        x = self._normalize(x)
        x_recon,x_ahead = x[:,:horizon,:],x[:,horizon:,:]
        T = x_recon.shape[1]
        T_Total = x_recon.shape[1] + x_ahead.shape[1]
        
        if self.clamp == 0:
            latent = self.encoder_(x_recon)
        else:
            latent = torch.tanh(self.encoder_(x_recon))*self.clamp #torch.clamp(self.encoder_(x),-2,2)

        dic = creatRealDictionary(T_Total,self.rr,self.theta,device=self.device)
        dic_recon = dic[:horizon]
#         dic_ahead = dic[horizon:]
        
        ## for UCF Dataset:
        # 0.1
        sparsecode = fista(dic_recon,latent,0.01,100,self.device)
#         y_recon = torch.matmul(dic_recon,sparsecode)
        
#         dic = creatRealDictionary(T_Total,self.rr,self.theta,self.device)
#         y_ahead = torch.matmul(dic_ahead,sparsecode) #[:,horizon:,:]
        y = torch.matmul(dic,sparsecode)
        y_recon = y[:,:horizon,:]
        y_ahead = y[:,horizon:,:]
        ## for Kitti Dataset: sparsecode = fista(dic,x,0.01,80,self.gid)
        
        x_recon_all = self.decoder_(y)
        x_recon_hat = x_recon_all[:,:horizon,:]
        x_ahead_hat = x_recon_all[:,horizon:,:]

#         x_recon_hat = self.decoder_(y_recon)
#         x_ahead_hat = self.decoder_(y_ahead)

        # x is the outer layer , y is the inner layer
        x_recon_hat = self._unnormalize(x_recon_hat)
        x_ahead_hat = self._unnormalize(x_ahead_hat)

        return x_recon_hat,x_ahead_hat,latent,y_recon,y_ahead
    
    def _normalize(self, x):
        return (x - self.mu.unsqueeze(0).unsqueeze(0))/self.std.unsqueeze(0).unsqueeze(0)    
    
    def _unnormalize(self, x):
        return self.std.unsqueeze(0).unsqueeze(0)*x + self.mu.unsqueeze(0).unsqueeze(0)

In [None]:
def dynamic_loss(x,model,horizon,alpha=1):
    x_recon,x_ahead = x[:,:horizon,:],x[:,horizon:,:]
    x_recon_hat,x_ahead_hat,latent,y_recon,y_ahead = model(x,horizon)
    
    reconstruction_loss = F.mse_loss(x_recon_hat,x_recon)
    prediction_loss = F.mse_loss(x_ahead_hat,x_ahead)
    
    loss = reconstruction_loss + alpha*prediction_loss
    
    with torch.no_grad():
        MSE_LOSS_latent = torch.nn.functional.mse_loss(y_recon,latent)#
        NORM_LOSS = torch.norm(y_recon,p='fro',dim=2).mean()
    
    return loss,reconstruction_loss,prediction_loss,MSE_LOSS_latent,NORM_LOSS

In [None]:
## HyperParameters for the Network
NumOfPoles = 60

N = NumOfPoles*4

Time_Length = 88

In [None]:
## Load saved model 
load_ckpt = False
checkptname = "lorenz_prediction_short" #"lorenz_prediction\lorenz_prediction50_best"
dataset_name = "lorenz"
file_path=r'C:\Users\lpott\Desktop\DYAN\Code\data'
ckpt_file = f"data/{dataset_name}/"+checkptname+str(170)+'.pth' # for Kitti Dataset: 'KittiModel.pth'


In [None]:
X_train,X_test = load_dataset(dataset_name=dataset_name,file_path=file_path,chunk_size=1)
print("X_tr shape: ",X_train.shape)
print("X_te shape: ",X_test.shape)

In [None]:
## Initializing r, theta
P,Pall = gridRing(N)
Drr = abs(P)
Drr = torch.from_numpy(Drr).float() #.to(device)
Dtheta = np.angle(P)
Dtheta = torch.from_numpy(Dtheta).float() #.to(device)

In [None]:
import time

In [None]:
# MODEL PARAMETERS
embed_dim = 256
latent_dim= 128
dimfeedforward = 1024
nhead=32
D = X_train.shape[-1]
clamp = 0
encoder_layers=1
decoder_layers=2
device = torch.device("cuda:0")

# TRAINING PARAMETERS
BATCH_SIZE = 128

LR = 0.0001
EPOCH = 300
print_every = 5
saveEvery = 10

# TIME PARAMETERS
Time_Length = Time_Length
horizon = 16
alpha = 2

In [None]:
N_train = X_train.shape[0]
N_test = X_test.shape[0]

In [None]:
train_dl = DataLoader(differential_dataset(X_train,Time_Length),batch_size=BATCH_SIZE,shuffle=True)
test_dl = DataLoader(differential_dataset(X_test,Time_Length),batch_size=BATCH_SIZE)
test_evaluator_dl = DataLoader(TensorDataset(X_test[:,-Time_Length:,:]),batch_size=BATCH_SIZE)

In [None]:
for i in train_dl:
    print(i.shape)
    break
plt.plot(i[0,:,0])
plt.plot(i[0,:,1])
# plt.plot(i[0,:,2])

In [None]:
## Create the model
model = TDYANT(Drr,
                Dtheta,
                N ,
                D, 
                embed_dim,
                latent_dim,
               dimfeedforward,
               nhead,
                encoder_layers,
                decoder_layers,
                device,
                clamp).to(device)

model.mu = train_dl.dataset.mu.to(device)
model.std = train_dl.dataset.std.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR,weight_decay=1e-8)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[150,200], gamma=0.1) # if Kitti: milestones=[100,150]

In [None]:
model.device = "cuda:0"
model = model.cuda()
device = "cuda:0"

In [None]:
torch.cuda.empty_cache()
start_epoch = 1

## If want to continue training from a checkpoint
if(load_ckpt):
    print("LOADING CHECKPT")
    loadedcheckpoint = torch.load(ckpt_file)
    start_epoch = loadedcheckpoint['epoch']
    model.load_state_dict(loadedcheckpoint['state_dict'])
    optimizer.load_state_dict(loadedcheckpoint['optimizer'])

print("Training from epoch: ", start_epoch)
print('-' * 25)


In [None]:
print_every = 1

In [None]:
## Start the Training
for epoch in range(start_epoch, EPOCH+1):
   
    model.train()
    
    train_epoch_loss = []; test_epoch_loss = []
    train_norm_loss = []; test_norm_loss = []
    train_latent_loss = []; test_latent_loss = [];
    train_reconstruction_loss = []; test_reconstruction_loss = [];
    train_prediction_loss = []; test_prediction_loss = []
    
    start = time.time()
    for x in tqdm(train_dl):
        x = x.to(device)
        optimizer.zero_grad()
        loss,reconstruction_loss,prediction_loss,latent_loss,norm_loss = dynamic_loss(x,model,horizon,alpha=alpha)
        loss.backward()
        optimizer.step()
        train_epoch_loss.append(loss.item()*x.shape[0])
        train_latent_loss.append(latent_loss.item()*x.shape[0])
        train_norm_loss.append(norm_loss.item()*x.shape[0])
        train_reconstruction_loss.append(reconstruction_loss.item()*x.shape[0])
        train_prediction_loss.append(prediction_loss.item()*x.shape[0])
        end = time.time()
    torch.cuda.empty_cache()

    if (epoch)%print_every == 0:
        model.eval()
        with torch.no_grad():
            for x in tqdm(test_evaluator_dl):
                x = x[0].to(device)
#                 x = X_test[:,-72:,:].to(device)
                loss,reconstruction_loss,prediction_loss,latent_loss,norm_loss = dynamic_loss(x,model,horizon,alpha=alpha)
                test_epoch_loss.append(loss.item()*x.shape[0])
                test_latent_loss.append(latent_loss.item()*x.shape[0])
                test_norm_loss.append(norm_loss.item()*x.shape[0])
                test_reconstruction_loss.append(reconstruction_loss.item()*x.shape[0])
                test_prediction_loss.append(prediction_loss.item()*x.shape[0])
                
    print('Epoch: ', epoch)
    print("| train time: %.6f" % (end-start))
    print('| train loss: %.6f' % (np.sum(train_epoch_loss)/N_train))
    print('| train reconstruction loss: %.6f' % (np.sum(train_reconstruction_loss)/N_train))
    print('| train prediction loss: %.6f' % (np.sum(train_prediction_loss)/N_train))
    print('| train LATENT loss: %.6f' % (np.sum(train_latent_loss)/N_train))
    print('| LATENT NORM: %.6f' % (np.sum(train_norm_loss)/N_train))
    if (epoch)%print_every == 0:
        print('| val loss: %.6f' % (np.sum(test_epoch_loss)/N_test))
        print('| val reconstruction loss: %.6f' % (np.sum(test_reconstruction_loss)/N_test))
        print('| val prediction loss: %.6f' % (np.sum(test_prediction_loss)/N_test))
        print('| val LATENT loss: %.6f' % (np.sum(test_latent_loss)/N_test))
        print('| val LATENT NORM: %.6f' % (np.sum(test_norm_loss)/N_test))
    print("\n")

#     print("Classification Report:")
#     print(classification_report(labels,predictions,zero_division=1))
    
    if (epoch+1) % saveEvery ==0 :
        print("Saving Checkpoint")
        save_checkpoint({'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer' : optimizer.state_dict(),
                        },f"data/{dataset_name}/"+checkptname+str(epoch)+'.pth')

In [None]:
x.shape

In [None]:
# print("Saving Checkpoint")
# checkptname = "nonlinear_prediction_75" #"lorenz_prediction\lorenz_prediction50_best"
# save_checkpoint({'epoch': epoch + 1,
#                 'state_dict': model.state_dict(),
#                 'optimizer' : optimizer.state_dict(),
#                 },f"data/{dataset_name}/"+checkptname+str(epoch)+'.pth')

In [None]:
## If want to continue training from a checkpoint
ckpt_file = f"data/{dataset_name}/"+"lorenz_prediction"+str(250)+'.pth' # for Kitti Dataset: 'KittiModel.pth'
# checkptname = "lorenz_prediction\lorenz_400_best_noclamp85.pth"
if(True):
    print("LOADING CHECKPT")
    loadedcheckpoint = torch.load(ckpt_file)
    start_epoch = loadedcheckpoint['epoch']
    model.load_state_dict(loadedcheckpoint['state_dict'])
    optimizer.load_state_dict(loadedcheckpoint['optimizer'])

In [None]:
with torch.no_grad():
    model.eval()
    train_epoch_loss = []; test_epoch_loss = []
    train_norm_loss = []; test_norm_loss = []
    train_latent_loss = []; test_latent_loss = [];
    train_reconstruction_loss = []; test_reconstruction_loss = [];
    train_prediction_loss = []; test_prediction_loss = []
    for x in tqdm(test_evaluator_dl):
        x = x[0].to(device)
        loss,reconstruction_loss,prediction_loss,latent_loss,norm_loss = dynamic_loss(x,model,horizon,alpha=alpha)
        test_epoch_loss.append(loss.item()*x.shape[0])
        test_latent_loss.append(latent_loss.item()*x.shape[0])
        test_norm_loss.append(norm_loss.item()*x.shape[0])
        test_reconstruction_loss.append(reconstruction_loss.item()*x.shape[0])
        test_prediction_loss.append(prediction_loss.item()*x.shape[0])
print('| val loss: %.6f' % (np.sum(test_epoch_loss)/N_test))
print('| val reconstruction loss: %.6f' % (np.sum(test_reconstruction_loss)/N_test))
print('| val prediction loss: %.6f' % (np.sum(test_prediction_loss)/N_test))
print('| val LATENT loss: %.6f' % (np.sum(test_latent_loss)/N_test))
print('| val LATENT NORM: %.6f' % (np.sum(test_norm_loss)/N_test))

In [None]:
with torch.no_grad():
    model.eval()
    loss,reconstruction_loss,prediction_loss,latent_loss,norm_loss = dynamic_loss(X_test[:,-Time_Length:,:].to(device),model,horizon,alpha=alpha)
    print('| val loss: %.6f' % (loss))
    print('| val reconstruction loss: %.6f' % (reconstruction_loss))
    print('| val prediction loss: %.6f' % (prediction_loss))
    print('| val LATENT loss: %.6f' % (latent_loss))
    print('| val LATENT NORM: %.6f' % (norm_loss))

In [None]:
model = model.to("cuda:0")
model.device = "cuda:0"

In [None]:
n = 1
x = X_test[[n],-Time_Length:,:]
model.eval()
with torch.no_grad():
    x_recon_hat,x_ahead_hat,latent,y_recon,y_ahead = model(x.to(device),horizon)

    plt.figure(figsize=(10,10))
#     for i in range(3):
    plt.plot(np.arange(Time_Length),x[0])
    plt.plot(np.arange(horizon),x_recon_hat[0,:,:].cpu(),'--')
    plt.plot(horizon+np.arange(Time_Length-horizon),x_ahead_hat[0,:,:].cpu(),'r.')

    plt.xlabel("Time (n)",fontsize=20)
    plt.ylabel("State",fontsize=20)
    plt.legend(["x","y","z","$x_{reconstructed}$","$y_{reconstructed}$","$z_{reconstructed}$","Prediction"],fontsize=20)

In [None]:
with torch.no_grad():
    nrow = 8; ncol = 8
    fig, axs = plt.subplots(nrow, ncol)
    if nrow == 1:
        axs = np.expand_dims(axs,0)
    for i in range(nrow):
        for j in range(ncol):
            axs[i,j].plot(latent[0,:,i*ncol + j].cpu().detach().numpy())
            axs[i,j].plot(y_recon[0,:,i*ncol + j].cpu().detach().numpy())
            axs[i,j].title.set_text(f"Feature {i*ncol+j+1}")

In [None]:
with torch.no_grad():
    nrow = 8; ncol = 8
    fig, axs = plt.subplots(nrow, ncol)
    if nrow == 1:
        axs = np.expand_dims(axs,0)
    for i in range(nrow):
        for j in range(ncol):
            axs[i,j].plot(y_ahead[0,:,i*ncol + j].cpu().detach().numpy())
            axs[i,j].title.set_text(f"Feature {i*ncol+j+1}")

In [None]:
with torch.no_grad():
    nrow = 8; ncol = 8
    fig, axs = plt.subplots(nrow, ncol)
    if nrow == 1:
        axs = np.expand_dims(axs,0)
    for i in range(nrow):
        for j in range(ncol):
            axs[i,j].plot(torch.concat((y_recon[0,:,i*ncol + j],y_ahead[0,:,i*ncol + j])).cpu().detach().numpy())
            axs[i,j].title.set_text(f"Feature {i*ncol+j+1}")
            axs[i,j].title.set_fontsize(30)

In [None]:
plt.figure()
%matplotlib qt5
ax = plt.axes(projection='3d')
ax.plot3D(x[0,:,0],x[0,:,1],x[0,:,2],'k-') #c=np.linspace(0,1,Time_Length))
ax.set_xlabel('$X$', fontsize=20)
ax.set_ylabel('$Y$',fontsize=20)
ax.set_zlabel(r'$Z$', fontsize=20)

In [None]:
plt.figure()
%matplotlib qt5
ax = plt.axes(projection='3d')
ax.plot3D(x[0,:,0],x[0,:,1],x[0,:,2],'k-') #c=np.linspace(0,1,Time_Length))
ax.plot3D(x_recon_hat[0,:,0].cpu(),x_recon_hat[0,:,1].cpu(),x_recon_hat[0,:,2].cpu(),'b*')
ax.plot3D(x_ahead_hat[0,:,0].cpu(),x_ahead_hat[0,:,1].cpu(),x_ahead_hat[0,:,2].cpu(),'rx')
ax.set_xlabel('$X$', fontsize=20)
ax.set_ylabel('$Y$',fontsize=20)
ax.set_zlabel(r'$Z$', fontsize=20)
plt.legend(["Actual","Reconstruction","Forecasted"])
plt.show()

In [None]:
plt.figure()
%matplotlib qt5
ax = plt.axes(projection='3d')
ax.scatter3D(x[n,horizon:,0],x[n,horizon:,1],x[n,horizon:,2],c=np.linspace(0,1,Time_Length-horizon))
ax.scatter3D(x_ahead_hat[n,:,0].cpu(),x_ahead_hat[n,:,1].cpu(),x_ahead_hat[n,:,2].cpu())
ax.set_xlabel('$X$', fontsize=20)
ax.set_ylabel('$Y$',fontsize=20)
ax.set_zlabel(r'$Z$', fontsize=20)
plt.show()