In [181]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [182]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F

from htools import *
from img_wang.models import ClassificationHead

In [183]:
cd_root()

Current directory: /Users/hmamin/img_wang


In [85]:
def contrastive_loss(x1, x2, y, m=1., p=2, reduction='mean'):
    """
    # TODO: find out what a reasonable value for m (margin) is.
    
    Note: 
    
    http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    
    Parameters
    ----------
    x1: torch.Tensor
        Shape (bs, n_features).
    x2: torch.Tensor
        Shape (bs, n_features).
    y: torch.Tensor
        Labels. Unlike the paper, we use the convention that a label of 1 
        means images are similar. This is consistent with all our existing
        datasets and just feels more intuitive.
    m: float
        Margin that prevents dissimilar pairs from affecting the loss unless
        they are sufficiently far apart. I believe the reasonable range of
        values depends on the size of the feature dimension.
    p: int
        The p that determines the p-norm used to calculate the initial 
        distance measure between x1 and x2. The default of 2 therefore uses
        euclidean distance.
    reduction: str
        One of ('sum', 'mean', 'none'). Standard pytorch loss reduction. Keep
        in mind 'none' will probably not allow backpropagation since it
        returns a rank 2 tensor.
        
    Returns
    -------
    torch.Tensor: Scalar measuring the contrastive loss. If no reduction is
    applied, this will instead be a tensor of shape (bs,).
    """
    reduction = identity if reduction == 'none' else getattr(torch, reduction)
    dw = F.pairwise_distance(x1, x2, p, keepdim=True) 
    # Loss_similar + Loss_different
    res = y*dw.pow(p).div(2) + (1-y)*torch.clamp_min(m-dw, 0).pow(p).div(2)
    return reduction(res)

In [325]:
class ContrastiveLoss1d(nn.Module):
    
    def __init__(self, m=1., p=2, reduction='mean'):
        super().__init__()
        self.m = m
        self.p = p
        self.reduction = reduction
        self.loss = partial(contrastive_loss, m=m, p=p, reduction=reduction)
        
    def forward(self, x1, x2, y_true):
        return self.loss(x1, x2, y_true)

In [288]:
bs = 2
x1 = torch.randn(bs, 5)
x2 = torch.randn(bs, 5)
x1[0] += torch.arange(0, 100, 20)
x1[1] -= 50
x2[0] += torch.arange(0, 100, 20)
x2[1] += 25

print(x1)
print(x2)

tensor([[ -1.4772,  20.1164,  41.4341,  59.0182,  78.2542],
        [-49.2596, -50.2114, -50.6976, -49.3883, -49.9065]])
tensor([[-2.3021e-02,  1.9816e+01,  3.8695e+01,  5.8490e+01,  7.9823e+01],
        [ 2.3965e+01,  2.4299e+01,  2.4448e+01,  2.5086e+01,  2.3895e+01]])


In [289]:
y = torch.tensor([1, 0]).unsqueeze(-1)
y

tensor([[1],
        [0]])

In [326]:
loss = ContrastiveLoss1d(reduction='mean')
loss(x1, x2, y)

tensor(3.1123)

In [327]:
loss = ContrastiveLoss1d(reduction='none')
loss(x1, x2, y)

tensor([[6.2247],
        [0.0000]])

In [218]:
contrastive_loss(x1, x2, y, m=1, reduction='none')

tensor([[3.2806],
        [0.0000]])

In [324]:
class ContrastiveLoss2d(nn.Module):
    
    def __init__(self, m=1., p=2, reduction='mean'):
        super().__init__()
        self.m = m
        self.p = p
        self.loss = partial(contrastive_loss, m=m, p=p, reduction='none')
        
        if reduction == 'none':
            self.reduction = identity
        elif reduction == 'row':
            self.reduction = partial(torch.sum, dim=-1)
        else:
            self.reduction = getattr(torch, reduction)
        
    def forward(self, x1, x2, y_true):
        # x1 has shape (bs, feats). x2 has shape (bs, n_item, n_feats).
        # I.E. we're comparing 1 image to `n_item` variants.
        # y_true has shape (bs, n_item).
        # Basically multi-label classification with OHE labels.
        # Output is scalar if reduction is 'mean' or 'sum', same shape as y
        # if reduction is 'none', or shape (bs,) if reduction is 'row'.
        bs, n, dim = x2.shape
        res = self.loss(x1.repeat_interleave(n, dim=0), 
                        x2.view(-1, dim),
                        y_true.view(-1, 1))
        return self.reduction(res.view(bs, -1))

