In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from extract_features_from_matterport import (
    MatterportFeature, MatterportDataset, Arguments, load_viewpoints, 
    init_detector,inference_detector,filter_panorama, cartesian_to_polar
)
from typing import Tuple, Union, Sequence, List
import pickle
import lmdb
import sys
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as mcolors
from mmdet.datasets.lvis import LVISV1Dataset
from torch.nn import functional as F
import math
import matplotlib.colors as mcolors

classes = LVISV1Dataset.CLASSES

sys.argv= ['foo']
args = Arguments
args.matterport = Path('matterport-views.lmdb')
args.max_total_boxes = 100

viewpoints = load_viewpoints(args)
dataset = MatterportDataset(viewpoints, args)

# Extract features using EQLv2

In [None]:
device_id = 0
model = init_detector(str(args.config), str(args.checkpoint), device=f'cuda:{device_id}')

In [None]:
# feats, scan, viewpoint = dataset[2]
scan = '17DRP5sb8fy'
viewpoint = '08c774f20c984008882da2b8547850eb'
key = f'{scan}_{viewpoint}'
feats = dataset.txn.get(key.encode('ascii'))
feats = pickle.loads(feats)


all_boxes = []
all_features = []
all_probs = []
all_view_ids = []
all_labels = []

for view_id, im in zip(feats['view_ids'], feats['image_feat']):

    results = inference_detector(model, np.array(im))
    assert len(results['bbox']) == len(results['features'])

    all_boxes.append(results['bbox'])
    all_probs.append(results['cls_score'])
    all_features.append(results['features'])
    all_labels.append(results['labels'])
    num_bbox = results['bbox'].shape[0]
    all_view_ids += [view_id] * num_bbox

image_feat = torch.cat(all_features)
bbox = torch.cat(all_boxes)
probs = torch.cat(all_probs)
view_ids = torch.Tensor(all_view_ids)
labels = torch.cat(all_labels)

In [None]:
keep_ind = filter_panorama(
    bbox,
    probs,
    image_feat,
    view_ids,
#     args.max_total_boxes, 
    100,
    feats['image_w'],
    feats['image_h'],
    feats['fov'],
)

In [None]:
for view_id in range(36):
    mask = view_ids[keep_ind] == view_id
    if not mask.any():
        continue
    plt.imshow(np.array(feats['image_feat'][view_id])[:, :, ::-1])
    ax = plt.gca()

    for obj, vid, label, color in zip(bbox[keep_ind][mask], view_ids[keep_ind][mask], probs[keep_ind][mask], mcolors.TABLEAU_COLORS):
        if vid != view_id:
            continue
        h = obj[3] - obj[1]
        w = obj[2] - obj[0]
        pos = (obj[0], obj[1])
        rect = patches.Rectangle(
            (pos), w, h, linewidth=1, 
            label=classes[label.argmax()],
            edgecolor=color, 
            facecolor='none')
        ax.add_patch(rect)
    plt.axis('off')
    plt.legend()
    plt.show()

# Visualize pre-extracted features

In [None]:
class FeaturesReader:
    def __init__(
        self,
        path: Union[Path, str]):
        self._path = Path(path)

        # open database
        self._env = lmdb.open(
            str(path),
            readonly=True,
            readahead=False,
            max_readers=512,
            lock=False,
            map_size=int(1e9),
        )

        # get keys
        with self._env.begin(write=False, buffers=True) as txn:
            bkeys = txn.get("__keys__".encode())
            if bkeys is None:
                bkeys = txn.get("keys".encode())
                if bkeys is None:
                    raise RuntimeError("Please preload keys in the LMDB")
            self._keys = set(k.decode() for k in pickle.loads(bkeys))

        self.key_split = "_"

    def __repr__(self) -> str:
        return f'{self._path.stem}.{int(self._path.lstat().st_ctime)}'

    @property
    def keys(self):
        return self._keys

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, keys: List[str]) -> List:
        items = [None] * len(keys)

        for i, key in enumerate(keys):
            if not isinstance(key, str) or key not in self.keys:
                raise TypeError(f"invalid key: {key}")

        with self._env.begin(write=False) as txn:
            for i, key in enumerate(keys):
                if items[i] is not None:
                    continue
                item = txn.get(key.encode())
                if item is None:
                    continue
                items[i] = pickle.loads(item)

        return items
    
    
pre_extracted_lmdb = '/scratch/jeanzay/work/src/eqlv2/matterport-eqlv2.lmdb'
reader = FeaturesReader(pre_extracted_lmdb)

In [None]:
key = next(iter(reader.keys))
# 17DRP5sb8fy 08c774f20c984008882da2b8547850eb
features = reader[[key]]

In [None]:
# load corresponding images
feats = dataset.txn.get(key.encode('ascii'))
feats = pickle.loads(feats)

In [None]:
ft = features[0]

In [None]:
print(key, feature.keys())

In [None]:
ft['boxes'].long()

In [None]:
for view_id in range(36):
    mask = ft['view_ids'] == view_id
    if not mask.any():
        continue
    plt.imshow(np.array(feats['image_feat'][view_id])[:, :, ::-1])
    ax = plt.gca()

    for obj, label, color in zip(ft['boxes'][mask], ft['cls_probs'][mask], mcolors.TABLEAU_COLORS):
        h = obj[3] - obj[1]
        w = obj[2] - obj[0]
        pos = (obj[0], obj[1])
        rect = patches.Rectangle(
            (pos), w, h, linewidth=1, 
            label=classes[label.argmax()],
            edgecolor=color, 
            facecolor='none')
        ax.add_patch(rect)
    plt.axis('off')
    plt.legend()
    plt.show()

# Comparing with filtered BUTD

In [None]:
fdata = {}
with h5py.File(os.path.join(data_dir, 'features', 'filtered_butd_bboxes.hdf5'), 'r') as f:
    for key in f:
        fts = f[key][...]
        item = {
            'fts': fts
        }
        for k, v in f[key].attrs.items():
            item[k] = v
        fdata[key] = item

# Clean up classes from LVIS

In [None]:
fdata = {}
with h5py.File(os.path.join(data_dir, 'features', 'filtered_butd_bboxes.hdf5'), 'r') as f:
    for key in f:
        fts = f[key][...]
        item = {
            'fts': fts
        }
        for k, v in f[key].attrs.items():
            item[k] = v
        fdata[key] = item

# Predict eqlv2 features from REVERIE bbox

In [None]:
# feats, scan, viewpoint = dataset[2]
scan = '17DRP5sb8fy'
viewpoint = '08c774f20c984008882da2b8547850eb'
view_id = 1
key = f'{scan}_{viewpoint}'
feats = dataset.txn.get(key.encode('ascii'))
feats = pickle.loads(feats)
import json
from collections import defaultdict
with open(f"data/bbox/{scan}_{viewpoint}.json") as fid:
    data = json.load(fid)[viewpoint]    

bbox_per_view_id = defaultdict(list)
for obj, details in data.items():
    for view_id, box in zip(details['visible_pos'], details['bbox2d']):
        bbox_per_view_id[view_id].append(box)

im = feats['image_feat'][view_id]
bbox = torch.Tensor(bbox_per_view_id[view_id]).cuda()
bbox[:, 2] += bbox[:, 0]
bbox[:, 3] += bbox[:, 1]

results = inference_detector(model, np.array(im), [bbox])

In [None]:
results.keys()

In [None]:
[classes[l] for l in results['labels']]

In [None]:
results['bbox'].long()