diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py index 0dc72882eb..61b3eb4cda 100755 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -446,7 +446,7 @@ def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **kw similarity_matrix = torch.cat(logits_list, dim=1) # temperature scaling and CE similarity_matrix = similarity_matrix / temperature - loss = nn.CrossEntropyLoss()(similarity_matrix, labels) / world_size # avoid duplicate + loss = nn.CrossEntropyLoss()(similarity_matrix, labels) else: all_tensors = [] for tensor in split_tensors: @@ -499,7 +499,6 @@ def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **kw # next positive is neg+1 length += tensor.size(0) - 1 loss /= len(split_tensors) - loss /= world_size # avoid duplicate return loss