Skip to content

Commit

Permalink
Fix #4: uniform API calls for normalize and some comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
kdexd committed Jun 23, 2020
1 parent 4e2936a commit ff9ad87
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 22 deletions.
6 changes: 6 additions & 0 deletions configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml
@@ -1,3 +1,9 @@
# ----------------------------------------------------------------------------
# Train a Faster R-CNN with ResNet-50 and C4 backbone. This config follows
# Detectron2 format; and is unrealed with our VirTex configs. Params here
# replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722).
# ----------------------------------------------------------------------------

INPUT:
# Input format will always be RGB, consistent with torchvision.
FORMAT: "RGB"
Expand Down
8 changes: 8 additions & 0 deletions configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml
@@ -1,3 +1,9 @@
# ----------------------------------------------------------------------------
# Train a Mask R-CNN with ResNet-50 and FPN backbone. This config follows
# Detectron2 format; and is unrealed with our VirTex configs. Params here
# replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722).
# ----------------------------------------------------------------------------

INPUT:
# Input format will always be RGB, consistent with torchvision.
FORMAT: "RGB"
Expand Down Expand Up @@ -52,6 +58,8 @@ MODEL:
POOLER_RESOLUTION: 14

# ImageNet color mean for torchvision-like models (RGB order).
# These are in [0-255] range as expected by Detectron2. Rest of our codebase
# uses [0-1] range; but both are equivalent and consistent.
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]

Expand Down
30 changes: 13 additions & 17 deletions scripts/pretrain_insup.py
Expand Up @@ -36,6 +36,7 @@
from torch.utils.tensorboard import SummaryWriter
from torchvision import models

import virtex.data.transforms as T
from virtex.data.datasets.downstream_datasets import ImageNetDataset
import virtex.utils.distributed as vdist
from virtex.utils.metrics import TopkAccuracy
Expand Down Expand Up @@ -167,29 +168,24 @@ def main_worker(gpu, ngpus_per_node, _A):
)
logger.info(f"Size of dataset: {len(train_dataset)}")
val_dataset = ImageNetDataset(root=_A.data, split="val")
# Val dataset is used sparsely, don't keep it around in memory by caching.

normalize = alb.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_pixel_value=1.0,
always_apply=True,
)
# Override image transform (class definition has transform according to
# downstream linear classification protocol).
# fmt: off
# Data augmentation as per ImageNet training from official PyTorch examples.
train_dataset.image_transform = alb.Compose([
alb.RandomResizedCrop(224, 224, always_apply=True),
alb.HorizontalFlip(p=0.5),
alb.ToFloat(max_value=255.0, always_apply=True),
normalize,
T.RandomResizedSquareCrop(224, always_apply=True),
T.HorizontalFlip(p=0.5),
alb.Normalize(
mean=T.IMAGENET_COLOR_MEAN, std=T.IMAGENET_COLOR_STD, always_apply=True
)
])
val_dataset.image_transform = alb.Compose([
alb.Resize(256, 256, always_apply=True),
alb.CenterCrop(224, 224, always_apply=True),
alb.ToFloat(max_value=255.0, always_apply=True),
normalize,
T.SquareResize(256, always_apply=True),
T.CenterSquareCrop(224, always_apply=True),
alb.Normalize(
mean=T.IMAGENET_COLOR_MEAN, std=T.IMAGENET_COLOR_STD, always_apply=True
)
])

train_sampler = DistributedSampler(train_dataset, shuffle=True)
val_sampler = DistributedSampler(val_dataset)
train_loader = DataLoader(
Expand Down
16 changes: 12 additions & 4 deletions virtex/data/transforms.py
Expand Up @@ -132,13 +132,21 @@ def get_transform_init_args_names(self):
class HorizontalFlip(ImageCaptionTransform):
r"""
Flip the image horizontally randomly (equally likely) and replace the
word "left" with "right" in the caption. This transform can also work on
images only (without the captions).
word "left" with "right" in the caption.
.. note::
This transform can also work on images only (without the captions).
Its behavior will be same as albumentations
:class:`~albumentations.augmentations.transforms.HorizontalFlip`.
Examples
--------
>>> flip = ImageCaptionHorizontalFlip(p=0.5)
>>> out = flip(image=image, caption=caption) # keys: {"image", "caption"}
>>> flip = HorizontalFlip(p=0.5)
>>> out1 = flip(image=image, caption=caption) # keys: {"image", "caption"}
>>> # Also works with images (without caption).
>>> out2 = flip(image=image) # keys: {"image"}
"""

def apply(self, img, **params):
Expand Down
3 changes: 2 additions & 1 deletion virtex/factories.py
Expand Up @@ -140,7 +140,8 @@ class ImageTransformsFactory(Factory):
),
"horizontal_flip": partial(T.HorizontalFlip, p=0.5),

# Color normalization: whenever selected, always applied.
# Color normalization: whenever selected, always applied. This accepts images
# in [0, 255], requires mean and std in [0, 1] and normalizes to `N(0, 1)`.
"normalize": partial(
alb.Normalize, mean=T.IMAGENET_COLOR_MEAN, std=T.IMAGENET_COLOR_STD, p=1.0
),
Expand Down

0 comments on commit ff9ad87

Please sign in to comment.