In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import os
import time
import glob
import numpy as np
from bagpy import bagreader
import rosbag
import pandas as pd
import wandb

Failed to load Python extension for LZ4 support. LZ4 compression will not be available.


In [73]:
from screwing_dataset import ScrewingDataset
from screwing_model import ScrewingModel
from screwing_model_seq import ScrewingModelSeq


In [74]:
from torch.utils.data import DataLoader, ConcatDataset
import torch
import torch.nn as nn
import torch.optim as optim

In [65]:
from training import batched_pos_err, batched_ori_err, weighted_MSE_loss


In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [18]:

batch_size = 1 # Powers of two
window_size = 30

input_dim = 19
hidden_dim = 10
num_layers = 3
output_dim = 5

#TODO change arbitrary weight
ori_rel_weight = 2

num_eps = 1

num_epochs = 200
learning_rate = 0.003

base_dset_dir = os.path.expanduser('~/datasets/screwing/')
# xprmnt_dir = time.strftime("/2022-03-10_23-17-39")
xprmnt_dir = time.strftime("2022-03-11_17-07-13/")

log_interval = 1 

train_ratio = .75


model_save_dir = '../../../models/'
model_name = 'model_2022-03-27_17-18-31.pt'

In [36]:
model = ScrewingModelSeq(input_dim, hidden_dim, num_layers, output_dim)
model.load_state_dict(torch.load(model_save_dir + model_name))
model.eval()
model = model.to(device)

In [75]:

bag_path_names = base_dset_dir + xprmnt_dir + '*.bag' 

bag_path_list = glob.glob(bag_path_names)
total_num_eps = len(bag_path_list)
wandb.config.update({'total_dset_eps_num': num_eps})

num_workers = 8

dset_list = []
for i in range(num_eps): # for testing a small number of data
# for i in range(total_num_eps):
    id_str = str(i)
    bag_path_names = base_dset_dir + xprmnt_dir + id_str + '_*.bag' 
    bag_path = glob.glob(bag_path_names)[0]

    pos_path_name = base_dset_dir + xprmnt_dir + id_str + '_pos.npy'
    proj_ori_path = base_dset_dir + xprmnt_dir + id_str + '_proj_ori.npy'
    pos_ori_path_list = [pos_path_name, proj_ori_path]

    dset_list.append(ScrewingDataset(bag_path, pos_ori_path_list, window_size))

concat_dset = ConcatDataset(dset_list)

[INFO]  Data folder /home/serialexperimentsleon/datasets/screwing/2022-03-11_17-07-13/0_2022-03-11-17-07-18 already exists. Not creating.


In [76]:
length = len(concat_dset)
train_size = int(train_ratio*length)
# train_size
torch_seed = 0
torch.manual_seed(torch_seed)
train_dset, valid_dset = torch.utils.data.random_split(concat_dset, [train_size,length - train_size])
train_dset_length = len(train_dset)
valid_dset_length = len(valid_dset)

In [77]:

valid_lder = DataLoader(
    valid_dset,
    shuffle=False,
    num_workers=num_workers,
    batch_size=batch_size
)

In [84]:
with torch.no_grad():
    for batch_idx,(x,y, times, T) in enumerate(valid_lder):
        x = x.to(device)
        y = y.float().to(device)

        # Forward propogation happens here
        outputs = model(x).to(device)
        t = 0
        # print(outputs.size())
        output_t = outputs[:, t, :]
        # print(output_t.size()) # B x L x O
        # print(outputs[0, t, :].size()) # B x L x O
        
        # print(times.size()) # B x L 
        print(times[:, t].item()) 
        
        loss = weighted_MSE_loss(output_t, y, ori_rel_weight)
        print(batched_ori_err(output_t, y, device))
        print(batched_pos_err(output_t, y).item())
        print(loss.item())
        print(T.item())
        break


0.45473598110118224
tensor([0.3854], device='cuda:0')
0.32584792375564575
0.4114879369735718
13.464705228805542


In [None]:

# def test_metrics(model, ori_rel_weight, seq_length, val_loader): #TODO add early stopping criterion
#     # logging_step = 0
#     # quantiles of interest: median and 95% CI
#     q = torch.as_tensor([0.025, 0.5, 0.975]).to(device) 
#     q_timing = torch.as_tensor([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]).to(device)
#     seq_length = seq_length
#     ## switch model to eval
#     model.eval()

