[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/itmorn/AI.handbook/blob/main/DL/torch/nn/LossFunction/CosineEmbeddingLoss.ipynb)

# CosineEmbeddingLoss
评价两个向量的相似度，标签y=1是，表示应该相似，y=-1时表示应该疏远。

$$\text{loss}(x, y) =
        \begin{cases}
        1 - \cos(x_1, x_2), & \text{if } y = 1 \\
        \max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1
\end{cases}$$


**定义**：  
torch.nn.CosineEmbeddingLoss(margin=0.0, size_average=None, reduce=None, reduction='mean')

**参数**:  
- margin (float, optional) – Should be a number from -1 to 1, 0 to 0.5 is suggested. If margin is missing, the default value is 0.

参考 CrossEntropyLoss部分

In [20]:
import torch
import torch.nn as nn
torch.manual_seed(666)
margin=1
loss = nn.CosineEmbeddingLoss(margin=margin)
input1 = torch.randn(3,5, requires_grad=True) #N,D
input2 = torch.randn(3,5, requires_grad=True)
target = torch.randn(3).sign()
print("input1:\n", input1, "\n")
print("input2:\n", input2, "\n")
print("target:\n", target, "\n")

output = loss(input1, input2, target)
print("output:\n", output, "\n")

prod_sum = (input1 * input2).sum(dim=1)
mag_square1 = (input1 * input1).sum(dim=1)
mag_square2 = (input2 * input2).sum(dim=1)
denom = (mag_square1 * mag_square2).sqrt()
cos = prod_sum / denom

zeros = torch.zeros_like(cos)
pos = 1 - cos
neg = (cos - margin).clip(min=0)
output_pos = torch.where(target == 1, pos, zeros)
output_neg = torch.where(target == -1, neg, zeros)
output = output_pos + output_neg
output.mean()  # 和调包结果一致 

input1:
 tensor([[-2.1188,  0.0635, -1.4555, -0.0126, -0.1548],
        [-0.0927,  2.5916,  0.4542, -0.6890, -0.9962],
        [ 0.1856,  0.1476,  0.8628,  0.2379, -0.5260]], requires_grad=True) 

input2:
 tensor([[-0.1043, -0.5187,  0.1231,  0.0755,  0.7091],
        [-1.0812, -0.6668, -0.8967,  0.7272,  1.4582],
        [-0.0018,  0.6660,  1.4064, -0.1019, -0.1370]], requires_grad=True) 

target:
 tensor([-1.,  1.,  1.]) 

output:
 tensor(0.5986, grad_fn=<MeanBackward0>) 



tensor([-0.6111,  0.8154], grad_fn=<IndexBackward0>)