From 6043e0998c7a2702ab195aa3941aefd3aee732d5 Mon Sep 17 00:00:00 2001 From: Deepak Kumar <34968705+dk25021999@users.noreply.github.com> Date: Tue, 31 Aug 2021 14:35:05 +0530 Subject: [PATCH 1/2] Update vilbert.py Optional ITM loss in pre-training added. --- mmf/models/vilbert.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mmf/models/vilbert.py b/mmf/models/vilbert.py index 1b9b3c44d..adf0a51f1 100644 --- a/mmf/models/vilbert.py +++ b/mmf/models/vilbert.py @@ -10,6 +10,7 @@ import torch.nn.functional as F from mmf.common.registry import registry from mmf.models import BaseModel +from mmf.models.transformers.heads.itm import ITM from mmf.modules.hf_layers import replace_with_jit from mmf.utils.configuration import get_mmf_cache_dir from mmf.utils.modeling import get_optimizer_parameters_for_bert @@ -1061,6 +1062,7 @@ def __init__(self, config): self.visual_target = config.visual_target self.num_negative = config.num_negative self.loss_fct = CrossEntropyLoss(ignore_index=-1) + if self.visual_target == 0: self.vis_criterion = nn.KLDivLoss(reduction="none") @@ -1099,6 +1101,8 @@ def forward( image_label: Optional[Tensor] = None, image_target: Optional[Tensor] = None, output_all_attention_masks: bool = False, + itm_loss: bool = False, + next_sentence_label: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, ) -> Dict[str, Tensor]: masked_img_loss: Optional[Tensor] = None ( @@ -1226,6 +1230,14 @@ def forward( prediction_scores_t.view(-1, self.vocab_size), masked_lm_labels.view(-1) ) output["masked_lm_loss"] = masked_lm_loss.unsqueeze(0) + + if itm_loss is not False: + itm_head = ITM({"type": "itm", "hidden_size": self.vocab_size}) + seq_output = torch.cat(sequence_output_t, sequence_output_v) + multimodal_alignment_loss = itm_head(seq_output, processed_sample_list = next_sentence_label) + if multimodal_alignment_loss is not None: + output["itm_loss"] = multimodal_alignment_loss["losses"]["itm_loss"] + # next_sentence_loss = self.loss_fct( # seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) # ) From 218057265a3fc175f656b5ebe8fb44ef5ccca2e9 Mon Sep 17 00:00:00 2001 From: Deepak Kumar <34968705+dk25021999@users.noreply.github.com> Date: Wed, 1 Sep 2021 09:56:10 +0530 Subject: [PATCH 2/2] Update vilbert.py ITM head initialization under init instead of forward pass. --- mmf/models/vilbert.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmf/models/vilbert.py b/mmf/models/vilbert.py index adf0a51f1..07c76c628 100644 --- a/mmf/models/vilbert.py +++ b/mmf/models/vilbert.py @@ -1063,6 +1063,8 @@ def __init__(self, config): self.num_negative = config.num_negative self.loss_fct = CrossEntropyLoss(ignore_index=-1) + if itm_loss is not False: + itm_head = ITM({"type": "itm", "hidden_size": self.vocab_size}) if self.visual_target == 0: self.vis_criterion = nn.KLDivLoss(reduction="none") @@ -1232,7 +1234,6 @@ def forward( output["masked_lm_loss"] = masked_lm_loss.unsqueeze(0) if itm_loss is not False: - itm_head = ITM({"type": "itm", "hidden_size": self.vocab_size}) seq_output = torch.cat(sequence_output_t, sequence_output_v) multimodal_alignment_loss = itm_head(seq_output, processed_sample_list = next_sentence_label) if multimodal_alignment_loss is not None: