In [1]:
import numpy as np
import math
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
import torch
import torch.utils.data as data 

from copy import deepcopy
from cv2 import aruco
from omegaconf import OmegaConf
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm

# Custom imports
from contrastive_learning.datasets.state_dataset import StateDataset
from contrastive_learning.tests.plotting import plot_corners, plot_rvec_tvec


In [2]:
# Script to check the distribution of all the states in given dataset

In [37]:
# Set the needed variables for the mock dataset to retrieve data
cfg = OmegaConf.create()
cfg.data_dir = '/home/irmak/Workspace/DAWGE/src/dawge_planner/data/box_orientation_2_demos/test_demos'
cfg.pos_ref = 'global' 
cfg.pos_type = 'corners'
cfg.pos_dim = 8

In [38]:
print(cfg)

{'data_dir': '/home/irmak/Workspace/DAWGE/src/dawge_planner/data/box_orientation_2_demos/test_demos', 'pos_ref': 'global', 'pos_type': 'corners', 'pos_dim': 8}


In [39]:
dataset = StateDataset(cfg)
data_loader = data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

DATASET POS_REF: global
len(dataset): 802
self.action_min: [ 0.         -0.30000001], self.action_max: [0.15000001 0.30000001]


In [40]:
# Method to dump all the positions to test_demos
def dump_all_pos(bs, dataset, data_loader):
    pbar = tqdm(total=len(data_loader))
    all_curr_pos = np.zeros((len(dataset), cfg.pos_dim*2))
    for i,batch in enumerate(data_loader):
        curr_pos, _, _ = [b for b in batch] # These are normalized
        all_curr_pos[i*bs:(i+1)*bs, :] = curr_pos.detach().numpy()
        pbar.update(1)

    with open(os.path.join(cfg.data_dir, 'all_curr_pos.npy'), 'wb') as f:
        np.save(f, all_curr_pos)
        
    print('All positions saved to : {}'.format(os.path.join(cfg.data_dir, 'all_curr_pos.npy')))

In [41]:
dump_all_pos(bs=1, dataset=dataset, data_loader=data_loader)

100%|█████████████████████████████████████████████████████████████████████████████| 802/802 [00:00<00:00, 1210.80it/s]

All positions saved to : /home/irmak/Workspace/DAWGE/src/dawge_planner/data/box_orientation_2_demos/test_demos/all_curr_pos.npy





In [42]:
# Method to find the average MSE to the closest k neighbours of curr_pos to all_curr_pos.npy
def get_mse_to_dist(cfg, curr_pos, k=10):
    with open(os.path.join(cfg.data_dir, 'all_curr_pos.npy'), 'rb') as f:
        all_curr_pos = np.load(f)

    dist = np.linalg.norm(all_curr_pos - curr_pos, axis=1)
    dist.sort()

    return sum(dist[:k])


In [43]:
def get_mse_for_all(cfg):
    # It will traverse through the all_curr_pos.npy and find the mse distance for all demos
    with open(os.path.join(cfg.data_dir, 'all_curr_pos.npy'), 'rb') as f:
        all_curr_pos = np.load(f)
        
    print('all_curr_pos: {}'.format(all_curr_pos))
        
    all_mses = []
    for curr_pos in all_curr_pos:
        all_mses.append(get_mse_to_dist(cfg, curr_pos, k=10))
        
    return all_mses

In [44]:
all_mses = get_mse_for_all(cfg)

all_curr_pos: [[1.78109854e-01 4.73011374e-01 1.86591282e-01 ... 9.17613626e-01
  3.28352183e-01 9.37500000e-01]
 [1.78109854e-01 4.73011374e-01 1.85379639e-01 ... 9.20454562e-01
  3.22294027e-01 9.38920438e-01]
 [1.78109854e-01 4.73011374e-01 1.86591282e-01 ... 9.20454562e-01
  3.05331171e-01 9.43181813e-01]
 ...
 [8.44507277e-01 4.26136376e-03 8.73586416e-01 ... 3.55113633e-02
  7.80290782e-01 1.98863633e-02]
 [8.52988720e-01 2.84090918e-03 8.82067859e-01 ... 3.32386382e-02
  7.88529873e-01 1.76136363e-02]
 [8.63287568e-01 7.10227294e-04 8.92366707e-01 ... 3.09659094e-02
  7.96768963e-01 1.53409094e-02]]


In [45]:
plt.plot(all_mses)
plt.ylabel("Normalized Norm Dist")
plt.xlabel("Frame ID")
plt.savefig(os.path.join(cfg.data_dir, 'all_mses.png'))

In [None]:
# Plot this value for all the demos