From ff9ad87f4a74a5949e37a8e52838c74e474f89de Mon Sep 17 00:00:00 2001 From: Karan Desai Date: Tue, 23 Jun 2020 14:50:08 -0400 Subject: [PATCH] Fix #4: uniform API calls for normalize and some comments. --- .../_base_faster_rcnn_R_50_C4_BN.yaml | 6 ++++ .../detectron2/_base_mask_rcnn_R_50_FPN.yaml | 8 +++++ scripts/pretrain_insup.py | 30 ++++++++----------- virtex/data/transforms.py | 16 +++++++--- virtex/factories.py | 3 +- 5 files changed, 41 insertions(+), 22 deletions(-) diff --git a/configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml b/configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml index a965cd86..49a39e42 100644 --- a/configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml +++ b/configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml @@ -1,3 +1,9 @@ +# ---------------------------------------------------------------------------- +# Train a Faster R-CNN with ResNet-50 and C4 backbone. This config follows +# Detectron2 format; and is unrealed with our VirTex configs. Params here +# replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722). +# ---------------------------------------------------------------------------- + INPUT: # Input format will always be RGB, consistent with torchvision. FORMAT: "RGB" diff --git a/configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml b/configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml index ffedf59d..3751ccb8 100644 --- a/configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml +++ b/configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml @@ -1,3 +1,9 @@ +# ---------------------------------------------------------------------------- +# Train a Mask R-CNN with ResNet-50 and FPN backbone. This config follows +# Detectron2 format; and is unrealed with our VirTex configs. Params here +# replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722). +# ---------------------------------------------------------------------------- + INPUT: # Input format will always be RGB, consistent with torchvision. FORMAT: "RGB" @@ -52,6 +58,8 @@ MODEL: POOLER_RESOLUTION: 14 # ImageNet color mean for torchvision-like models (RGB order). + # These are in [0-255] range as expected by Detectron2. Rest of our codebase + # uses [0-1] range; but both are equivalent and consistent. PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] diff --git a/scripts/pretrain_insup.py b/scripts/pretrain_insup.py index 49de0431..81b0506b 100644 --- a/scripts/pretrain_insup.py +++ b/scripts/pretrain_insup.py @@ -36,6 +36,7 @@ from torch.utils.tensorboard import SummaryWriter from torchvision import models +import virtex.data.transforms as T from virtex.data.datasets.downstream_datasets import ImageNetDataset import virtex.utils.distributed as vdist from virtex.utils.metrics import TopkAccuracy @@ -167,29 +168,24 @@ def main_worker(gpu, ngpus_per_node, _A): ) logger.info(f"Size of dataset: {len(train_dataset)}") val_dataset = ImageNetDataset(root=_A.data, split="val") - # Val dataset is used sparsely, don't keep it around in memory by caching. - normalize = alb.Normalize( - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225), - max_pixel_value=1.0, - always_apply=True, - ) - # Override image transform (class definition has transform according to - # downstream linear classification protocol). # fmt: off + # Data augmentation as per ImageNet training from official PyTorch examples. train_dataset.image_transform = alb.Compose([ - alb.RandomResizedCrop(224, 224, always_apply=True), - alb.HorizontalFlip(p=0.5), - alb.ToFloat(max_value=255.0, always_apply=True), - normalize, + T.RandomResizedSquareCrop(224, always_apply=True), + T.HorizontalFlip(p=0.5), + alb.Normalize( + mean=T.IMAGENET_COLOR_MEAN, std=T.IMAGENET_COLOR_STD, always_apply=True + ) ]) val_dataset.image_transform = alb.Compose([ - alb.Resize(256, 256, always_apply=True), - alb.CenterCrop(224, 224, always_apply=True), - alb.ToFloat(max_value=255.0, always_apply=True), - normalize, + T.SquareResize(256, always_apply=True), + T.CenterSquareCrop(224, always_apply=True), + alb.Normalize( + mean=T.IMAGENET_COLOR_MEAN, std=T.IMAGENET_COLOR_STD, always_apply=True + ) ]) + train_sampler = DistributedSampler(train_dataset, shuffle=True) val_sampler = DistributedSampler(val_dataset) train_loader = DataLoader( diff --git a/virtex/data/transforms.py b/virtex/data/transforms.py index 9c5ba831..cbef025b 100644 --- a/virtex/data/transforms.py +++ b/virtex/data/transforms.py @@ -132,13 +132,21 @@ def get_transform_init_args_names(self): class HorizontalFlip(ImageCaptionTransform): r""" Flip the image horizontally randomly (equally likely) and replace the - word "left" with "right" in the caption. This transform can also work on - images only (without the captions). + word "left" with "right" in the caption. + + .. note:: + + This transform can also work on images only (without the captions). + Its behavior will be same as albumentations + :class:`~albumentations.augmentations.transforms.HorizontalFlip`. Examples -------- - >>> flip = ImageCaptionHorizontalFlip(p=0.5) - >>> out = flip(image=image, caption=caption) # keys: {"image", "caption"} + >>> flip = HorizontalFlip(p=0.5) + >>> out1 = flip(image=image, caption=caption) # keys: {"image", "caption"} + >>> # Also works with images (without caption). + >>> out2 = flip(image=image) # keys: {"image"} + """ def apply(self, img, **params): diff --git a/virtex/factories.py b/virtex/factories.py index c3ad2d76..c079dad2 100644 --- a/virtex/factories.py +++ b/virtex/factories.py @@ -140,7 +140,8 @@ class ImageTransformsFactory(Factory): ), "horizontal_flip": partial(T.HorizontalFlip, p=0.5), - # Color normalization: whenever selected, always applied. + # Color normalization: whenever selected, always applied. This accepts images + # in [0, 255], requires mean and std in [0, 1] and normalizes to `N(0, 1)`. "normalize": partial( alb.Normalize, mean=T.IMAGENET_COLOR_MEAN, std=T.IMAGENET_COLOR_STD, p=1.0 ),