In [317]:
noise = torch.rand(2, 3, 5) * 10
noise[0, 0] /= 100
noise[-1, -1] += 500
x3 = x2[:, None, ...] + noise
print(x3.shape)
x3

torch.Size([2, 3, 5])


tensor([[[7.2080e-02, 1.9909e+01, 3.8723e+01, 5.8513e+01, 7.9828e+01],
         [4.2575e+00, 2.8752e+01, 4.4914e+01, 6.1409e+01, 8.9757e+01],
         [8.2325e+00, 2.3834e+01, 3.9332e+01, 5.8804e+01, 8.2293e+01]],

        [[2.5705e+01, 3.2525e+01, 2.4480e+01, 3.1325e+01, 2.7094e+01],
         [2.4503e+01, 2.9017e+01, 2.6627e+01, 2.6928e+01, 2.9855e+01],
         [5.2564e+02, 5.2743e+02, 5.2514e+02, 5.3094e+02, 5.3027e+02]]])

In [318]:
y2d = torch.tensor([[1, 0, 1],
                    [1, 1, 0]])
y2d

tensor([[1, 0, 1],
        [1, 1, 0]])

In [319]:
x1.shape, x3.shape, y2d.shape

(torch.Size([2, 5]), torch.Size([2, 3, 5]), torch.Size([2, 3]))

In [320]:
loss2d = ContrastiveLoss2d(reduction='row')
res = loss2d(x1, x3, y2d)
res

tensor([   70.6990, 30221.8262])

In [321]:
loss2d = ContrastiveLoss2d(reduction='none')
res = loss2d(x1, x3, y2d)
res

tensor([[6.2623e+00, 0.0000e+00, 6.4437e+01],
        [1.5280e+04, 1.4942e+04, 0.0000e+00]])

In [322]:
loss2d = ContrastiveLoss2d(reduction='mean')
res = loss2d(x1, x3, y2d)
res

tensor(5048.7544)

In [323]:
loss2d = ContrastiveLoss2d(reduction='sum')
res = loss2d(x1, x3, y2d)
res

tensor(30292.5254)

In [280]:
res.view(x1.shape[0], -1)

tensor([[3.1057e+00, 0.0000e+00, 1.2657e+02],
        [1.6185e+04, 1.5624e+04, 0.0000e+00]])

In [257]:
x1

tensor([[  1.8909,  20.2486,  39.7307,  60.3099,  81.7360],
        [-50.2483, -49.1922, -50.7428, -50.1480, -48.9504]])

In [225]:
x3

tensor([[[  0.9451,  19.4059,  39.5839,  59.4560,  79.7724],
         [ 10.8960,  27.8966,  42.8544,  65.2645,  89.4107],
         [  2.1223,  27.9709,  48.9245,  67.1623,  89.6087]],

        [[ 32.8121,  28.1083,  27.9012,  32.6898,  31.3549],
         [ 26.2665,  33.0588,  29.6756,  30.8534,  25.8869],
         [525.0332, 527.2878, 530.2631, 529.9833, 531.0088]]])

In [227]:
x3.view(-1, x3.shape[-1])

tensor([[  0.9451,  19.4059,  39.5839,  59.4560,  79.7724],
        [ 10.8960,  27.8966,  42.8544,  65.2645,  89.4107],
        [  2.1223,  27.9709,  48.9245,  67.1623,  89.6087],
        [ 32.8121,  28.1083,  27.9012,  32.6898,  31.3549],
        [ 26.2665,  33.0588,  29.6756,  30.8534,  25.8869],
        [525.0332, 527.2878, 530.2631, 529.9833, 531.0088]])

In [237]:
x1

tensor([[  1.8909,  20.2486,  39.7307,  60.3099,  81.7360],
        [-50.2483, -49.1922, -50.7428, -50.1480, -48.9504]])

In [252]:
x1.repeat_interleave(x3.shape[1], dim=0).shape

torch.Size([6, 5])

