In [1]:
%load_ext autoreload
%autoreload 2

In [7]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F

from htools import *

In [5]:
cd_root()

Current directory: /Users/hmamin/img_wang


In [85]:
def contrastive_loss(x1, x2, y, m=1., p=2, reduction='mean'):
    """
    # TODO: find out what a reasonable value for m (margin) is.
    
    Note: 
    
    http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    
    Parameters
    ----------
    x1: torch.Tensor
        Shape (bs, n_features).
    x2: torch.Tensor
        Shape (bs, n_features).
    y: torch.Tensor
        Labels. Unlike the paper, we use the convention that a label of 1 
        means images are similar. This is consistent with all our existing
        datasets and just feels more intuitive.
    m: float
        Margin that prevents dissimilar pairs from affecting the loss unless
        they are sufficiently far apart. I believe the reasonable range of
        values depends on the size of the feature dimension.
    p: int
        The p that determines the p-norm used to calculate the initial 
        distance measure between x1 and x2. The default of 2 therefore uses
        euclidean distance.
    reduction: str
        One of ('sum', 'mean', 'none'). Standard pytorch loss reduction. Keep
        in mind 'none' will probably not allow backpropagation since it
        returns a rank 2 tensor.
        
    Returns
    -------
    torch.Tensor: Scalar measuring the contrastive loss. If no reduction is
    applied, this will instead be a tensor of shape (bs,).
    """
    reduction = identity if reduction == 'none' else getattr(torch, reduction)
    dw = F.pairwise_distance(x1, x2, p, keepdim=True) 
    # Loss_similar + Loss_different
    res = y*dw.pow(p).div(2) + (1-y)*torch.clamp_min(m-dw, 0).pow(p).div(2)
    return reduction(res)

In [89]:
class ContrastiveLoss(nn.Module):
    
    def __init__(self, m=1., p=2, reduction='mean'):
        super().__init__()
        self.m = m
        self.p = p
        self.reduction = reduction
        self.loss = partial(contrastive_loss, m=m, p=p, reduction=reduction)
        
    def forward(self, x1, x2, y_true):
        return self.loss(x1, x2, y_true)

In [94]:
bs = 2
x1 = torch.randn(bs, 5)
x2 = torch.randn(bs, 5)
x1[0] += torch.arange(0, 100, 20)
x1[1] -= 50
x2[0] += torch.arange(0, 100, 20)
x2[1] += 25

print(x1)
print(x2)

tensor([[  1.8909,  20.2486,  39.7307,  60.3099,  81.7360],
        [-50.2483, -49.1922, -50.7428, -50.1480, -48.9504]])
tensor([[ 0.9450, 19.3297, 39.5396, 59.4475, 79.7256],
        [24.4308, 24.5980, 26.0160, 25.4062, 22.6329]])


In [95]:
y = torch.tensor([1, 0]).unsqueeze(-1)
y

tensor([[1],
        [0]])

In [96]:
loss = ContrastiveLoss()
loss(x1, x2, y)

tensor(1.6403)

In [97]:
contrastive_loss(x1, x2, y, m=1, reduction='mean')

tensor(1.6403)

In [99]:
# Looks like I'll need to make some adjustments if I want this to work well
# with non-binary targets.
y_reg = torch.tensor([.8, .2]).unsqueeze(-1)
loss(x1, x2, y_reg)

tensor(1388.6364)

In [100]:
F.cosine_similarity(x1, x2)

tensor([ 0.9999, -0.9994])

In [101]:
F.pairwise_distance(x1, x2)

tensor([  2.5615, 166.5728])

In [118]:
# Not exactly sure what target is supposed to be.
F.cosine_embedding_loss(x1, x2, torch.tensor(1), reduction='none')

tensor([7.2360e-05, 1.9994e+00])

## TODO

- confirm good value of margin
- try alternate formula I found using softmax and cosine distance
- try to make code work with non-binary targets (proba instead of 0/1)
- think about how this might work for my problem where I have 3 pairs per row (if a row contains x_new, x1, x2, x3, we have x_new:x1, x_new:x2, and x_new:x3)