Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
Refactor instance retrieval optional normalization (#381)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #381

1. Rename SHOULD_TRAIN_PCA_OR_WHITENING to TRAIN_PCA_WHITENING

2. Make l2 normalization optional.

3. Fix cfg access bugs

4. Add some more experiments.

Reviewed By: prigoyal

Differential Revision: D30002757

fbshipit-source-id: 3ec5be799a1d9bf2fa75c736fce9b2552db7966c
  • Loading branch information
iseessel authored and facebook-github-bot committed Aug 9, 2021
1 parent b4e6aa4 commit db63a8f
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 40 deletions.
90 changes: 56 additions & 34 deletions tools/instance_retrieval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
import logging
import os
import sys
import uuid
from argparse import Namespace
from typing import Any, List

import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from classy_vision.generic.util import copy_model_to_gpu, load_checkpoint
from fvcore.common.file_io import PathManager
Expand Down Expand Up @@ -92,7 +90,7 @@ def get_train_features(
):
train_features = []

def process_train_image(i, out_dir):
def process_train_image(i, out_dir, verbose=False):
if i % LOG_FREQUENCY == 0:
logging.info(f"Train Image: {i}"),

Expand All @@ -115,24 +113,35 @@ def process_train_image(i, out_dir):
vc = v.cuda()
# the model output is a list always.
activation_map = model(vc)[0].cpu()

if verbose:
print(f"Train Image raw activation map shape: { activation_map.shape }")

# once we have the features,
# we can perform: rmac | gem pooling | l2 norm
if cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "rmac":
descriptors = get_rmac_descriptors(activation_map, spatial_levels)
elif cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "l2_norm":
# we simply L2 normalize the features otherwise
descriptors = F.normalize(activation_map, p=2, dim=0)
descriptors = get_rmac_descriptors(
activation_map,
spatial_levels,
normalize=cfg.IMG_RETRIEVAL.NORMALIZE_FEATURES,
)
elif cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "gem":
descriptors = l2n(
gem(
activation_map,
p=cfg.IMG_RETRIEVAL.GEM_POOL_POWER,
add_bias=False,
)
descriptors = gem(
activation_map,
p=cfg.IMG_RETRIEVAL.GEM_POOL_POWER,
add_bias=True,
)
else:
descriptors = activation_map

# Optionally l2 normalize the features.
if (
cfg.IMG_RETRIEVAL.NORMALIZE_FEATURES
and cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE != "rmac"
):
# RMAC performs normalization within the algorithm, hence we skip it here.
descriptors = l2n(descriptors, dim=1)

if fname_out:
save_file(descriptors.data.numpy(), fname_out, verbose=False)
train_features.append(descriptors.data.numpy())
Expand All @@ -146,7 +155,7 @@ def process_train_image(i, out_dir):

logging.info(f"Getting features for train images: {num_images}")
for i in range(num_images):
process_train_image(i, out_dir)
process_train_image(i, out_dir, verbose=(i == 0))

train_features = np.vstack([x.reshape(-1, x.shape[-1]) for x in train_features])
logging.info(f"Train features size: {train_features.shape}")
Expand All @@ -163,6 +172,7 @@ def process_eval_image(
model,
pca,
eval_dataset_name,
verbose=False,
):
if is_revisited_dataset(eval_dataset_name):
img = image_helper.load_and_prepare_revisited_image(fname_in, roi=roi)
Expand All @@ -176,30 +186,39 @@ def process_eval_image(
# the model output is a list always.
activation_map = model(vc)[0].cpu()

if verbose:
print(f"Eval image raw activation map shape: { activation_map.shape }")

# process the features: rmac | l2 norm
if cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "rmac":
descriptors = get_rmac_descriptors(activation_map, spatial_levels, pca=pca)
elif cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "l2_norm":
# we simply L2 normalize the features otherwise
descriptors = F.normalize(activation_map, p=2, dim=0)
# Optionally apply pca.
if pca:
descriptors = pca.apply(descriptors)

descriptors = get_rmac_descriptors(
activation_map,
spatial_levels,
pca=pca,
normalize=cfg.IMG_RETRIEVAL.NORMALIZE_FEATURES,
)
elif cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "gem":
descriptors = l2n(
gem(
activation_map,
p=cfg.IMG_RETRIEVAL.GEM_POOL_POWER,
add_bias=True,
)
descriptors = gem(
activation_map,
p=cfg.IMG_RETRIEVAL.GEM_POOL_POWER,
add_bias=True,
)
# Optionally apply pca.
if pca:
descriptors = pca.apply(descriptors)
else:
descriptors = activation_map

# Optionally l2 normalize the features.
if (
cfg.IMG_RETRIEVAL.NORMALIZE_FEATURES
and cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE != "rmac"
):
# RMAC performs normalization within the algorithm, hence we skip it here.
descriptors = l2n(descriptors, dim=1)

# Optionally apply pca.
if pca and cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE != "rmac":
# RMAC performs pca within the algorithm, hence we skip it here.
descriptors = pca.apply(descriptors)

if fname_out:
save_file(descriptors.data.numpy(), fname_out, verbose=False)
return descriptors.data.numpy()
Expand Down Expand Up @@ -248,6 +267,7 @@ def get_dataset_features(
model,
pca,
eval_dataset_name,
verbose=(idx == 0),
)
features_dataset.append(db_feature)

Expand Down Expand Up @@ -286,6 +306,7 @@ def get_queries_features(
if idx % LOG_FREQUENCY == 0:
logging.info(f"Eval Query: {idx}"),
q_fname_in = eval_dataset.get_query_filename(idx)
# Optionally crop the query by the region-of-interest (ROI).
roi = (
eval_dataset.get_query_roi(idx)
if cfg.IMG_RETRIEVAL.CROP_QUERY_ROI
Expand All @@ -309,6 +330,7 @@ def get_queries_features(
model,
pca,
eval_dataset_name,
verbose=(idx == 0),
)
features_queries.append(query_feature)

Expand Down Expand Up @@ -345,7 +367,7 @@ def get_transforms(cfg, dataset_name):
def get_train_dataset(cfg, root_dataset_path, train_dataset_name, eval_binary_path):
# We only create the train dataset if we need PCA or whitening training.
# Otherwise not.
if cfg.IMG_RETRIEVAL.SHOULD_TRAIN_PCA_OR_WHITENING:
if cfg.IMG_RETRIEVAL.TRAIN_PCA_WHITENING:
train_data_path = f"{root_dataset_path}/{train_dataset_name}"
assert PathManager.exists(train_data_path), f"Unknown path: {train_data_path}"

Expand Down Expand Up @@ -444,7 +466,7 @@ def instance_retrieval_test(args, cfg):
############################################################################
# Step 2: Extract the features for the train dataset, calculate PCA or
# whitening and save
if cfg.IMG_RETRIEVAL.SHOULD_TRAIN_PCA_OR_WHITENING:
if cfg.IMG_RETRIEVAL.TRAIN_PCA_WHITENING:
logging.info("Extracting training features...")
# the features are already processed based on type: rmac | GeM | l2 norm
train_features = get_train_features(
Expand Down Expand Up @@ -551,7 +573,7 @@ def validate_and_infer_config(config: AttrDict):
), "Spatial levels must be greater than 0."
if config.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "rmac":
assert (
config.IMG_RETRIEVAL.SHOULD_TRAIN_PCA_OR_WHITENING
config.IMG_RETRIEVAL.TRAIN_PCA_WHITENING
), "PCA Whitening is built-in to the RMAC algorithm and is required"

return config
Expand Down
8 changes: 6 additions & 2 deletions vissl/config/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1257,10 +1257,10 @@ config:
# Whether or not to save the features that were extracted
SAVE_FEATURES: False
# Whether to apply PCA/whitening or not
SHOULD_TRAIN_PCA_OR_WHITENING: True
TRAIN_PCA_WHITENING: True
# gem | rmac | l2_norm
FEATS_PROCESSING_TYPE: ""
# valid only for GeM pooling of features
# valid only for GeM pooling of features. Note that GEM_POOL_POWER=1 equates to average pooling.
GEM_POOL_POWER: 4.0
# valid only if we are training whitening on the whitening dataset
WHITEN_IMG_LIST: ""
Expand All @@ -1276,6 +1276,10 @@ config:
# Relevant for Oxford, Paris, ROxford, and RParis datasets.
# Our experiments with RN-50/rmac show that ROI cropping degrades performance.
CROP_QUERY_ROI: False
# Whether or not to apply L2 norm after the features have been post-processed.
# Normalization is heavily recommended based on experiments run.
NORMALIZE_FEATURES: True


# ----------------------------------------------------------------------------------- #
# K-NEAREST NEIGHBOR (benchmark)
Expand Down
15 changes: 11 additions & 4 deletions vissl/utils/instance_retrieval_utils/rmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_rmac_region_coordinates(H, W, L):

# Credits: https://github.com/facebookresearch/deepcluster/blob/master/eval_retrieval.py # NOQA
# Adapted by: Priya Goyal (prigoyal@fb.com)
def get_rmac_descriptors(features, rmac_levels, pca=None):
def get_rmac_descriptors(features, rmac_levels, pca=None, normalize=True):
"""
RMAC descriptors. Coordinates are retrieved following Tolias et al.
L2 normalize the descriptors and optionally apply PCA on the descriptors
Expand All @@ -104,18 +104,25 @@ def get_rmac_descriptors(features, rmac_levels, pca=None):

rmac_descriptors = torch.cat(rmac_descriptors, 1)

rmac_descriptors = normalize_L2(rmac_descriptors, 2)
if normalize:
# Can optionally skip normalization -- not recommended.
# the original RMAC paper normalizes.
rmac_descriptors = normalize_L2(rmac_descriptors, 2)

if pca is None:
return rmac_descriptors

# PCA + whitening
npca = pca.n_components
rmac_descriptors = pca.apply(rmac_descriptors.view(nr * nim, nc))
rmac_descriptors = normalize_L2(rmac_descriptors, 1)

if normalize:
rmac_descriptors = normalize_L2(rmac_descriptors, 1)

rmac_descriptors = rmac_descriptors.view(nim, nr, npca)

# Sum aggregation and L2-normalization
rmac_descriptors = torch.sum(rmac_descriptors, 1)
rmac_descriptors = normalize_L2(rmac_descriptors, 1)
if normalize:
rmac_descriptors = normalize_L2(rmac_descriptors, 1)
return rmac_descriptors

0 comments on commit db63a8f

Please sign in to comment.