Skip to content

Commit

Permalink
Merge branch 'decoding' to 'master'
Browse files Browse the repository at this point in the history
  • Loading branch information
Karan Desai committed Jul 15, 2021
2 parents 341aef7 + f33ef5a commit 1cb3954
Show file tree
Hide file tree
Showing 10 changed files with 345 additions and 181 deletions.
21 changes: 18 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
ArXiv v1 -> v2 CHANGELOG
=========================
CHANGELOG
=========

[ArXiv v1](https://arxiv.org/abs/2006.06666v1) was our ECCV 2020 submission (reject). [ArXiv v2](https://arxiv.org/abs/2006.06666v2) is out CVPR 2021 submission (accept). The repository snapshots for these two versions are tagged at [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9) and [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0).
This CHANGELOG file records changes between different arXiv versions of our paper, and the version of this codebase which should be used to reproduce the results in the corresponding arXiv version. View changes between code versions on the [Releases page](https://github.com/kdexd/virtex/releases).

ArXiv v1 -> v2
==============

**Code version:** `v1.2`.

Fix image captioning results with a modified beam search implementation. _Rest of the downstream task results and pre-trained models are unchanged._


ArXiv v1 -> v2
==============

**Code version:** `v1.0` or `v1.1`.

[ArXiv v1](https://arxiv.org/abs/2006.06666v1) was our ECCV 2020 submission (reject). [ArXiv v2](https://arxiv.org/abs/2006.06666v2) is our CVPR 2021 submission (accept). The repository snapshots for these two versions are tagged at [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9) and [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0).

While the core motivation and approach is the same, we have made some minor changes in our experiments and evaluation setup. These slightly improve model performances across the board (within decimals). New models are available in [`v1.0` model zoo](http://kdexd.github.io/virtex/virtex/usage/model_zoo.html), however links to old models in `v0.9` will be active till June 30, 2021. We encourage you to use the new models!

Expand Down
6 changes: 6 additions & 0 deletions configs/_base_bicaptioning_R_50_L1_H1024.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,20 @@ DATA:

MODEL:
NAME: "virtex"

VISUAL:
NAME: "torchvision::resnet50"
PRETRAINED: false
FROZEN: false

TEXTUAL:
NAME: "transdec_postnorm::L1_H1024_A16_F4096"
DROPOUT: 0.1

DECODER:
NAME: "beam_search"
BEAM_SIZE: 5

OPTIM:
OPTIMIZER_NAME: "sgd"
SGD_MOMENTUM: 0.9
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
author = "Karan Desai"

# The full version, including alpha/beta/rc tags
release = "1.1"
release = "1.2"


# -- General configuration ---------------------------------------------------
Expand Down
6 changes: 5 additions & 1 deletion scripts/eval_captioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
evaluate pretrained model on COCO Captions val2017 split."""
)
parser.add_argument(
"--data-root", default=None,
"--images", "--data-root", default=None,
help="""Path to a directory containing image files to generate captions for.
Default: COCO val2017 image directory as expected relative to project root."""
)
Expand Down Expand Up @@ -89,6 +89,10 @@ def main(_A: argparse.Namespace):
}
)

logger.info("Displaying first 25 caption predictions:")
for pred in predictions[:25]:
logger.info(f"{pred['image_id']} :: {pred['caption']}")

# Save predictions as a JSON file if specified.
if _A.output is not None:
os.makedirs(os.path.dirname(_A.output), exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_model_zoo_configs() -> List[str]:

setup(
name="virtex",
version="1.1.0",
version="1.2.0",
author="Karan Desai and Justin Johnson",
description="VirTex: Learning Visual Representations with Textual Annotations",
package_data={"virtex.model_zoo": get_model_zoo_configs()},
Expand Down
14 changes: 14 additions & 0 deletions virtex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,20 @@ def __init__(
# Dropout probability for embedding, hidden features in textual head.
_C.MODEL.TEXTUAL.DROPOUT = 0.1

_C.MODEL.DECODER = CN()
# What algorithm to use for decoding. Supported values: {"beam_search",
# "nucleus_sampling"}.
_C.MODEL.DECODER.NAME = "beam_search"
# Number of beams to decode (1 = greedy decoding). Ignored when decoding
# through nucleus sampling.
_C.MODEL.DECODER.BEAM_SIZE = 5
# Size of nucleus for sampling predictions. Ignored when decoding through
# beam search.
_C.MODEL.DECODER.NUCLEUS_SIZE = 0.9
# Maximum length of decoded caption. Decoding may end earlier when [EOS]
# token is sampled.
_C.MODEL.DECODER.MAX_DECODING_STEPS = _C.DATA.MAX_CAPTION_LENGTH

# ---------------------------------------------------------------------
# Optimization hyper-parameters, default values are for pretraining
# our best model on bicaptioning task (COCO Captions).
Expand Down
47 changes: 43 additions & 4 deletions virtex/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@
signature of underlying class; or config hierarchy. Refer description of
specific factories for more details.
"""
from functools import partial
import re
from functools import partial
from typing import Any, Callable, Dict, Iterable, List

import albumentations as alb
from torch import nn, optim

from virtex.config import Config
import virtex.data as vdata
import virtex.models as vmodels
from virtex.config import Config
from virtex.data import transforms as T
from virtex.data.tokenizers import SentencePieceBPETokenizer
import virtex.models as vmodels
from virtex.modules import visual_backbones, textual_heads
from virtex.optim import Lookahead, lr_scheduler

from virtex.utils.beam_search import AutoRegressiveBeamSearch
from virtex.utils.nucleus_sampling import AutoRegressiveNucleusSampling


class Factory(object):
r"""
Expand Down Expand Up @@ -460,9 +463,9 @@ def from_config(cls, config: Config) -> nn.Module:
# for matching kwargs here.
if _C.MODEL.NAME in {"virtex", "captioning", "bicaptioning"}:
kwargs = {
"max_decoding_steps": _C.DATA.MAX_CAPTION_LENGTH,
"sos_index": _C.DATA.SOS_INDEX,
"eos_index": _C.DATA.EOS_INDEX,
"decoder": CaptionDecoderFactory.from_config(_C),
}

elif _C.MODEL.NAME == "token_classification":
Expand All @@ -482,6 +485,42 @@ def from_config(cls, config: Config) -> nn.Module:
return cls.create(_C.MODEL.NAME, visual, textual, **kwargs)


class CaptionDecoderFactory(Factory):
r"""
Factory to create decoders from predicting captions from VirTex model.
Possible choices: ``{"beam_search", "nucleus_sampling"}``.
"""

PRODUCTS: Dict[str, Callable] = {
"beam_search": AutoRegressiveBeamSearch,
"nucleus_sampling": AutoRegressiveNucleusSampling,
}

@classmethod
def from_config(cls, config: Config) -> nn.Module:
r"""
Create a model directly from config.
Parameters
----------
config: virtex.config.Config
Config object with all the parameters.
"""

_C = config
kwargs = {
"eos_index": _C.DATA.EOS_INDEX,
"max_steps": _C.MODEL.DECODER.MAX_DECODING_STEPS,
}
if _C.MODEL.DECODER.NAME == "beam_search":
kwargs["beam_size"] = _C.MODEL.DECODER.BEAM_SIZE
elif _C.MODEL.DECODER.NAME == "nucleus_sampling":
kwargs["nucleus_size"] = _C.MODEL.DECODER.NUCLEUS_SIZE

return cls.create(_C.MODEL.DECODER.NAME, **kwargs)


class OptimizerFactory(Factory):
r"""Factory to create optimizers. Possible choices: ``{"sgd", "adamw"}``."""

Expand Down

0 comments on commit 1cb3954

Please sign in to comment.