diff --git a/README.md b/README.md index 3cdd591..a5eb664 100644 --- a/README.md +++ b/README.md @@ -54,19 +54,23 @@ python src/emma_perception/commands/run_server.py ### Extracting features +For training things, we need to extract the features for each image. -#### For the pretrained datasets +Here's the command you can use to extract features from images. Obviously, you can change the paths to the folder of images, and the output dir, and whatever else you want. ```bash +python src/emma_perception/commands/extract_visual_features.py --images_dir --output_dir ``` +
+`argparse` arguments for the command -#### For the Alexa Arena +
-```bash -``` +#### Extracting features for the Alexa Arena +If you want to use the fine-tuned model to extract features with the model we trained on the Alexa Arena, just add `--is_arena` onto the above command. ### Developer tooling diff --git a/src/emma_perception/commands/download_checkpoints.py b/src/emma_perception/commands/download_checkpoints.py index 7b57002..99880d5 100644 --- a/src/emma_perception/commands/download_checkpoints.py +++ b/src/emma_perception/commands/download_checkpoints.py @@ -19,7 +19,7 @@ def download_arena_checkpoint( def download_vinvl_checkpoint( - *, hf_repo_id: str = HF_REPO_ID, file_name: str = ARENA_CHECKPOINT_NAME + *, hf_repo_id: str = HF_REPO_ID, file_name: str = VINVL_CHECKPOINT_NAME ) -> Path: """Download the pre-trained VinVL checkpoint.""" file_path = download_file(repo_id=hf_repo_id, repo_type="model", filename=file_name) diff --git a/src/emma_perception/commands/extract_visual_features.py b/src/emma_perception/commands/extract_visual_features.py index 984aa5c..51331c3 100644 --- a/src/emma_perception/commands/extract_visual_features.py +++ b/src/emma_perception/commands/extract_visual_features.py @@ -1,59 +1,45 @@ import argparse -from typing import Union +from ast import arg from maskrcnn_benchmark.config import cfg from pytorch_lightning import Trainer from scene_graph_benchmark.config import sg_cfg from emma_perception.callbacks.callbacks import VisualExtractionCacheCallback -from emma_perception.datamodules.visual_extraction_dataset import ( - ImageDataset, - PredictDataModule, - VideoFrameDataset, +from emma_perception.commands.download_checkpoints import ( + download_arena_checkpoint, + download_vinvl_checkpoint, ) +from emma_perception.datamodules.visual_extraction_dataset import ImageDataset, PredictDataModule from emma_perception.models.vinvl_extractor import VinVLExtractor, VinVLTransform def parse_args() -> argparse.Namespace: """Defines arguments.""" - parser = argparse.ArgumentParser(prog="PROG") - + parser = argparse.ArgumentParser() parser = Trainer.add_argparse_args(parser) # type: ignore[assignment] - parser.add_argument("-i", "--input_path", required=True, help="Path to input dataset") - parser.add_argument("-b", "--batch_size", type=int, default=2) - parser.add_argument("-w", "--num_workers", type=int, default=0) - parser.add_argument("-cs", "--cache_suffix", default=".pt", help="Extension of cached files") - parser.add_argument("--config_file", metavar="FILE", help="path to VinVL config file") - parser.add_argument("--return_predictions", action="store_true") - - parser.add_argument( - "--downsample", - type=int, - default=0, - help="Downsampling factor for videos. If 0 then no downsampling is performed", - ) - parser.add_argument( - "-c", "--cache", default="storage/data/cache", help="Path to store visual features" + "-i", + "--images_dir", + required=True, + help="Path to a folder of images to extract features from", ) - parser.add_argument( - "-d", "--dataset", required=True, choices=["images", "frames"], help="Dataset type" + "--is_arena", + action="store_true", + help="If we are extracting features from the Arena images, use the Arena checkpoint", ) + parser.add_argument("-b", "--batch_size", type=int, default=2) + parser.add_argument("-w", "--num_workers", type=int, default=0) parser.add_argument( - "-a", - "--ann_csv", - help="Path to annotation csv file. Used for video datasets to select only the frames that have annotations", + "-c", "--output_dir", default="storage/data/cache", help="Path to store visual features" ) - parser.add_argument( - "-at", - "--ann_type", - choices=["epic_kitchens"], - default="epic_kitchens", - help="Annotation parser for video datasets", + "--num_gpus", + type=int, + default=None, + help="Number of GPUs to use for visual feature extraction", ) - parser.add_argument( "opts", default=None, @@ -75,34 +61,23 @@ def main() -> None: cfg.merge_from_list(args.opts) cfg.freeze() - extractor = VinVLExtractor(cfg=cfg) - transform = VinVLTransform(cfg=cfg) - - dataset: Union[ImageDataset, VideoFrameDataset] - if args.dataset == "images": - dataset = ImageDataset(input_path=args.input_path, preprocess_transform=transform) - elif args.dataset == "frames": - dataset = VideoFrameDataset( - input_path=args.input_path, - ann_csv=args.ann_csv, - ann_type=args.ann_type, - preprocess_transform=transform, - downsample=args.downsample, - ) + if args.is_arena: + cfg.MODEL.WEIGHT = download_arena_checkpoint().as_posix() else: - raise OSError(f"Unsupported dataset type {args.dataset}") + cfg.MODEL.WEIGHT = download_vinvl_checkpoint().as_posix() + dataset = ImageDataset( + input_path=args.images_dir, preprocess_transform=VinVLTransform(cfg=cfg) + ) dm = PredictDataModule( dataset=dataset, batch_size=args.batch_size, num_workers=args.num_workers ) + extractor = VinVLExtractor(cfg=cfg) trainer = Trainer( - gpus=args.gpus, - callbacks=[ - VisualExtractionCacheCallback(cache_dir=args.cache, cache_suffix=args.cache_suffix) - ], - profiler="advanced", + gpus=args.num_gpus, + callbacks=[VisualExtractionCacheCallback(cache_dir=args.output_dir, cache_suffix=".pt")], ) - trainer.predict(extractor, dm, return_predictions=args.return_predictions) + trainer.predict(extractor, dm) if __name__ == "__main__": diff --git a/src/emma_perception/constants/vinvl_x152c4.yaml b/src/emma_perception/constants/vinvl_x152c4.yaml index 38eece7..fbb0e7e 100644 --- a/src/emma_perception/constants/vinvl_x152c4.yaml +++ b/src/emma_perception/constants/vinvl_x152c4.yaml @@ -17,6 +17,7 @@ MODEL: SCORE_THRESH: 0.2 # 0.0001 DETECTIONS_PER_IMG: 36 # 600 MIN_DETECTIONS_PER_IMG: 10 + NMS_FILTER: 1 ROI_BOX_HEAD: NUM_CLASSES: 1595 ROI_ATTRIBUTE_HEAD: @@ -52,6 +53,7 @@ TEST: TSV_SAVE_SUBSET: ["rect", "class", "conf", "feature"] OUTPUT_FEATURE: True GATHER_ON_CPU: True + IGNORE_BOX_REGRESSION: False OUTPUT_DIR: "./output/X152C5_test" DATA_DIR: "./datasets" DISTRIBUTED_BACKEND: "gloo"