In [1]:
import os
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from pointnav_vo.config.vo_config.default import get_config as get_vo_config
from pointnav_vo.vo import VOTransformerRegressionGeometricInvarianceEngine

ACTIONS = {0: "STOP", 1: "MOVE_FORWARD", 2: "TURN_LEFT", 3: "TURN_RIGHT"}

## extract statistics from dataset

In [2]:
config = get_vo_config('/datasets/home/memmel/PointNav-VO/configs/vo/vo_pointnav_vit.yaml', [])
config.defrost()
config.VO.DATASET.TRAIN = config.VO.DATASET.TRAIN_WITH_NOISE
config.VO.DATASET.EVAL = config.VO.DATASET.EVAL_WITH_NOISE

# set collision and invariance settings
config.VO.TRAIN.collision = '-1' # -1 w/ collision
config.VO.GEOMETRY.invariance_types = ["inverse_joint_train"] # ["inverse_joint_train"]
config.freeze()

# select what to extract
with_mean_std = True

In [3]:
engine = VOTransformerRegressionGeometricInvarianceEngine(config=config, run_type='train')
engine._set_up_dataloader(0,1)

2022-03-02 19:41:43,268 Visual Odometry configs:
BASE_TASK_CONFIG_PATH: configs/challenge_pointnav2021.local.rgbd.yaml
CHECKPOINT_FOLDER: {{LOG_DIR}}/checkpoints
DEBUG: True
ENGINE_NAME: vo_transformer_regression_geo_inv_engine
EVAL:
  EVAL_CKPT_PATH: eval_ckpt.pth
  EVAL_WITH_CKPT: True
INFO_DIR: {{LOG_DIR}}/infos
LOG_DIR: train_log/vit/
LOG_FILE: {{LOG_DIR}}/train.log
LOG_INTERVAL: 1
N_GPUS: -1
RESUME_STATE_FILE: resume_train_ckpt.pth
RESUME_TRAIN: False
TASK_CONFIG:
  DATASET:
    CONTENT_SCENES: ['*']
    DATA_PATH: /scratch/memmel/dataset/habitat_datasets/pointnav/gibson/v2/{split}/{split}.json.gz
    SCENES_DIR: /scratch/memmel/dataset/Gibson
    SPLIT: train
    TYPE: PointNav-v1
  ENVIRONMENT:
    ITERATOR_OPTIONS:
      CYCLE: True
      GROUP_BY_SCENE: True
      MAX_SCENE_REPEAT_EPISODES: -1
      MAX_SCENE_REPEAT_STEPS: 10000
      NUM_EPISODE_SAMPLE: -1
      SHUFFLE: False
      STEP_REPETITION_RANGE: 0.2
    MAX_EPISODE_SECONDS: 10000000
    MAX_EPISODE_STEPS: 500
  PYRO

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:00<00:00, 1938.05it/s]
2022-03-02 19:41:47,152 ... done. lenght 5435



In [None]:
data = dict({
    'actions': [],
    'delta_xs': [],
    'delta_ys': [],
    'delta_zs': [],
    'delta_yaws': [],
    'dz_regress_masks': [],
    })
mean, std = torch.zeros(3), torch.zeros(3)
ctr = 0
samples = 0

train_iter = iter(engine.train_loader)
with tqdm(total=0) as pbar:
            
    while True:
        try:
            batch_data = next(train_iter)
        # NOTE RuntimeError: DataLoader timed out after 300 seconds
        except RuntimeError  as re:
            print(re)
            batch_data = next(train_iter)
        except StopIteration:
            break

        (data_types,
        raw_rgb_pairs,
        raw_depth_pairs,
        raw_discretized_depth_pairs,
        raw_top_down_view_pairs,

        actions,
        delta_xs,
        delta_ys,
        delta_zs,
        delta_yaws,
        dz_regress_masks,

        chunk_idxs,
        entry_idxs,
        ) = batch_data
       
        data['actions'].append(actions)
        data['delta_xs'].append(delta_xs)
        data['delta_ys'].append(delta_ys)
        data['delta_zs'].append(delta_zs)
        data['delta_yaws'].append(delta_yaws)
        data['dz_regress_masks'].append(dz_regress_masks)
        
        if with_mean_std:
            rgb = torch.cat([torch.cat((pair[:,:,:,:pair.shape[-1]//2], pair[:,:,:,pair.shape[-1]//2:]),dim=0).float().to(torch.device('cpu'), non_blocking=True)
                            for pair in raw_rgb_pairs], dim=0,)

            mean += (torch.mean(rgb, axis=(0,1,2)) / 255.)
            std += (torch.std(rgb, axis=(0,1,2)) / 255.)
            ctr += 1
            samples += rgb.shape[0]
            
        pbar.update(1)
        
for k in data.keys():
    data[k] = torch.cat(data[k])


796it [03:48,  5.21it/s]

In [None]:
data_tmp = data.copy()

In [None]:
data_tmp['actions'].unique(), data_tmp['actions'].shape

In [None]:
outpath = './plots'
os.makedirs(outpath,exist_ok=True)
fname_append = f"{'_collision' if config.VO.TRAIN.collision == '-1' else ''}{'_invjoint' if 'inverse_joint_train' in config.VO.GEOMETRY.invariance_types else ''}"

In [None]:
with open(os.path.join(outpath,'statistics'+fname_append+'.txt'), 'w') as f:
    f.write(f'mean {(mean / ctr).tolist()}\n')
    f.write(f'std {(std / ctr).tolist()}\n')
    f.write(f'samples ALL {int(samples)}\n')
    for act in ACTIONS:
        f.write(f'samples {ACTIONS[act]} {(data_tmp["actions"]==act).sum()}\n')  

In [None]:
barlist = plt.bar(x=[0,1,2], height=[(data_tmp['actions']==act).sum() for act in data_tmp['actions'].unique()])
barlist[0].set_color('tab:blue')
barlist[1].set_color('tab:green')
barlist[2].set_color('tab:orange')

plt.xticks([])
plt.title(f'{"w/" if config.VO.TRAIN.collision == "-1" else "w/o"} collisions')

plt.savefig(os.path.join(outpath,'bar'+fname_append+'.pdf'))

In [None]:
label_dict = dict({
    'delta_xs': r'$\xi^x_{C_t\rightarrow C_{t+1}}$',
    'delta_zs': r'$\xi^z_{C_t\rightarrow C_{t+1}}$',
    'delta_yaws': r'$\theta_{C_t\rightarrow C_{t+1}}$',
})

plots = [('delta_xs', 'delta_zs'), ('delta_xs', 'delta_yaws'), ('delta_zs', 'delta_yaws')]

plt.rcParams["figure.figsize"] = (12,4)

fig, axs = plt.subplots(1,3)

for i, (kx,ky) in enumerate(plots):
    for act in data_tmp['actions'].unique():
        selector = data_tmp['actions']==act
        axs[i].scatter(data_tmp[kx][selector],
                    data_tmp[ky][selector],
                    label=ACTIONS[act.item()], s=2)

    axs[i].set_xlabel(label_dict[kx], fontsize=20)
    axs[i].set_ylabel(label_dict[ky], fontsize=20)

    axs[i].tick_params(labelsize=10)

# plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(outpath,'distributions'+fname_append+'.pdf'))