Skip to content
This repository has been archived by the owner on Feb 16, 2022. It is now read-only.

Commit

Permalink
Add scripts to extract features from images with GT bboxes
Browse files Browse the repository at this point in the history
  • Loading branch information
vedanuj committed Feb 7, 2020
1 parent b9eaeaf commit 9158751
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 1 deletion.
28 changes: 27 additions & 1 deletion data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,33 @@ Run from root directory
python script/extract_features.py --model_file data/detectron_model.pth --config_file data/detectron_config.yaml --image_dir <path_to_directory_with_images> --output_folder <path_to_output_extracted_features>
```

3. Convert the extracted images to an LMDB file
3. Extract features for images with GT bbox

Generate a `.npy` file with the following format for all the images and their bboxes

```text
{
{
'file_name': 'name_of_image_file',
'file_path': '<path_to_image_file_on_your_disk>',
'bbox': array([
[ x1, y1, width1, height1],
[ x2, y2, width2, height2],
...
]),
'num_box': 2
},
....
}
```

Run from root directory

```text
python script/extract_features.py --model_file data/detectron_model.pth --config_file data/detectron_config.yaml --imdb_gt_file <path_to_imdb_npy_file_generated_above> --output_folder <path_to_output_extracted_features>
```

4. Convert the extracted images to an LMDB file

```text
python script/convert_to_lmdb.py --features_dir <path_to_extracted_features> --lmdb_file <path_to_output_lmdb_file>
Expand Down
250 changes: 250 additions & 0 deletions script/extract_features_from_gt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# Requires vqa-maskrcnn-benchmark to be built and installed. See Readme
# for more details.
import argparse
import os

import cv2
import numpy as np
import torch
from PIL import Image

from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.layers import nms
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.image_list import to_image_list
from maskrcnn_benchmark.utils.model_serialization import load_state_dict


class FeatureExtractor:
MAX_SIZE = 1333
MIN_SIZE = 800

def __init__(self):
self.args = self.get_parser().parse_args()
self.detection_model = self._build_detection_model()

os.makedirs(self.args.output_folder, exist_ok=True)

def get_parser(self):
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_file", default=None, type=str, help="Detectron model file"
)
parser.add_argument(
"--config_file", default=None, type=str, help="Detectron config file"
)
parser.add_argument(
"--imdb_gt_file",
default=None,
type=str,
help="Imdb file containing file path and bboxes.",
)
parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
parser.add_argument(
"--num_features",
type=int,
default=100,
help="Number of features to extract.",
)
parser.add_argument(
"--output_folder", type=str, default="./output", help="Output folder"
)
parser.add_argument(
"--feature_name",
type=str,
help="The name of the feature to extract",
default="fc6",
)
parser.add_argument(
"--confidence_threshold",
type=float,
default=0,
help="Threshold of detection confidence above which boxes will be selected",
)
parser.add_argument(
"--background",
action="store_true",
help="The model will output predictions for the background class when set",
)
parser.add_argument(
"--partition", type=int, default=0, help="Partition to download."
)
return parser

def _build_detection_model(self):
cfg.merge_from_file(self.args.config_file)
cfg.freeze()

model = build_detection_model(cfg)
checkpoint = torch.load(self.args.model_file, map_location=torch.device("cpu"))

load_state_dict(model, checkpoint.pop("model"))

model.to("cuda")
model.eval()
return model

def get_batch_proposals(self, images, im_scales, im_infos, proposals):
proposals_batch = []
for idx, img_info in enumerate(im_infos):
boxes_tensor = torch.from_numpy(
proposals[idx]["bbox"][: int(proposals[idx]["num_box"]), 0:]
).to("cuda")
orig_image_size = (img_info["width"], img_info["height"])
boxes = BoxList(boxes_tensor, orig_image_size)
image_size = (images.image_sizes[idx][1], images.image_sizes[idx][0])
boxes = boxes.resize(image_size)
proposals_batch.append(boxes)
return proposals_batch

def _image_transform(self, path):
img = Image.open(path)
im = np.array(img).astype(np.float32)
# IndexError: too many indices for array, grayscale images
if len(im.shape) < 3:
im = np.repeat(im[:, :, np.newaxis], 3, axis=2)
im = im[:, :, ::-1]
im -= np.array([102.9801, 115.9465, 122.7717])
im_shape = im.shape
im_height = im_shape[0]
im_width = im_shape[1]
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])

