In [1]:
import numpy as np 
import matplotlib.pyplot as plt 
import seaborn as sns
sns.set_theme()

import torch 
import torch.nn as nn 
import torch.optim as optim

# import torch_optimizer as optim_custom
from torch.utils.data import Dataset, DataLoader
from bernstein_torch import bernstein_coeff_order10_new
import scipy.io as sio

# from models.mlp_qp_vis_aware_2 import MLP, vis_aware_track_net, PointNet
# import pol_matrix_comp
# from tqdm import trange

from mlp_obst_avoidance import MLP, mlp_projection_filter
# from scipy.io import loadmat

In [2]:
# Generating P matrix
t_fin = 20.0
num = 100
tot_time = torch.linspace(0, t_fin, num)
tot_time_copy = tot_time.reshape(num, 1)
P, Pdot, Pddot = bernstein_coeff_order10_new(10, tot_time_copy[0], tot_time_copy[-1], tot_time_copy)
P_diag = torch.block_diag(P, P)
Pdot_diag = torch.block_diag(Pdot, Pdot)

Pddot_diag = torch.block_diag(Pddot, Pddot)
nvar = P.size(dim = 1)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

In [4]:
# data = loadmat("./dataset/data/train_data_2_500.mat")
data = np.load("./training_scripts/dataset/data/train_data_obs_150_100_250.npz")

# print(data)


init_state = data['init_state_data']

c_samples_input = data['c_samples_data']

inp = np.hstack(( init_state, c_samples_input  ))



inp_mean, inp_std = inp.mean(), inp.std()


In [None]:
print(c_samples_input[0:5,0:11])

In [None]:

# Differentiable Layer
num_batch = 2

P = P.to(device) 
Pdot = Pdot.to(device)
P_diag = P_diag.to(device)
Pdot_diag = Pdot_diag.to(device)

Pddot_diag = Pddot_diag.to(device)



num_dot = num 
num_ddot = num_dot 
num_constraint = 2*num+2*num_dot+2*num_ddot

# CVAE input
enc_inp_dim = np.shape(inp)[1] 
mlp_inp_dim = enc_inp_dim
hidden_dim = 1024
mlp_out_dim = 6*nvar#+3*num_constraint
print(mlp_out_dim)





mlp =  MLP(mlp_inp_dim, hidden_dim, mlp_out_dim)
model = mlp_projection_filter(P, Pdot, Pddot, mlp, num_batch, inp_mean, inp_std, t_fin).to(device)

model.load_state_dict(torch.load("./training_scripts/weights/mlp_learned_proj_obs_150_100_250.pth"))
model.eval()


In [None]:


# idx = np.random.randint(0, np.shape(inp)[0])
idx = 100
print(idx)



inp_test = inp[idx]
inp_test = torch.tensor(inp_test).float()
inp_test = inp_test.to(device)
inp_test = torch.vstack([inp_test] * num_batch)
inp_norm = (inp_test - inp_mean) / inp_std

init_state = inp_test[:, 0 : 6]

c_samples_input_test = c_samples_input[idx]
c_samples_input_test = torch.tensor(c_samples_input_test).float()
c_samples_iniput_test = c_samples_input_test.to(device)
c_samples_input_test = torch.vstack( [c_samples_input_test]*num_batch  )


c_v_samples_input = c_samples_input_test[:, 0: nvar].to(device)
c_pitch_samples_input = c_samples_input_test[:, nvar: 2*nvar].to(device)
c_roll_samples_input = c_samples_input_test[:, 2*nvar: 3*nvar].to(device)    

# print(inp_norm.device )

# print(init_state.device )

# print(c_v_samples_input.device )




with torch.no_grad():
    c_v_samples, c_pitch_samples, c_roll_samples, accumulated_res_fixed_point, accumulated_res_primal, accumulated_res_primal_temp, accumulated_res_fixed_point_temp = model.decoder_function(inp_norm, init_state, c_v_samples_input, c_pitch_samples_input, c_roll_samples_input)

    accumulated_res_primal_temp = torch.stack(accumulated_res_primal_temp)[:, 0]
    accumulated_res_fixed_point_temp = torch.stack(accumulated_res_fixed_point_temp)[:, 0]
    


    v_samples = torch.mm(model.P, c_v_samples.T).T 

    pitch_samples = torch.mm(model.P, c_pitch_samples.T).T

    roll_samples =  torch.mm(model.P, c_roll_samples.T).T
    
    v_samples_input = torch.mm(model.P, c_v_samples_input.T).T 

    pitch_samples_input = torch.mm(model.P, c_pitch_samples_input.T).T

    roll_samples_input =  torch.mm(model.P, c_roll_samples_input.T).T
    

    
    
    plt.figure(4)
    plt.plot(accumulated_res_primal_temp.cpu().detach().numpy())
    
    plt.figure(5)
    plt.plot(accumulated_res_fixed_point_temp.cpu().detach().numpy())
    
    plt.figure(6)

    plt.plot(v_samples.T.cpu().detach().numpy())
    plt.plot(v_samples_input.T.cpu().detach().numpy(), '-b')
    
    
    plt.figure(7)

    plt.plot(pitch_samples.T.cpu().detach().numpy())
    plt.plot(pitch_samples_input.T.cpu().detach().numpy(), '-b')
    
    plt.figure(8)

    plt.plot(roll_samples.T.cpu().detach().numpy())
    plt.plot(roll_samples_input.T.cpu().detach().numpy(), '-b') 
    
    plt.show()
