From d6716d6f55feb3fb8837b589f8100bb1e902f40e Mon Sep 17 00:00:00 2001 From: Karan Desai Date: Thu, 15 Jul 2021 11:40:38 -0700 Subject: [PATCH 1/3] Fix beam search bug, thanks @ashkamath and @alcinos --- virtex/utils/beam_search.py | 116 ++++++++++++------------------------ 1 file changed, 39 insertions(+), 77 deletions(-) diff --git a/virtex/utils/beam_search.py b/virtex/utils/beam_search.py index df65571f..c41fd636 100644 --- a/virtex/utils/beam_search.py +++ b/virtex/utils/beam_search.py @@ -3,23 +3,27 @@ `AllenNLP `_. Thanks to the developers of AllenNLP! + +**Update (v1.2):** The "backpointer" trick in Beam Search (as implemented in +AllenNLP) does not work well with autoregressive models (transformers). It is +now removed and it improves qualitative predictions and captioning metrics +(CIDEr/SPICE) for VirTex. Updated captioning results are on ArXiv v3. Refer +`CHANGELOG `_ and +`Release Page `_ for more +details. + +Huge thanks to Nicolas Carion (@alcinos) and Aishwarya Kamath (@ashkamath) for +helping me fix this bug! """ -from typing import Callable, Dict, List, Tuple +from typing import Callable, Tuple import warnings import torch -# Short names for commonly annotated types. -StateType = Dict[str, torch.Tensor] -StepFunctionType = Callable[..., torch.Tensor] - - class AutoRegressiveBeamSearch(object): r""" Implements the beam search algorithm for decoding the most likely captions. - This only works for auto-regressive models (Transformer-like) and not - recurrent models (LSTM-like). Parameters ---------- @@ -50,7 +54,7 @@ def __init__( self.per_node_beam_size = per_node_beam_size or beam_size def search( - self, start_predictions: torch.Tensor, step: StepFunctionType + self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Given a starting state and a step function, apply beam search to find @@ -76,11 +80,10 @@ def search( step : Callable[..., torch.Tensor] A function that is responsible for computing the next most likely tokens, given the past predictions. Predictions from all previous - time-steps are required, not just the last time-step, because our - model is auto-regressive instead of recurrent. The function should - The function is expected to return a tensor of shape - ``(group_size, target_vocab_size)`` containing - the log probs of the tokens for the next step. + time-steps are required, not just the last time-step. The function + should The function is expected to return a tensor of shape + ``(group_size, target_vocab_size)`` containing the token log probs + for the next step. Returns ------- @@ -89,17 +92,15 @@ def search( has shape ``(batch_size, beam_size, max_steps)`` and ``log_probs`` has shape ``(batch_size, beam_size)``. """ + batch_size = start_predictions.size()[0] - # List of `(batch_size, beam_size)` tensors. One for each time step. + # List of `(batch_size, beam_size, length)` tensors. # Does not include the start symbols, which are implicit. - predictions: List[torch.Tensor] = [] - - # List of (batch_size, beam_size) tensors. One for each time step. None - # for the first. Stores the index n for the parent prediction, i.e. - # predictions[t-1][i][n], that it came from. - backpointers: List[torch.Tensor] = [] - + predictions: torch.Tensor = torch.empty( + (batch_size, self.beam_size, 0), + dtype=torch.long, device=start_predictions.device + ) # Calculate the first timestep. This is done outside the main loop # because we are going from a single decoder input (the output from the # encoder) to the top `beam_size` decoder outputs. On the other hand, @@ -111,14 +112,6 @@ def search( num_classes = start_class_log_probs.size()[1] - # Make sure `per_node_beam_size` is not larger than `num_classes`. - if self.per_node_beam_size > num_classes: - raise ValueError( - f"Target vocab size ({num_classes:d}) too small " - f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n" - f"Please decrease beam_size or per_node_beam_size." - ) - # shape: (batch_size, beam_size), (batch_size, beam_size) start_top_log_probs, start_predicted_classes = start_class_log_probs.topk( self.beam_size @@ -138,8 +131,8 @@ def search( # shape: (batch_size, beam_size) last_log_probs = start_top_log_probs - # shape: [(batch_size, beam_size)] - predictions.append(start_predicted_classes) + # shape: (batch_size, beam_size, sequence_length) + predictions = torch.cat([predictions, start_predicted_classes.unsqueeze(-1)], dim=-1) # Log probability tensor that mandates that the end token is selected. # shape: (batch_size * beam_size, num_classes) @@ -150,17 +143,16 @@ def search( for timestep in range(self.max_steps - 1): # shape: (batch_size * beam_size,) - last_predictions = predictions[-1].reshape(batch_size * self.beam_size) + last_predictions = predictions[:, :, -1].reshape(batch_size * self.beam_size) # If every predicted token from the last step is `self._end_index`, # then we can stop early. if (last_predictions == self._end_index).all(): break - # Take a step. This get the predicted log probs of the next classes. - predictions_so_far = torch.stack(predictions).permute(1, 2, 0).view( + predictions_so_far = predictions.view( batch_size * self.beam_size, -1 - ) + ) # shape: (batch_size * beam_size, num_classes) class_log_probs = step(predictions_so_far) @@ -203,31 +195,26 @@ def search( reshaped_predicted_classes = predicted_classes.reshape( batch_size, self.beam_size * self.per_node_beam_size ) + # Append the predictions to the current beam. + reshaped_beam = ( + predictions.view(batch_size * self.beam_size, 1, -1) + .repeat(1, self.per_node_beam_size, 1) + .reshape(batch_size, self.beam_size * self.per_node_beam_size, -1) + ) + reshaped_beam = torch.cat([reshaped_beam, reshaped_predicted_classes.unsqueeze(-1)], dim=-1) + # Keep only the top `beam_size` beam indices. # shape: (batch_size, beam_size), (batch_size, beam_size) restricted_beam_log_probs, restricted_beam_indices = reshaped_summed.topk( self.beam_size ) - # Use the beam indices to extract the corresponding classes. - # shape: (batch_size, beam_size) - restricted_predicted_classes = reshaped_predicted_classes.gather( - 1, restricted_beam_indices + predictions = reshaped_beam.gather( + 1, restricted_beam_indices.unsqueeze(-1).repeat(1,1,reshaped_beam.shape[-1]) ) - predictions.append(restricted_predicted_classes) # shape: (batch_size, beam_size) last_log_probs = restricted_beam_log_probs - # The beam indices come from a `beam_size * per_node_beam_size` - # dimension where the indices with a common ancestor are grouped - # together. Hence dividing by `per_node_beam_size` gives the - # ancestor. (Note that this is integer division as the tensor is a - # LongTensor.) - # shape: (batch_size, beam_size) - backpointer = restricted_beam_indices // self.per_node_beam_size - - backpointers.append(backpointer) - if not torch.isfinite(last_log_probs).all(): warnings.warn( "Infinite log probs encountered. Some final captions may not " @@ -237,29 +224,4 @@ def search( RuntimeWarning, ) - # Reconstruct the captions. - # shape: [(batch_size, beam_size, 1)] - reconstructed_predictions = [predictions[-1].unsqueeze(2)] - - # shape: (batch_size, beam_size) - cur_backpointers = backpointers[-1] - - for timestep in range(len(predictions) - 2, 0, -1): - # shape: (batch_size, beam_size, 1) - cur_preds = ( - predictions[timestep].gather(1, cur_backpointers).unsqueeze(2) - ) - reconstructed_predictions.append(cur_preds) - - # shape: (batch_size, beam_size) - cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers) - - # shape: (batch_size, beam_size, 1) - final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2) - - reconstructed_predictions.append(final_preds) - - # shape: (batch_size, beam_size, max_steps) - all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2) - - return all_predictions, last_log_probs + return predictions, last_log_probs \ No newline at end of file From a27b0a44a8b55c24738b58497f79ccb8e42553ab Mon Sep 17 00:00:00 2001 From: Karan Desai Date: Thu, 15 Jul 2021 11:42:09 -0700 Subject: [PATCH 2/3] Add nucleus sampling and CaptionDecoderFactory. --- configs/_base_bicaptioning_R_50_L1_H1024.yaml | 6 + virtex/config.py | 14 ++ virtex/factories.py | 47 ++++++- virtex/models/captioning.py | 95 ++++++------- virtex/utils/beam_search.py | 106 ++++++++------- virtex/utils/nucleus_sampling.py | 127 ++++++++++++++++++ 6 files changed, 289 insertions(+), 106 deletions(-) create mode 100644 virtex/utils/nucleus_sampling.py diff --git a/configs/_base_bicaptioning_R_50_L1_H1024.yaml b/configs/_base_bicaptioning_R_50_L1_H1024.yaml index ab40b92b..02036dbd 100644 --- a/configs/_base_bicaptioning_R_50_L1_H1024.yaml +++ b/configs/_base_bicaptioning_R_50_L1_H1024.yaml @@ -35,14 +35,20 @@ DATA: MODEL: NAME: "virtex" + VISUAL: NAME: "torchvision::resnet50" PRETRAINED: false FROZEN: false + TEXTUAL: NAME: "transdec_postnorm::L1_H1024_A16_F4096" DROPOUT: 0.1 + DECODER: + NAME: "beam_search" + BEAM_SIZE: 5 + OPTIM: OPTIMIZER_NAME: "sgd" SGD_MOMENTUM: 0.9 diff --git a/virtex/config.py b/virtex/config.py index 9e69b89c..3bfc4ab5 100644 --- a/virtex/config.py +++ b/virtex/config.py @@ -158,6 +158,20 @@ def __init__( # Dropout probability for embedding, hidden features in textual head. _C.MODEL.TEXTUAL.DROPOUT = 0.1 + _C.MODEL.DECODER = CN() + # What algorithm to use for decoding. Supported values: {"beam_search", + # "nucleus_sampling"}. + _C.MODEL.DECODER.NAME = "beam_search" + # Number of beams to decode (1 = greedy decoding). Ignored when decoding + # through nucleus sampling. + _C.MODEL.DECODER.BEAM_SIZE = 5 + # Size of nucleus for sampling predictions. Ignored when decoding through + # beam search. + _C.MODEL.DECODER.NUCLEUS_SIZE = 0.9 + # Maximum length of decoded caption. Decoding may end earlier when [EOS] + # token is sampled. + _C.MODEL.DECODER.MAX_DECODING_STEPS = _C.DATA.MAX_CAPTION_LENGTH + # --------------------------------------------------------------------- # Optimization hyper-parameters, default values are for pretraining # our best model on bicaptioning task (COCO Captions). diff --git a/virtex/factories.py b/virtex/factories.py index 522f69d0..b3ec554b 100644 --- a/virtex/factories.py +++ b/virtex/factories.py @@ -18,21 +18,24 @@ signature of underlying class; or config hierarchy. Refer description of specific factories for more details. """ -from functools import partial import re +from functools import partial from typing import Any, Callable, Dict, Iterable, List import albumentations as alb from torch import nn, optim -from virtex.config import Config import virtex.data as vdata +import virtex.models as vmodels +from virtex.config import Config from virtex.data import transforms as T from virtex.data.tokenizers import SentencePieceBPETokenizer -import virtex.models as vmodels from virtex.modules import visual_backbones, textual_heads from virtex.optim import Lookahead, lr_scheduler +from virtex.utils.beam_search import AutoRegressiveBeamSearch +from virtex.utils.nucleus_sampling import AutoRegressiveNucleusSampling + class Factory(object): r""" @@ -460,9 +463,9 @@ def from_config(cls, config: Config) -> nn.Module: # for matching kwargs here. if _C.MODEL.NAME in {"virtex", "captioning", "bicaptioning"}: kwargs = { - "max_decoding_steps": _C.DATA.MAX_CAPTION_LENGTH, "sos_index": _C.DATA.SOS_INDEX, "eos_index": _C.DATA.EOS_INDEX, + "decoder": CaptionDecoderFactory.from_config(_C), } elif _C.MODEL.NAME == "token_classification": @@ -482,6 +485,42 @@ def from_config(cls, config: Config) -> nn.Module: return cls.create(_C.MODEL.NAME, visual, textual, **kwargs) +class CaptionDecoderFactory(Factory): + r""" + Factory to create decoders from predicting captions from VirTex model. + + Possible choices: ``{"beam_search", "nucleus_sampling"}``. + """ + + PRODUCTS: Dict[str, Callable] = { + "beam_search": AutoRegressiveBeamSearch, + "nucleus_sampling": AutoRegressiveNucleusSampling, + } + + @classmethod + def from_config(cls, config: Config) -> nn.Module: + r""" + Create a model directly from config. + + Parameters + ---------- + config: virtex.config.Config + Config object with all the parameters. + """ + + _C = config + kwargs = { + "eos_index": _C.DATA.EOS_INDEX, + "max_steps": _C.MODEL.DECODER.MAX_DECODING_STEPS, + } + if _C.MODEL.DECODER.NAME == "beam_search": + kwargs["beam_size"] = _C.MODEL.DECODER.BEAM_SIZE + elif _C.MODEL.DECODER.NAME == "nucleus_sampling": + kwargs["nucleus_size"] = _C.MODEL.DECODER.NUCLEUS_SIZE + + return cls.create(_C.MODEL.DECODER.NAME, **kwargs) + + class OptimizerFactory(Factory): r"""Factory to create optimizers. Possible choices: ``{"sgd", "adamw"}``.""" diff --git a/virtex/models/captioning.py b/virtex/models/captioning.py index 0093e714..9753f22c 100644 --- a/virtex/models/captioning.py +++ b/virtex/models/captioning.py @@ -4,12 +4,10 @@ import torch from torch import nn -from torch.nn import functional as F from virtex.data.tokenizers import SentencePieceBPETokenizer from virtex.modules.textual_heads import TextualHead from virtex.modules.visual_backbones import VisualBackbone -from virtex.utils.beam_search import AutoRegressiveBeamSearch class CaptioningModel(nn.Module): @@ -31,10 +29,6 @@ class CaptioningModel(nn.Module): textual: virtex.modules.textual_heads.TextualHead A :class:`~virtex.modules.textual_heads.TextualHead` which makes final predictions conditioned on visual features. - beam_size : int, optional (default = 5) - The width of the beam used for beam search. - max_decoding_steps: int, optional (default = 30) - The maximum number of decoding steps for beam search. sos_index: int, optional (default = 1) The index of the end token (``[SOS]``) in vocabulary. eos_index: int, optional (default = 2) @@ -44,17 +38,20 @@ class CaptioningModel(nn.Module): ``False`` -- only forward captioning is performed. When ``True``, a clone of textual head is created, which does not share weights with "forward" model except input and output embeddings. + decoder: Any, optional (default = None) + An instance of :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch` + or :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling` + for decoding captions during inference (unused during training). """ def __init__( self, visual: VisualBackbone, textual: TextualHead, - beam_size: int = 5, - max_decoding_steps: int = 30, + caption_backward: bool = False, sos_index: int = 1, eos_index: int = 2, - caption_backward: bool = False, + decoder: Any = None, ): super().__init__() self.visual = visual @@ -75,17 +72,14 @@ def __init__( # These boundary indices are needed for beam search. self.sos_index = sos_index self.eos_index = eos_index - self.beam_search = AutoRegressiveBeamSearch( - self.eos_index, beam_size=beam_size, max_steps=max_decoding_steps - ) + self.beam_search = decoder self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx) def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: r""" Given a batch of images and captions, compute log likelihood loss per - caption token during training. During inference, given a batch of - images, decode the most likely caption in forward direction through - beam search decoding. + caption token during training. During inference (with images), predict + a caption through either beam search decoding or nucleus sampling. Parameters ---------- @@ -140,9 +134,7 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: backward_caption_tokens = batch["noitpac_tokens"] backward_output_logits = self.backward_textual( - visual_features, - backward_caption_tokens, - caption_lengths, + visual_features, backward_caption_tokens, caption_lengths ) backward_loss = self.loss( backward_output_logits[:, :-1] @@ -159,35 +151,41 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: if not self.training: # During validation (while pretraining), get best prediction - # at every time-step. + # at every timestep. output_dict["predictions"] = torch.argmax(output_logits, dim=-1) else: + if self.decoder is None: + raise ValueError("Decoder for predicting captions is missing!") + # During inference, get beam search predictions for forward # model. Predictions from forward transformer will be shifted - # right by one time-step. + # right by one timestep. start_predictions = visual_features.new_full( (batch_size,), self.sos_index ).long() # Add image features as a default argument to match callable # signature accepted by beam search class (partial captions only). - beam_search_step = functools.partial( - self.beam_search_step, visual_features - ) - all_top_k_predictions, _ = self.beam_search.search( - start_predictions, beam_search_step + decoding_step = functools.partial(self.decoding_step, visual_features) + + predicted_caption, _ = self.decoder.search( + start_predictions, decoding_step ) - best_beam = all_top_k_predictions[:, 0, :] - output_dict = {"predictions": best_beam} + output_dict = {"predictions": predicted_caption} return output_dict - def beam_search_step( + def decoding_step( self, visual_features: torch.Tensor, partial_captions: torch.Tensor ) -> torch.Tensor: r""" Given visual features and a batch of (assumed) partial captions, predict - the distribution over vocabulary tokens for next time-step. This method - is used by :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch`. + the logits over output vocabulary tokens for next timestep. This method + is used by :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch` + and :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling`. + + .. note:: + + For nucleus sampling, ``beam_size`` will always be 1 (not relevant). Parameters ---------- @@ -202,8 +200,8 @@ def beam_search_step( Returns ------- torch.Tensor - A tensor of shape ``(batch_size * beam_size, vocab_size)`` -- output - distribution over tokens for next time-step. + A tensor of shape ``(batch_size * beam_size, vocab_size)`` -- logits + over output vocabulary tokens for next timestep. """ # Expand and repeat image features while doing beam search. @@ -222,26 +220,13 @@ def beam_search_step( if len(caption_lengths.size()) == 2: caption_lengths = caption_lengths.sum(1) else: - # Add a time-step. shape: (batch_size, 1) + # Add a timestep. shape: (batch_size, 1) partial_captions = partial_captions.unsqueeze(1) # shape: (batch_size * beam_size, partial_caption_length, vocab_size) - output_logits = self.textual( - visual_features, partial_captions, caption_lengths - ) - # Keep features for last time-step only, we only care about those. - output_logits = output_logits[:, -1, :] - - # Return logprobs as required by `AutoRegressiveBeamSearch`. - # shape: (batch_size * beam_size, vocab_size) - next_logprobs = F.log_softmax(output_logits, dim=1) - - # Set logprobs of last predicted tokens as high negative value to avoid - # repetition in caption. - for index in range(batch_size * beam_size): - next_logprobs[index, partial_captions[index, -1]] = -10000 - - return next_logprobs + logits = self.textual(visual_features, partial_captions, caption_lengths) + # Return logits from the last timestep. + return logits[:, -1, :] def log_predictions( self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer @@ -272,19 +257,17 @@ def __init__( self, visual: VisualBackbone, textual: TextualHead, - beam_size: int = 5, - max_decoding_steps: int = 30, sos_index: int = 1, eos_index: int = 2, + decoder: Any = None, ): super().__init__( visual, textual, - beam_size=beam_size, - max_decoding_steps=max_decoding_steps, sos_index=sos_index, eos_index=eos_index, caption_backward=False, + decoder=decoder, ) @@ -298,19 +281,17 @@ def __init__( self, visual: VisualBackbone, textual: TextualHead, - beam_size: int = 5, - max_decoding_steps: int = 30, sos_index: int = 1, eos_index: int = 2, + decoder: Any = None, ): super().__init__( visual, textual, - beam_size=beam_size, - max_decoding_steps=max_decoding_steps, sos_index=sos_index, eos_index=eos_index, caption_backward=True, + decoder=decoder, ) diff --git a/virtex/utils/beam_search.py b/virtex/utils/beam_search.py index c41fd636..d369dda5 100644 --- a/virtex/utils/beam_search.py +++ b/virtex/utils/beam_search.py @@ -19,6 +19,7 @@ import warnings import torch +from torch.nn import functional as F class AutoRegressiveBeamSearch(object): @@ -27,7 +28,7 @@ class AutoRegressiveBeamSearch(object): Parameters ---------- - end_index: int + eos_index: int The index of the end token (``[EOS]``) in vocabulary. max_steps: int, optional (default = 50) The maximum number of decoding steps. @@ -43,34 +44,26 @@ class AutoRegressiveBeamSearch(object): def __init__( self, - end_index: int, + eos_index: int, max_steps: int = 50, beam_size: int = 5, per_node_beam_size: int = 2, ) -> None: - self._end_index = end_index + self._eos_index = eos_index self.max_steps = max_steps self.beam_size = beam_size self.per_node_beam_size = per_node_beam_size or beam_size def search( - self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor] + self, + start_predictions: torch.Tensor, + step: Callable[..., torch.Tensor], + only_return_best: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Given a starting state and a step function, apply beam search to find the most likely target captions. - .. note:: - - If your step function returns ``-inf`` for some log probs - (like if you're using a masked log-softmax) then some of the "best" - captions returned may have ``-inf`` log probability. Specifically - this happens when the beam size is smaller than the number of actions - with finite log probability (non-zero probability) returned by the - step function. Therefore if you're using a mask you may want to - check the results from ``search`` and potentially discard captions - with non-finite log probability. - Parameters ---------- start_predictions : torch.Tensor @@ -80,16 +73,20 @@ def search( step : Callable[..., torch.Tensor] A function that is responsible for computing the next most likely tokens, given the past predictions. Predictions from all previous - time-steps are required, not just the last time-step. The function - should The function is expected to return a tensor of shape - ``(group_size, target_vocab_size)`` containing the token log probs - for the next step. + timesteps are required, not just the last timestep. The function is + expected to return a tensor of shape ``(group_size, target_vocab_size)`` + containing the token logits for the next step. + only_return_best: bool, optional (default = True) + Whether to only return the best beam (with highest logprobs). Set this + to ``False`` to return all the beams. If this is ``True``, then the + returned tensor is of shape ``(batch_size, sequence_length)``, else + will be ``(batch_size, beam_size, sequence_length)``. Returns ------- Tuple[torch.Tensor, torch.Tensor] - Tuple of ``(predictions, log_probs)``, where ``predictions`` - has shape ``(batch_size, beam_size, max_steps)`` and ``log_probs`` + Tuple of ``(predictions, logprobs)``, where ``predictions`` + has shape ``(batch_size, beam_size, max_steps)`` and ``logprobs`` has shape ``(batch_size, beam_size)``. """ @@ -108,53 +105,66 @@ def search( # beam to `beam_size`^2 candidates from which we will select the top # `beam_size` elements for the next iteration. # shape: (batch_size, num_classes) - start_class_log_probs = step(start_predictions) + start_class_logits = step(start_predictions) + + # Convert logits to logprobs. + # shape: (batch_size * beam_size, vocab_size) + start_class_logprobs = F.log_softmax(start_class_logits, dim=1) - num_classes = start_class_log_probs.size()[1] + num_classes = start_class_logprobs.size()[1] # shape: (batch_size, beam_size), (batch_size, beam_size) - start_top_log_probs, start_predicted_classes = start_class_log_probs.topk( + start_top_logprobs, start_predicted_classes = start_class_logprobs.topk( self.beam_size ) if ( self.beam_size == 1 - and (start_predicted_classes == self._end_index).all() + and (start_predicted_classes == self._eos_index).all() ): warnings.warn( "Empty captions predicted. You may want to increase beam " "size or ensure your step function is working properly.", RuntimeWarning, ) - return start_predicted_classes.unsqueeze(-1), start_top_log_probs + return start_predicted_classes.unsqueeze(-1), start_top_logprobs # The log probs for the last time step. # shape: (batch_size, beam_size) - last_log_probs = start_top_log_probs + last_logprobs = start_top_logprobs # shape: (batch_size, beam_size, sequence_length) predictions = torch.cat([predictions, start_predicted_classes.unsqueeze(-1)], dim=-1) # Log probability tensor that mandates that the end token is selected. # shape: (batch_size * beam_size, num_classes) - log_probs_after_end = start_class_log_probs.new_full( + logprobs_after_end = start_class_logprobs.new_full( (batch_size * self.beam_size, num_classes), float("-inf") ) - log_probs_after_end[:, self._end_index] = 0.0 + logprobs_after_end[:, self._eos_index] = 0.0 for timestep in range(self.max_steps - 1): # shape: (batch_size * beam_size,) last_predictions = predictions[:, :, -1].reshape(batch_size * self.beam_size) - # If every predicted token from the last step is `self._end_index`, + # If every predicted token from the last step is `self._eos_index`, # then we can stop early. - if (last_predictions == self._end_index).all(): + if (last_predictions == self._eos_index).all(): break predictions_so_far = predictions.view( batch_size * self.beam_size, -1 ) # shape: (batch_size * beam_size, num_classes) - class_log_probs = step(predictions_so_far) + class_logits = step(predictions_so_far) + + # Convert logits to logprobs. + # shape: (batch_size * beam_size, vocab_size) + class_logprobs = F.log_softmax(class_logits, dim=1) + + # Set logprobs of last predicted tokens as high negative value to avoid + # repetition in caption. + for index in range(batch_size * self.beam_size): + class_logprobs[index, predictions_so_far[index, -1]] = -10000 # shape: (batch_size * beam_size, num_classes) last_predictions_expanded = last_predictions.unsqueeze(-1).expand( @@ -165,13 +175,13 @@ def search( # one-hot distribution, forcing the beam to predict the end token # this timestep as well. # shape: (batch_size * beam_size, num_classes) - cleaned_log_probs = torch.where( - last_predictions_expanded == self._end_index, - log_probs_after_end, - class_log_probs, + cleaned_logprobs = torch.where( + last_predictions_expanded == self._eos_index, + logprobs_after_end, + class_logprobs, ) # shape (both): (batch_size * beam_size, per_node_beam_size) - top_log_probs, predicted_classes = cleaned_log_probs.topk( + top_logprobs, predicted_classes = cleaned_logprobs.topk( self.per_node_beam_size ) # Here we expand the last log probs to `(batch_size * beam_size, @@ -179,16 +189,16 @@ def search( # probs for this timestep. This lets us maintain the log # probability of each element on the beam. # shape: (batch_size * beam_size, per_node_beam_size) - expanded_last_log_probs = ( - last_log_probs.unsqueeze(2) + expanded_last_logprobs = ( + last_logprobs.unsqueeze(2) .expand(batch_size, self.beam_size, self.per_node_beam_size) .reshape(batch_size * self.beam_size, self.per_node_beam_size) ) # shape: (batch_size * beam_size, per_node_beam_size) - summed_top_log_probs = top_log_probs + expanded_last_log_probs + summed_top_logprobs = top_logprobs + expanded_last_logprobs # shape: (batch_size, beam_size * per_node_beam_size) - reshaped_summed = summed_top_log_probs.reshape( + reshaped_summed = summed_top_logprobs.reshape( batch_size, self.beam_size * self.per_node_beam_size ) # shape: (batch_size, beam_size * per_node_beam_size) @@ -205,7 +215,7 @@ def search( # Keep only the top `beam_size` beam indices. # shape: (batch_size, beam_size), (batch_size, beam_size) - restricted_beam_log_probs, restricted_beam_indices = reshaped_summed.topk( + restricted_beam_logprobs, restricted_beam_indices = reshaped_summed.topk( self.beam_size ) predictions = reshaped_beam.gather( @@ -213,9 +223,9 @@ def search( ) # shape: (batch_size, beam_size) - last_log_probs = restricted_beam_log_probs + last_logprobs = restricted_beam_logprobs - if not torch.isfinite(last_log_probs).all(): + if not torch.isfinite(last_logprobs).all(): warnings.warn( "Infinite log probs encountered. Some final captions may not " "make sense. This can happen when the beam size is larger than" @@ -224,4 +234,10 @@ def search( RuntimeWarning, ) - return predictions, last_log_probs \ No newline at end of file + # Optionally select best beam and its logprobs. + if only_return_best: + # shape: (batch_size, sequence_length) + predictions = predictions[:, 0, :] + last_logprobs = last_logprobs[:, 0] + + return predictions, last_logprobs diff --git a/virtex/utils/nucleus_sampling.py b/virtex/utils/nucleus_sampling.py new file mode 100644 index 00000000..50b6f02c --- /dev/null +++ b/virtex/utils/nucleus_sampling.py @@ -0,0 +1,127 @@ +r""" +Nucleus Sampling was introduced in the paper +`The Curious Case of Neural Text Degeneration `_. +If you take it from here, make sure to cite them: + +.. code-block:: text + + @inproceedings{, + title={The Curious Case of Neural Text Degeneration}, + author={Ari Holtzman and Jan Buys and Li Du and Maxwell Forbes and Yejin Choi}, + journal={ICLR}, + year={2020} + } + +Some core parts of this code are adapted with minor modifications from Thomas Wolf's +gist: https://gist.githubusercontent.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 +""" + +from typing import Callable, List, Tuple + +import torch +import torch.nn.functional as F + + +class AutoRegressiveNucleusSampling(object): + r""" + Implements the nucleus sampling for decoding captions. This class only works + for auto-regressive models (Transformer-like), not recurrent models (LSTM-like). + + Parameters + ---------- + eos_index: int + The index of the end token (``[EOS]``) in vocabulary. + max_steps: int, optional (default = 50) + The maximum number of decoding steps. + nucleus_size: float, optional (default = 0.9) + Size of top-K nucleus for sampling. + """ + + def __init__( + self, + eos_index: int, + max_steps: int = 50, + nucleus_size: float = 0.9, + ): + super().__init__() + self._eos_index = eos_index + self.max_steps = max_steps + self.nucleus_size = nucleus_size + + def search( + self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor] + ) -> Tuple[torch.Tensor, None]: + + batch_size = start_predictions.size()[0] + + # List of `(batch_size, )` tensors. One for each timestep. + # This includes the start-of-sentence tokens, unlike the implementation + # in `AutoregressiveBeamSearch`. We will remove them in the end. + predictions: List[torch.Tensor] = [start_predictions] + + for timestep in range(self.max_steps): + # Get the predictions from last timestep (most recent). + # shape: (batch_size, ) + last_predictions = predictions[-1] + + # If every predicted token from the last step is end-of-sentence token, + # then we can stop early. + if (last_predictions == self._eos_index).all(): + break + + # Combine step predictions made so far into one tensor. This is our + # "partial" caption input to the transformer. + # shape: (batch_size, timestep + 1) + predictions_so_far = torch.stack(predictions).permute(1, 0) + + # Take a step, get the distribution of logits from next timestep. + # shape: (batch_size, num_classes) + current_logits = step(predictions_so_far) + + # Sort logits in descending order to determine the nucleus. + sorted_logits, sorted_idx = torch.sort(current_logits, descending=True) + + # Get cumulative softmax probabilites. For every instance in batch, a + # variable amount of tokens (N) will consitute the nucleus. + # shape: (batch_size, num_classes) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Determine indices of tokens at the tail of distribution. These will be + # removed from the nucleus. + sorted_idx_to_remove = cumulative_probs > self.nucleus_size + + # Shift the indices to the right to keep the first token outside nucleus. + sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone() + sorted_idx_to_remove[..., 0] = 0 + + # Set logits to large negative value to avoid sampling them. Iterate over + # the batch of examples. + for t in range(current_logits.size()[0]): + idx_to_remove = sorted_idx[t][sorted_idx_to_remove[t]] + current_logits[t][idx_to_remove] = -1e12 + + # Set logits for last predicted token to a large negative value to + # avoid repetition. + current_logits[t][last_predictions[t]] = -1e12 + + # Sample from the filtered distribution. + # shape: (batch_size, num_classes) + current_probs = F.softmax(current_logits, dim=-1) + + # shape: (batch_size, ) + current_predictions = torch.multinomial(current_probs, 1) + current_predictions = current_predictions.view(batch_size) + + # Set current predicted tokens to be end-of-sentence for instances where + # last prediction was also end-of-sentence token. + current_predictions[last_predictions == self._eos_index] = self._eos_index + + predictions.append(current_predictions) + + # Remove start-of-sentence token from predictions, and collect them together. + # shape: (batch_size, max_steps) .. or could be less than max_steps. + all_predictions = torch.stack(predictions[1:]).permute(1, 0) + + # We don't return any logprobs of generated sequence with nucleus sampling, + # unlike `AutoregressiveBeamSearch`. + return all_predictions, None From f33ef5a178ebc72a80c5381d625fd348b96b376e Mon Sep 17 00:00:00 2001 From: Karan Desai Date: Thu, 15 Jul 2021 14:39:54 -0700 Subject: [PATCH 3/3] Update version and add CHANGELOG. --- CHANGELOG.md | 21 ++++++++++++++++++--- docs/conf.py | 2 +- scripts/eval_captioning.py | 6 +++++- setup.py | 2 +- virtex/models/captioning.py | 2 +- 5 files changed, 26 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e54814c..24b885b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,22 @@ -ArXiv v1 -> v2 CHANGELOG -========================= +CHANGELOG +========= -[ArXiv v1](https://arxiv.org/abs/2006.06666v1) was our ECCV 2020 submission (reject). [ArXiv v2](https://arxiv.org/abs/2006.06666v2) is out CVPR 2021 submission (accept). The repository snapshots for these two versions are tagged at [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9) and [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0). +This CHANGELOG file records changes between different arXiv versions of our paper, and the version of this codebase which should be used to reproduce the results in the corresponding arXiv version. View changes between code versions on the [Releases page](https://github.com/kdexd/virtex/releases). + +ArXiv v1 -> v2 +============== + +**Code version:** `v1.2`. + +Fix image captioning results with a modified beam search implementation. _Rest of the downstream task results and pre-trained models are unchanged._ + + +ArXiv v1 -> v2 +============== + +**Code version:** `v1.0` or `v1.1`. + +[ArXiv v1](https://arxiv.org/abs/2006.06666v1) was our ECCV 2020 submission (reject). [ArXiv v2](https://arxiv.org/abs/2006.06666v2) is our CVPR 2021 submission (accept). The repository snapshots for these two versions are tagged at [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9) and [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0). While the core motivation and approach is the same, we have made some minor changes in our experiments and evaluation setup. These slightly improve model performances across the board (within decimals). New models are available in [`v1.0` model zoo](http://kdexd.github.io/virtex/virtex/usage/model_zoo.html), however links to old models in `v0.9` will be active till June 30, 2021. We encourage you to use the new models! diff --git a/docs/conf.py b/docs/conf.py index fdd9cafe..0e1a3518 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,7 +24,7 @@ author = "Karan Desai" # The full version, including alpha/beta/rc tags -release = "1.1" +release = "1.2" # -- General configuration --------------------------------------------------- diff --git a/scripts/eval_captioning.py b/scripts/eval_captioning.py index 8da98284..b50b9b30 100644 --- a/scripts/eval_captioning.py +++ b/scripts/eval_captioning.py @@ -21,7 +21,7 @@ evaluate pretrained model on COCO Captions val2017 split.""" ) parser.add_argument( - "--data-root", default=None, + "--images", "--data-root", default=None, help="""Path to a directory containing image files to generate captions for. Default: COCO val2017 image directory as expected relative to project root.""" ) @@ -89,6 +89,10 @@ def main(_A: argparse.Namespace): } ) + logger.info("Displaying first 25 caption predictions:") + for pred in predictions[:25]: + logger.info(f"{pred['image_id']} :: {pred['caption']}") + # Save predictions as a JSON file if specified. if _A.output is not None: os.makedirs(os.path.dirname(_A.output), exist_ok=True) diff --git a/setup.py b/setup.py index fc715695..ca9986ca 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def get_model_zoo_configs() -> List[str]: setup( name="virtex", - version="1.1.0", + version="1.2.0", author="Karan Desai and Justin Johnson", description="VirTex: Learning Visual Representations with Textual Annotations", package_data={"virtex.model_zoo": get_model_zoo_configs()}, diff --git a/virtex/models/captioning.py b/virtex/models/captioning.py index 9753f22c..87b97db5 100644 --- a/virtex/models/captioning.py +++ b/virtex/models/captioning.py @@ -72,7 +72,7 @@ def __init__( # These boundary indices are needed for beam search. self.sos_index = sos_index self.eos_index = eos_index - self.beam_search = decoder + self.decoder = decoder self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx) def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: