In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [21]:
x = torch.rand(5,2)
x[3:] += 1
x

tensor([[0.0417, 0.8189],
        [0.0397, 0.6720],
        [0.4841, 0.6323],
        [1.7131, 1.6262],
        [1.2009, 1.8970]])

In [25]:
preds = x.argmax(1)
preds

tensor([1, 1, 1, 0, 1])

In [40]:
y_eq_preds = preds.clone()
y_neq_preds = 1 - preds.clone()
y_overlap1_preds = torch.tensor([0,0,0,1,1])
y_overlap2_preds = torch.tensor([0,0,0,0,1])
print(preds)
print(y_eq_preds)
print(y_neq_preds)
print(y_overlap1_preds)
print(y_overlap2_preds)

tensor([1, 1, 1, 0, 1])
tensor([1, 1, 1, 0, 1])
tensor([0, 0, 0, 1, 0])
tensor([0, 0, 0, 1, 1])
tensor([0, 0, 0, 0, 1])


In [41]:
make_sim_labels = lambda y: (preds == y).long() * 2 - 1

### Cosine embedding loss against the mean embedding for the class

In [27]:
mean_x_by_class = torch.stack([x[y==i].mean(0) for i in y.unique()])
mean_x_by_class

tensor([[0.1885, 0.7077],
        [1.4570, 1.7616]])

In [28]:
pred_embeddings = mean_x_by_class.index_select(0, preds)
y_embeddings = mean_x_by_class.index_select(0, y)
print(pred_embeddings)
print(y_embeddings)

tensor([[1.4570, 1.7616],
        [1.4570, 1.7616],
        [1.4570, 1.7616],
        [0.1885, 0.7077],
        [1.4570, 1.7616]])
tensor([[0.1885, 0.7077],
        [0.1885, 0.7077],
        [0.1885, 0.7077],
        [1.4570, 1.7616],
        [1.4570, 1.7616]])


In [62]:
# actual embeddings vs mean embeddings by class (where true labels have different degree of overlap with the predictions)
for y_ in [y_eq_preds, y_neq_preds, y_overlap1_preds, y_overlap2_preds]:
    print(F.cosine_embedding_loss(x, mean_x_by_class.index_select(0, y_), make_sim_labels(y_), reduction='none'))

tensor([0.1980, 0.1931, 0.0007, 0.1480, 0.0080])
tensor([0.9782, 0.9798, 0.9237, 0.9928, 0.9541])
tensor([0.9782, 0.9798, 0.9237, 0.9928, 0.0080])
tensor([0.9782, 0.9798, 0.9237, 0.1480, 0.0080])


In [63]:
# pred_embeddings vs mean embeddings by class
for y_ in [y_eq_preds, y_neq_preds, y_overlap1_preds, y_overlap2_preds]:
    print(F.cosine_embedding_loss(pred_embeddings, mean_x_by_class.index_select(0, y_), make_sim_labels(y_), reduction='none'))

tensor([0., 0., 0., 0., 0.])
tensor([0.9087, 0.9087, 0.9087, 0.9087, 0.9087])
tensor([0.9087, 0.9087, 0.9087, 0.9087, 0.0000])
tensor([0.9087, 0.9087, 0.9087, 0.0000, 0.0000])


In [67]:
F.cosine_embedding_loss(pred_embeddings, mean_x_by_class.index_select(0, y_), make_sim_labels(y_), reduction='sum')

tensor(2.7260)

### Hinge embedding loss

In [61]:
# actual embeddings vs mean embeddings by class
for y_ in [y_eq_preds, y_neq_preds, y_overlap1_preds, y_overlap2_preds]:
    print(F.hinge_embedding_loss((x - mean_x_by_class.index_select(0, y_)).norm(2, dim=1), make_sim_labels(y_), reduction='mean'))

tensor(1.4097)
tensor(0.6136)
tensor(0.6716)
tensor(0.8855)


In [55]:
(x - mean_x_by_class.index_select(0, y_)).norm(2, dim=1)

tensor([0.1841, 0.1530, 0.3050, 1.7799, 0.2897])

In [50]:
make_sim_labels(y_)

tensor([-1, -1, -1,  1,  1])