From 0d98c2e32aeeab5c062bda837ad4ec4e4da0d402 Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 22 Dec 2022 20:52:27 +0900 Subject: [PATCH 1/2] add ignore_index --- src/open_clip/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 9f26dd4f8..0ff1d440a 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -126,7 +126,7 @@ def __init__( self, caption_loss_weight, clip_loss_weight, - pad_id=-100, + pad_id=0, # pad_token for open_clip custom tokenizer local_loss=False, gather_with_grad=False, cache_labels=False, @@ -145,7 +145,7 @@ def __init__( self.clip_loss_weight = clip_loss_weight self.caption_loss_weight = caption_loss_weight - self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id, ignore_index=pad_id) def forward(self, image_features, text_features, logits, labels, logit_scale): clip_loss = super().forward(image_features, text_features, logit_scale) From 9b01d1d2b9438cdcf38ee72673aac3fadc7bb331 Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 22 Dec 2022 20:58:34 +0900 Subject: [PATCH 2/2] just need to pick right index --- src/open_clip/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 0ff1d440a..555cf545d 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -145,7 +145,7 @@ def __init__( self.clip_loss_weight = clip_loss_weight self.caption_loss_weight = caption_loss_weight - self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id, ignore_index=pad_id) + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) def forward(self, image_features, text_features, logits, labels, logit_scale): clip_loss = super().forward(image_features, text_features, logit_scale)