# Scale based on minimum size
im_scale = self.MIN_SIZE / im_size_min

# Prevent the biggest axis from being more than max_size
# If bigger, scale it down
if np.round(im_scale * im_size_max) > self.MAX_SIZE:
im_scale = self.MAX_SIZE / im_size_max

im = cv2.resize(
im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR
)
img = torch.from_numpy(im).permute(2, 0, 1)

im_info = {"width": im_width, "height": im_height}

return img, im_scale, im_info

def _process_feature_extraction(
self, output, im_scales, im_infos, feature_name="fc6", conf_thresh=0
):
batch_size = len(output[0]["proposals"])
n_boxes_per_image = [len(boxes) for boxes in output[0]["proposals"]]
score_list = output[0]["scores"].split(n_boxes_per_image)
score_list = [torch.nn.functional.softmax(x, -1) for x in score_list]
feats = output[0][feature_name].split(n_boxes_per_image)
cur_device = score_list[0].device

feat_list = []
info_list = []

for i in range(batch_size):
dets = output[0]["proposals"][i].bbox / im_scales[i]
scores = score_list[i]
max_conf = torch.zeros((scores.shape[0])).to(cur_device)
conf_thresh_tensor = torch.full_like(max_conf, conf_thresh)
start_index = 1
# Column 0 of the scores matrix is for the background class
if self.args.background:
start_index = 0
for cls_ind in range(start_index, scores.shape[1]):
cls_scores = scores[:, cls_ind]
keep = nms(dets, cls_scores, 0.5)
max_conf[keep] = torch.where(
# Better than max one till now and minimally greater than conf_thresh
(cls_scores[keep] > max_conf[keep])
& (cls_scores[keep] > conf_thresh_tensor[keep]),
cls_scores[keep],
max_conf[keep],
)

feat_list.append(feats[i])
num_boxes = len(feats[i])
bbox = output[0]["proposals"][i]
bbox = bbox.resize(((im_infos[i]["width"], im_infos[i]["height"])))
bbox = bbox.bbox
# Predict the class label using the scores
objects = torch.argmax(scores[:, start_index:], dim=1)

info_list.append(
{
"bbox": bbox.cpu().numpy(),
"num_boxes": num_boxes,
"objects": objects.cpu().numpy(),
"image_width": im_infos[i]["width"],
"image_height": im_infos[i]["height"],
"cls_prob": scores.cpu().numpy(),
}
)

return feat_list, info_list

def get_detectron_features(self, image_paths):
img_tensor, im_scales, im_infos = [], [], []

for image_path in image_paths:
im, im_scale, im_info = self._image_transform(image_path["file_path"])
img_tensor.append(im)
im_scales.append(im_scale)
im_infos.append(im_info)

# Image dimensions should be divisible by 32, to allow convolutions
# in detector to work
current_img_list = to_image_list(img_tensor, size_divisible=32)
current_img_list = current_img_list.to("cuda")

proposals = self.get_batch_proposals(
current_img_list, im_scales, im_infos, image_paths
)

with torch.no_grad():
output = self.detection_model(current_img_list, proposals=proposals)

feat_list = self._process_feature_extraction(
output,
im_scales,
im_infos,
self.args.feature_name,
self.args.confidence_threshold,
)

return feat_list

def _chunks(self, array, chunk_size):
for i in range(0, len(array), chunk_size):
yield array[i : i + chunk_size]

def _save_feature(self, file_name, feature, info):
file_base_name = os.path.basename(file_name)
file_base_name = file_base_name.split(".")[0]
info["image_id"] = file_base_name
info["features"] = feature.cpu().numpy()
file_base_name = str(file_base_name) + ".npy"

np.save(os.path.join(self.args.output_folder, file_base_name), info)

def extract_features(self):
files = np.load(args.imdb_gt_file, allow_pickle=True)
# files = sorted(files)
# files = [files[i: i+1000] for i in range(0, len(files), 1000)][self.args.partition]
for chunk in self._chunks(files, self.args.batch_size):
try:
features, infos = self.get_detectron_features(chunk)
for idx, c in enumerate(chunk):
self._save_feature(c["file_name"], features[idx], infos[idx])
except BaseException:
continue


if __name__ == "__main__":
feature_extractor = FeatureExtractor()
feature_extractor.extract_features()

0 comments on commit 9158751

Please sign in to comment.