In [263]:
y2d.view(-1, 1)

tensor([[1],
        [0],
        [1],
        [1],
        [1],
        [0]])

In [254]:
x3.view(-1, x3.shape[-1]).shape

torch.Size([6, 5])

In [224]:
x3.shape

torch.Size([2, 3, 5])

In [223]:
F.pairwise_distance(x1[:, None, :], x3, keepdim=True)

torch.Size([2, 1, 5])

In [99]:
# Looks like I'll need to make some adjustments if I want this to work well
# with non-binary targets.
y_reg = torch.tensor([.8, .2]).unsqueeze(-1)
loss(x1, x2, y_reg)

tensor(1388.6364)

In [100]:
F.cosine_similarity(x1, x2)

tensor([ 0.9999, -0.9994])

In [101]:
F.pairwise_distance(x1, x2)

tensor([  2.5615, 166.5728])

In [118]:
# Not exactly sure what target is supposed to be.
F.cosine_embedding_loss(x1, x2, torch.tensor(1), reduction='none')

tensor([7.2360e-05, 1.9994e+00])

## TODO

- [] confirm good value of margin
- [X] try alternate formula I found using softmax and cosine distance
- [] try to make code work with non-binary targets (proba instead of 0/1)
- [] think about how this might work for my problem where I have 3 pairs per row (if a row contains x_new, x1, x2, x3, we have x_new:x1, x_new:x2, and x_new:x3)
- [] think how this will work with incendio (requires x to be passed in). Maybe just use another library? Or could I make this a layer, like my SimilarityHead below? Would probably need to separate the functionality between the layer and loss func, as I did in SimilarityHead: need to investigate what loss function would be appropriate in that case.

## CosineSimilarityHead

For other contrastive loss variant, I think it's simplest to do the cosine similarity and temperature scaling in the classification head and then leave log softmax for the loss function.

In [173]:
from img_wang.models import SmoothLogSoftmax, SmoothSoftmax

In [172]:
torch.exp(SmoothLogSoftmax()(x1)).sum(-1)

tensor([1., 1.])

In [178]:
out = SmoothSoftmax()(x1)
assert torch.isclose(out.sum(-1), 
                     torch.ones(x1.shape[0])).all()
assert torch.isclose(torch.exp(SmoothLogSoftmax()(x1)), out).all()

In [206]:
class SimilarityHead(ClassificationHead):
    
    def __init__(self, similarity=None, temperature='auto'):
        """This should be used with nn.NLLLoss since it ends with a 
        log_softmax operation.
        """
        super().__init__(last_act='log_softmax', temperature=temperature)
        self.similarity = similarity or nn.CosineSimilarity(dim=-1)
        warnings.warn('Remember to use nn.NLLLoss when using SimilarityHead.')
        
    def _forward(self, x_new, x_stack):
        return self.similarity(x_new[:, None, ...], x_stack)

In [193]:
x1.shape, x3.shape

(torch.Size([2, 5]), torch.Size([2, 3, 5]))

In [201]:
F.kl_div(x1, x3[:, 0, :], reduction='none').sum(-1)

tensor([-11278.1543,   8144.6191])

In [204]:
x1

tensor([[  1.8909,  20.2486,  39.7307,  60.3099,  81.7360],
        [-50.2483, -49.1922, -50.7428, -50.1480, -48.9504]])

In [205]:
F.kl_div(x1[:1], x1[:1], reduction='none')

tensor([[-2.3709e+00, -3.4910e+02, -1.4322e+03, -3.3900e+03, -6.3209e+03]])

In [185]:
head = SimilarityHead()
head

SimilarityHead(
  (last_act): SmoothLogSoftmax(
    (act): LogSoftmax(dim=-1)
  )
  (similarity): CosineSimilarity()
)

In [186]:
x1

tensor([[  1.8909,  20.2486,  39.7307,  60.3099,  81.7360],
        [-50.2483, -49.1922, -50.7428, -50.1480, -48.9504]])

In [189]:
res = head(x1, x3)
res

tensor([[-1.0976, -1.0997, -1.0985],
        [-1.0985, -1.0974, -1.0999]])

In [190]:
torch.exp(res)

tensor([[0.3337, 0.3330, 0.3334],
        [0.3334, 0.3337, 0.3329]])