Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe committed Sep 13, 2022
1 parent 1670b9d commit ccda304
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 44 deletions.
11 changes: 11 additions & 0 deletions hloc/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@
'resize_max': 1600,
},
},
'disk': {
'output': 'feats-disk',
'model': {
'name': 'disk',
'max_keypoints': 5000,
},
'preprocessing': {
'grayscale': False,
'resize_max': 1600,
},
},
# Global descriptors
'dir': {
'output': 'global-feats-dir',
Expand Down
76 changes: 32 additions & 44 deletions hloc/extractors/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

disk_path = Path(__file__).parent / "../../third_party/disk"
sys.path.append(str(disk_path))
from disk import DISK as _DISK
from disk import DISK as _DISK # noqa E402


class DISK(BaseModel):
Expand All @@ -22,17 +22,18 @@ class DISK(BaseModel):
required_inputs = ['image']

def _init(self, conf):
state_dict = torch.load(disk_path/conf['model_name'],
map_location='cpu')
self.model = _DISK(window=8, desc_dim=conf['desc_dim'])

state_dict = torch.load(
disk_path / conf['model_name'], map_location='cpu')
if 'extractor' in state_dict:
weights = state_dict['extractor']
elif 'disk' in state_dict:
weights = state_dict['disk']
else:
raise KeyError('Incompatible weight file!')
self.model = _DISK(window=8, desc_dim=conf['desc_dim'])
self.model.load_state_dict(weights)

if conf['mode'] == 'nms':
self.extract = partial(
self.model.features,
Expand All @@ -44,50 +45,37 @@ def _init(self, conf):
elif conf['mode'] == 'rng':
self.extract = partial(self.model.features, kind='rng')
else:
raise KeyError('mode must be either nms or rng!')
raise KeyError(
f'mode must be `nms` or `rng`, got `{conf["mode"]}`')

def _forward(self, data):
img = data['image'][0]
assert len(img.shape) == 3 and img.shape[0] == 3
# pad img so that its height and width be the multiple of 16
# as required by the original dis repo
orig_h, orig_w = img.shape[1:]
new_h = ((orig_h-1)//16 + 1) * 16
new_w = ((orig_w-1)//16 + 1) * 16
y_pad = new_h - orig_h
x_pad = new_w - orig_w
image = data['image']
# make sure that the dimensions of the image are multiple of 16
orig_h, orig_w = image.shape[-2:]
new_h = round(orig_h / 16) * 16
new_w = round(orig_w / 16) * 16
image = F.pad(image, (0, new_w - orig_w, 0, new_h - orig_h))

img = F.pad(img, (0, x_pad, 0, y_pad))
assert img.shape[1] == new_h and img.shape[2] == new_w, "Wrong Padding"

batched_features = self.extract(img[None]) # add batch dimension
batched_features = self.extract(image)

assert(len(batched_features) == 1)
features = batched_features[0]
for features in batched_features.flat:
features = features.to(torch.device('cpu'))

kps_crop_space = features.kp.t()

kps_img_space = kps_crop_space # (2, N)
x = kps_crop_space[0, :]
y = kps_crop_space[1, :]
mask = (0 <= x) & (x <= orig_w-1) & (0 <= y) & (y <= orig_h-1)

keypoints = kps_img_space.t()[mask]
descriptors = features.desc[mask]
scores = features.kp_logp[mask]

order = torch.argsort(-scores)

keypoints = keypoints[order]
descriptors = descriptors[order]
scores = scores[order]

assert descriptors.shape[1] == self.conf['desc_dim']
assert keypoints.shape[1] == 2

pred = {'keypoints': keypoints[None],
'descriptors': descriptors.t()[None],
'scores': scores[None]}
return pred
# filter points detected in the padded areas
kpts = features.kp
valid = torch.all(kpts <= kpts.new_tensor([orig_w, orig_h]) - 1, 1)
kpts = kpts[valid]
descriptors = features.desc[valid]
scores = features.kp_logp[valid]

# order the keypoints
indices = torch.argsort(scores, descending=True)
kpts = kpts[indices]
descriptors = descriptors[indices]
scores = scores[indices]

return {
'keypoints': kpts[None],
'descriptors': descriptors.t()[None],
'scores': scores[None],
}

0 comments on commit ccda304

Please sign in to comment.