In [None]:
import os
current_working_directory = os.getcwd()
print(current_working_directory)

import numpy as np 

import torch 
import torch.nn as nn 
import torch.optim as optim

# import torch_optimizer as optim_custom
from torch.utils.data import Dataset, DataLoader
from bernstein_torch import bernstein_coeff_order10_new
import scipy.io as sio

# from models.mlp_qp_vis_aware_2 import MLP, vis_aware_track_net, PointNet
# import pol_matrix_comp
from tqdm import trange,tqdm

from models.mlp_terrain import MLP, mlp_projection_filter
# from scipy.io import loadmat

In [2]:
# Generating P matrix
t_fin = 20.0
num = 100
tot_time = torch.linspace(0, t_fin, num)
tot_time_copy = tot_time.reshape(num, 1)
P, Pdot, Pddot = bernstein_coeff_order10_new(10, tot_time_copy[0], tot_time_copy[-1], tot_time_copy)
P_diag = torch.block_diag(P, P)
Pdot_diag = torch.block_diag(Pdot, Pdot)

Pddot_diag = torch.block_diag(Pddot, Pddot)
nvar = P.size(dim = 1)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

In [None]:

# data = loadmat("./dataset/data/train_data_2_500.mat")
data = np.load("./training_scripts/dataset/data/train_data_terrain_250_c2_v3.npz")

# print(data)


init_state = data['init_state_data']

c_samples_input = data['c_samples_data']

print(init_state.shape)

inp = np.hstack(( init_state, c_samples_input  ))



inp_mean, inp_std = inp.mean(), inp.std()


In [5]:
# Custom Dataset Loader 
class TrajDataset(Dataset):
	"""Expert Trajectory Dataset."""
	def __init__(self, inp, init_state, c_samples_input):
		
		# input
		self.inp = inp
		# State Data
		self.init_state = init_state
		
		self.c_samples_input = c_samples_input
	
	def __len__(self):
		return len(self.inp)    
			
	def __getitem__(self, idx):
		
		# Inputs
		init_state = self.init_state[idx]
		
		c_samples_input = self.c_samples_input[idx]
  
		inp = self.inp[idx]
		
				 
		return torch.tensor(inp).float(), torch.tensor(init_state).float(), torch.tensor(c_samples_input).float() 

# Batch Size - 3k or 4k
batch_size = 12000

# Using PyTorch Dataloader
train_dataset = TrajDataset(inp, init_state, c_samples_input)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)



In [None]:

# Differentiable Layer
num_batch = train_loader.batch_size

P = P.to(device) 
Pdot = Pdot.to(device)
P_diag = P_diag.to(device)
Pdot_diag = Pdot_diag.to(device)

Pddot_diag = Pddot_diag.to(device)



num_dot = num 
num_ddot = num_dot 
num_constraint = 2*num+2*num_dot+2*num_ddot

# CVAE input
enc_inp_dim = np.shape(inp)[1] 
mlp_inp_dim = enc_inp_dim
hidden_dim = 1024
mlp_out_dim = 6*nvar#+3*num_constraint
print(mlp_out_dim)





mlp =  MLP(mlp_inp_dim, hidden_dim, mlp_out_dim)
model = mlp_projection_filter(P, Pdot, Pddot, mlp, num_batch, inp_mean, inp_std, t_fin).to(device)
# model.load_state_dict(torch.load('./training_scripts/weights/mlp_learned_proj_terrain_250_c4_02_check_2.pth'))
model.train()



In [None]:

epochs = 50
step, beta = 0, 1.0 # 3.5
optimizer = optim.AdamW(model.parameters(), lr = 2e-4, weight_decay=6e-5)
# optimizer.load_state_dict(checkpoint['optimizer'])
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 30, gamma = 0.1)
losses = []
last_loss = torch.inf
model_checkpoint = 0
avg_train_loss, avg_primal_loss, avg_fixed_point_loss = [], [], []
for epoch in range(epochs):
    
    # Train Loop
    losses_train, primal_losses, fixed_point_losses = [], [], []
    
    for (inp, init_state, c_samples_input) in tqdm(train_loader):
        
        # Input and Output 
        inp = inp.to(device)
        init_state = init_state.to(device)
        c_samples_input = c_samples_input.to(device)
        

        c_v_samples_input = c_samples_input[:, 0: nvar]
        c_pitch_samples_input = c_samples_input[:, nvar: 2*nvar]
        c_roll_samples_input = c_samples_input[:, 2*nvar: 3*nvar]    

        
        c_v_samples, c_pitch_samples, c_roll_samples, accumulated_res_fixed_point, accumulated_res_primal, accumulated_res_primal_temp, accumulated_res_fixed_point_temp = model(inp, init_state, c_v_samples_input, c_pitch_samples_input, c_roll_samples_input)
        primal_loss, fixed_point_loss, loss = model.mlp_loss(accumulated_res_primal, accumulated_res_fixed_point, c_v_samples, c_v_samples_input, c_pitch_samples, c_pitch_samples_input, c_roll_samples, c_roll_samples_input)

        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses_train.append(loss.detach().cpu().numpy()) 
        primal_losses.append(primal_loss.detach().cpu().numpy())
        fixed_point_losses.append(fixed_point_loss.detach().cpu().numpy())
        

    if epoch % 2 == 0:    
        print(f"Epoch: {epoch + 1}, Train Loss: {np.average(losses_train):.3f}, primal: {np.average(primal_losses):.3f}, fixed_point: {np.average(fixed_point_losses):.3f} ")

    step += 0.07 #0.15
    # scheduler.step()
    if loss <= last_loss:
            torch.save(model.state_dict(), f"./training_scripts/weights/mlp_learned_proj_terrain_250_c2_05_v3_lowest.pth")
            last_loss = loss

    if epoch % 15 == 0:
        torch.save(model.state_dict(), f"./training_scripts/weights/mlp_learned_proj_terrain_250_c2_05_v3_check_{model_checkpoint}.pth")
        model_checkpoint += 1
        
    avg_train_loss.append(np.average(losses_train)), avg_primal_loss.append(np.average(primal_losses)), \
    avg_fixed_point_loss.append(np.average(fixed_point_losses))
    

In [8]:
torch.save(model.state_dict(), './training_scripts/weights/mlp_learned_proj_terrain_250_c2_05_v3.pth')

In [None]:
h_avg_train_loss = np.array(avg_train_loss)
h_avg_primal_loss = np.array(avg_primal_loss)
h_avg_fixed_point_loss = np.array(avg_fixed_point_loss)
h_mean = inp_mean
h_std = inp_std
np.savez("./training_scripts/weights/data_out_terrain_250_c2_05_v3",avg_train_loss=h_avg_train_loss,avg_primal_loss=h_avg_primal_loss,
avg_fixed_point_loss=h_avg_fixed_point_loss,mean=h_mean,std=h_std)