#     with torch.no_grad():

#         for batch_idx,(x,y, times) in enumerate(val_loader):
#             x = x.to(device)
#             y = y.float().to(device)

#             # Forward propogation happens here
#             outputs = model(x).to(device)
#             for t in range(seq_length):

#                 output_t = outputs[:, t, :]
#                 times_t = times[:, t] 

#                 loss = weighted_MSE_loss(output_t, y, ori_rel_weight)

#                 ## evaluate and append analysis metrics
#                 total_valid_ori_error.append(batched_ori_err(output_t, y))
#                 total_valid_pos_error.append(batched_pos_err(output_t, y))
#                 total_valid_loss.append(loss)

#                 # if batch_idx % log_interval == 0:
#                 #     wandb.log({"loss": loss, 'epoch': epoch, 'batch_idx': batch_idx})
                
#                 total_valid_pos_error = torch.cat(total_valid_pos_error).to(device)
#                 total_valid_ori_error = torch.cat(total_valid_ori_error).to(device)
#                 total_valid_loss = torch.as_tensor(total_valid_loss).to(device)

#                 ## statistical metrics from the test evaluations

#                 ## pos error
#                 pos_err_mean = torch.mean(total_valid_pos_error)
#                 pos_err_std = torch.std(total_valid_pos_error)
#                 pos_err_max = torch.max(total_valid_pos_error)
#                 pos_err_min = torch.min(total_valid_pos_error)

#                 ## 95% confidence interval and median
#                 # q = torch.as_tensor([0.025, 0.5, 0.975]) 
#                 pos_err_95_median = torch.quantile(total_valid_pos_error, q, dim=0, keepdim=False, interpolation='nearest')

#                 ## ori error
#                 ori_err_mean = torch.mean(total_valid_ori_error)
#                 ori_err_std = torch.std(total_valid_ori_error)
#                 ori_err_max = torch.max(total_valid_ori_error)
#                 ori_err_min = torch.min(total_valid_ori_error)

#                 ## 95% confidence interval
#                 ori_err_95_median = torch.quantile(total_valid_ori_error, q, dim=0, keepdim=False, interpolation='nearest')

#                 ## loss 
#                 loss_mean = torch.mean(total_valid_loss)
#                 loss_std = torch.std(total_valid_loss)
#                 loss_max = torch.max(total_valid_loss)
#                 loss_min = torch.min(total_valid_loss)

#                 ## 95% confidence interval
#                 loss_95_median = torch.quantile(total_valid_loss, q, dim=0, keepdim=False, interpolation='nearest')

#                 wandb.log({ 
#                 'valid_pos_err_mean_' + str(t) : pos_err_mean,
#                 'valid_pos_err_std_' + str(t) : pos_err_std,
#                 'valid_pos_err_max_' + str(t) : pos_err_max,
#                 'valid_pos_err_min_' + str(t) : pos_err_min,
#                 'valid_pos_err_95_lower_' + str(t) : pos_err_95_median[0].item(),
#                 'valid_pos_err_median_' + str(t) : pos_err_95_median[1].item(),
#                 'valid_pos_err_95_upper_' + str(t) : pos_err_95_median[2].item(),
#                 'valid_ori_err_mean_' + str(t) : ori_err_mean,
#                 'valid_ori_err_std_' + str(t) : ori_err_std,
#                 'valid_ori_err_max_' + str(t) : ori_err_max,
#                 'valid_ori_err_min_' + str(t) : ori_err_min,
#                 'valid_ori_err_95_lower_' + str(t) : ori_err_95_median[0].item(),
#                 'valid_ori_err_median_' + str(t) : ori_err_95_median[1].item(),
#                 'valid_ori_err_95_upper_' + str(t) : ori_err_95_median[2].item(),
#                 'valid_loss_mean_' + str(t) : loss_mean,
#                 'valid_loss_std_' + str(t) : loss_std,
#                 'valid_loss_max_' + str(t) : loss_max,
#                 'valid_loss_min_' + str(t) : loss_min,
#                 'valid_loss_95_lower_' + str(t) : loss_95_median[0].item(),
#                 'valid_loss_median_' + str(t) : loss_95_median[1].item(),
#                 'valid_loss_95_upper_' + str(t) : loss_95_median[2].item()
#                 }, step = logging_step-1)
#                 ## log some summary metrics from the validation/eval run

#                 ## log a figure of model output  
