In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os
# os.environ['PYOPENGL_PLATFORM'] = 'egl'

from pathlib import Path
import trimesh
import pyrender
import h5py
import numpy as np
import torch
import os, os.path as osp
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.feature import plot_matches
import cv2
from PIL import Image

from datasets.scannet.utils_3d import ProjectionHelper, adjust_intrinsic, make_intrinsic, load_intrinsic, load_pose
from datasets.scannet.utils_3d import load_depth, load_color

from scripts.sem_seg.prep_backproj_data import get_world_to_scene

In [None]:
def get_scan_name(scene_id, scan_id):
    return f'scene{str(scene_id).zfill(4)}_{str(scan_id).zfill(2)}'

# globals
subvol_size = (32, 32, 32)
voxel_size = 0.05
voxel_dims = (1, 1, 1)
root = Path('/mnt/data/scannet/scans')
proj_img_size = (40, 30)

data_dir = Path('/mnt/data/scannet/backproj')
fname = 'train100-v6.h5'
f = h5py.File(data_dir / fname, 'r')



In [None]:
f['x'].shape

In [None]:
occupied, projected, overlap = [], [], []

for ndx in tqdm(range(1000)):
    w2g, sceneid, scanid, frames = f['world_to_grid'][ndx], f['scene_id'][ndx], f['scan_id'][ndx], f['frames'][ndx]

    subvol_x = f['x'][ndx]
    # per-scene basics
    scan_name = get_scan_name(sceneid, scanid)
    frame_ndx = 0
    pose_path = root / scan_name / 'pose' / f'{frames[frame_ndx]}.txt'
    pose = load_pose(pose_path).numpy()
    depth_path = root / scan_name / 'depth' / f'{frames[frame_ndx]}.png' 
    depth = load_depth(depth_path, proj_img_size)
    rgb_path = root / scan_name / 'color' / f'{frames[frame_ndx]}.jpg' 
    rgb = load_color(rgb_path, proj_img_size)
    # get projection
    intrinsic_path = root / scan_name / 'intrinsic/intrinsic_color.txt'
    intrinsic = load_intrinsic(intrinsic_path)
    # adjust for smaller image size
    intrinsic = adjust_intrinsic(intrinsic, [1296, 968], proj_img_size)

    projection = ProjectionHelper(
                intrinsic, 
                0.4, 4.0,
                proj_img_size,
                subvol_size, voxel_size
            )

    # projection expects origin of chunk in a corner
    # but w2g is wrt center of the chunk -> add 16 to its "grid coords" to get the required grid coords
    # ie 0,0,0 becomes 16,16,16
    # add an additional translation to existing one 
    t = torch.eye(4)
    t[:3, -1] = 16
    w2g_tmp = t @ w2g

    proj = projection.compute_projection(torch.Tensor(depth), torch.Tensor(pose), torch.Tensor(w2g_tmp))
    if proj is None: 
        continue
    proj3d, proj2d = proj
    num_inds = proj3d[0]

    ind3d = proj3d[1:1+num_inds]
    ind2d = proj2d[1:1+num_inds]

#     coords_3d = torch.empty(4, num_inds)
    coords_3d = ProjectionHelper.lin_ind_to_coords_static(ind3d, subvol_size).T[:, :-1].long()
    i,j,k = coords_3d.T

#     print('Num correspondences:', proj3d[0].item())
#     print('Occupied voxels:', (subvol_x == 1).sum())
#     print('Overlap: ', (subvol_x[i, j, k] == 1).sum())
    projected.append(proj3d[0].item())
    occupied.append((subvol_x == 1).sum())
    overlap.append((subvol_x[i, j, k] == 1).sum())



In [None]:
projected = torch.Tensor(projected)
occupied = torch.Tensor(occupied)
overlap = torch.Tensor(overlap)

In [None]:
def stats(x):
    print(x.min(), x.max(), x.mean())

In [None]:
stats(projected)

In [None]:
stats(occupied)

In [None]:
stats(overlap)

In [None]:
subvol_x.shape

In [None]:
num_ind = 10
inds = torch.randperm(32*32*32)[:num_ind]
print(inds)
coords_3d = ProjectionHelper.lin_ind_to_coords_static(inds, subvol_size).T[:, :-1].long()
i,j,k = coords_3d.T
# empty features tensor CWHD
x = torch.zeros(2, 32, 32, 32)

# CDHW
# x = x.permute(0, 3, 2, 1)
# x.reshape(2, -1)[:, inds] = torch.ones(2, num_ind)
# # back to CWHD
# x = x.permute(0, 3, 2, 1)

# set values with ijk
x[:, i, j, k] = torch.ones(2, num_ind)

In [None]:
# CWHD
x[:, i, j, k]

In [None]:
# CDHW
x.permute(3, 2, 1, 0).reshape(-1, 2)[inds, :]
# x.reshape(2, -1)[:, inds]

In [None]:
x = torch.zeros(5, 5)
x.reshape(-1)[10] = 10
x

In [None]:
# CHW
img_size = (40, 30)
x2d = torch.zeros(2, img_size[1], img_size[0])
num_ind = 10
inds2d = torch.randperm(img_size[0]*img_size[1])[:num_ind]

coords_2d = ProjectionHelper.lin_ind2d_to_coords2d_static(inds2d, (img_size))
i, j = coords_2d
print(x2d.sum())
# CWH
x2d = x2d.permute(0, 2, 1)
# x2d.reshape(2, -1)[:, inds2d] = torch.ones(2, num_ind)
# CWH
x2d[:, i, j] = torch.ones(2, num_ind)
print(x2d.sum(), x2d.shape)


In [None]:
x2d[:, i, j]

In [None]:
# CHW
x2d.permute(0, 2, 1).reshape(2, -1)[:, inds2d]