In [1]:
#default_exp contrastive_loss

In [2]:
#export
from rsna_retro.imports import *

Loading imports


In [3]:
torch.cuda.set_device(1)

## Loss 

### Taken from here: https://github.com/adambielski/siamese-triplet/blob/master/losses.py

In [4]:
#export
class TripletLoss(nn.Module):
    """
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
    """

    def __init__(self, margin=0.5):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative=None, size_average=True):
        if negative is None: negative = positive.flip(dims=[0])
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()


## Sanity Checking Triple loss

In [5]:
tloss = TripletLoss(margin=0.5)

In [6]:
anch = torch.randn(4, 8) # bs x features
pos = anch + 0.1 # bs x feat
neg = pos.flip(dims=[0])
tloss(anch,pos,neg), tloss(anch,neg,pos)

(tensor(0.), tensor(10.6520))

In [7]:
tloss(anch,pos)

tensor(0.)

## Sanity Check Contrastive Loss

In [8]:
#export
# https://github.com/adambielski/siamese-triplet/blob/master/losses.py
class ContrastiveLoss(nn.Module):
    """
    Contrastive loss
    Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise
    """
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.eps = 1e-9

    def forward(self, output1, output2, target, size_average=True):
        distances = (output2 - output1).pow(2).sum(-1)  # squared distances
#         distances = F.pairwise_distance(output1, output2, keepdim=True).pow(2).sum(-1)  # squared distances
        losses = (target.float() * distances +
                        (1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
        return losses.mean() if size_average else losses.sum()

In [9]:
# Original code
# https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch/blob/master/Siamese-networks-medium.ipynb
class PairContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super().__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim = False)
        if len(euclidean_distance.shape)>1: euclidean_distance = euclidean_distance.sum(-1)
        c_pos = (label) * torch.pow(euclidean_distance, 2)
        c_neg = (1-label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        loss_contrastive = (c_neg + c_pos)
        return loss_contrastive.mean()

        return loss_contrastive

In [10]:
#export
# https://stackoverflow.com/questions/47107589/how-do-i-use-a-bytetensor-in-a-contrastive-cosine-loss-function
class CosineContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super().__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        cos_sim = F.cosine_similarity(output1, output2, dim=-1)
        loss_cos_con = torch.mean((label) * torch.div(torch.pow((1.0-cos_sim), 2), 4) +
                                    (1-label) * torch.pow(cos_sim * torch.lt(cos_sim, self.margin), 2))
        return loss_cos_con

In [11]:
def cosine_sim(output1, output2):
    num = output1.T @ output2
    denom = torch.norm(output1) * torch.norm(output2)
    out = num / denom
    return out

In [12]:
ls = {
    '1_contloss': ContrastiveLoss(),
    '2_pairloss': PairContrastiveLoss(),
    '3_cosloss': CosineContrastiveLoss(),
#     '4_xentloss': XentContrastiveLoss()
}

In [13]:
def loss_out(t,a,l):
    return {k:v(t,a,l) for k,v in ls.items()}

In [14]:
targ = torch.randn(4, 8) # bs x features
aug = targ + 0.1 # bs x feat
targ.shape, aug.shape

(torch.Size([4, 8]), torch.Size([4, 8]))

In [15]:
labels = torch.zeros(aug.shape[0])
labels[0] = 1

# single target to rest of batch
loss_out(targ[:1], aug, labels)

{'1_contloss': tensor(0.0200),
 '2_pairloss': tensor(0.0200),
 '3_cosloss': tensor(0.0552)}

In [16]:
# 1-to-1 match targ -> aug. Loss should be 8x higher
labels = torch.ones(aug.shape[0])
loss_out(targ, aug, labels)

{'1_contloss': tensor(0.0800),
 '2_pairloss': tensor(0.0800),
 '3_cosloss': tensor(1.8311e-05)}

In [17]:
# single targ -> all rand. Super High Loss
labels = torch.ones(aug.shape[0])
loss_out(targ[:1], aug, labels)

{'1_contloss': tensor(11.8284),
 '2_pairloss': tensor(11.8284),
 '3_cosloss': tensor(0.2118)}

### Batched contrastive loss

In [18]:
#export
def batched_labels(output1, output2, onehot=True):
    bs = output1.shape[0]
    rp = [1]*len(output1.shape)
    o1 = output1.repeat(*rp,bs).view(bs,*output1.shape)
    labels = torch.arange(bs, device=output1.device)
    if onehot: labels = torch.eye(o1.shape[0], device=output1.device)[labels]
    return o1, output2, labels

In [19]:
#export
def batched_labels(output1, output2, onehot=True):
    bs,feat = output1.shape
    o1 = output1.view(bs,1,feat)
    labels = torch.arange(bs, device=output1.device)
    if onehot: labels = torch.eye(o1.shape[0], device=output1.device)[labels]
    return o1, output2, labels

In [20]:
closs = ContrastiveLoss()
# closs = CosineContrastiveLoss()

In [21]:
bs,feat=6,18 # 
targ = torch.randn(bs, feat)*5 # bs x features
aug = targ+0.1 # bs x features
# aug = torch.randn(bs, feat) # bs x features
o1, o2, l = batched_labels(targ, aug)

In [22]:
o1.shape

torch.Size([6, 1, 18])

In [23]:
cs = F.cosine_similarity(o1, o2, dim=-1)
F.cross_entropy(cs, torch.arange(bs))

tensor(1.1464)

In [24]:
losses = []
for i in range(targ.shape[0]):
    bs = targ.shape[0]
    labels = torch.zeros(bs)
    labels[i] = 1 # set current target as the only positive label
    losses.append(closs(targ[i], aug, labels))
torch.stack(losses).mean()

tensor(0.0300)

In [25]:
closs(o1,o2,l)

tensor(0.0300)

In [26]:
# Sanity checking - setting wrong target as positive label
losses = []
for i in range(targ.shape[0]):
    bs = targ.shape[0]
    labels = torch.zeros(bs)
    labels[0] = 1 # set current target as the only positive label
    losses.append(closs(targ[i], aug, labels))
torch.stack(losses).mean()

tensor(118.5505)

In [27]:
l3 = l[torch.zeros(o1.shape[0]).long()]
closs(o1,o2,l3)

tensor(118.5505)

In [28]:
# Sanity checking - setting wrong target as positive label
losses = []
for i in range(targ.shape[0]):
    bs = targ.shape[0]
    labels = torch.ones(bs)
    labels[0] = 1 # set current target as the only positive label
    losses.append(closs(targ[i], aug, labels))
torch.stack(losses).mean()

tensor(767.6289)

In [29]:
l4 = torch.ones(o1.shape[0], o1.shape[0])
closs(o1,o2,l4)

tensor(767.6289)

## XentContrastiveLoss

In [30]:
#export
# https://arxiv.org/pdf/2002.05709.pdf
class XentOldContrastiveLoss(nn.Module):
    def __init__(self, temp=0.5):
        super().__init__()
        self.temp = temp

    def forward(self, output1, output2, labels):
        cos_sim = F.cosine_similarity(output1, output2, dim=-1)/self.temp
        xent_loss = F.cross_entropy(cos_sim,labels.long())
        return xent_loss

In [31]:
bs,feat=4,16 # 
targ = torch.randn(bs, feat) # bs x features
rand_aug = torch.randn(bs, feat) # bs x features
aug = targ + 0.1 # bs x feat
output1, output2, labels = batched_labels(targ, aug, True)

In [32]:
_, lmax = labels.max(dim=-1); lmax
xent_loss = XentOldContrastiveLoss(1.0)
xent_loss(output1, output2, lmax)

tensor(0.7361)

In [33]:
cos_sim = F.cosine_similarity(output1, output2, dim=-1)

In [34]:
lsoft = F.log_softmax(cos_sim, dim=-1); lsoft

tensor([[-0.7000, -2.2669, -1.4111, -1.8586],
        [-2.2112, -0.6351, -1.8216, -1.6156],
        [-1.5602, -2.0221, -0.8183, -1.5310],
        [-1.9654, -1.7637, -1.4475, -0.7911]])

#### Manual

In [35]:
lsoft1 = torch.log(torch.exp(cos_sim)/torch.exp(cos_sim).sum(dim=-1)); lsoft1

tensor([[-0.7000, -2.1996, -1.5289, -1.9389],
        [-2.2785, -0.6351, -2.0068, -1.7632],
        [-1.4423, -1.8369, -0.8183, -1.4934],
        [-1.8852, -1.6161, -1.4851, -0.7911]])

In [36]:
# nll loss
l_nll = torch.mean(torch.sum(-labels * lsoft1, dim=1)); l_nll

tensor(0.7361)

In [37]:
F.nll_loss(lsoft, lmax)

tensor(0.7361)

## Without Xent

In [38]:
cos_sim

tensor([[ 0.9973, -0.5696,  0.2862, -0.1613],
        [-0.5812,  0.9949, -0.1916,  0.0144],
        [ 0.2550, -0.2069,  0.9969,  0.2842],
        [-0.1879,  0.0139,  0.3300,  0.9865]])

In [39]:
cexp = torch.exp(cos_sim)
cexp

tensor([[2.7110, 0.5657, 1.3314, 0.8510],
        [0.5592, 2.7045, 0.8256, 1.0145],
        [1.2905, 0.8131, 2.7098, 1.3287],
        [0.8287, 1.0140, 1.3910, 2.6817]])

In [40]:
cexp.sum(dim=-1)

tensor([5.4592, 5.1039, 6.1420, 5.9154])

In [41]:
neg_denom = (cexp*(1-labels)).sum(dim=-1); neg_denom

tensor([2.7481, 2.3994, 3.4322, 3.2337])

In [42]:
pos = cos_sim[range(lmax.shape[0]),lmax]; pos

tensor([0.9973, 0.9949, 0.9969, 0.9865])

In [43]:
cexp.sum(dim=-1) - torch.exp(pos)

tensor([2.7481, 2.3994, 3.4322, 3.2337])

In [44]:
lsoft2 = torch.log(cexp/neg_denom); lsoft2

tensor([[-0.0136, -1.4448, -0.9470, -1.3350],
        [-1.5921,  0.1197, -1.4248, -1.1592],
        [-0.7559, -1.0821, -0.2363, -0.8895],
        [-1.1988, -0.8613, -0.9032, -0.1872]])

In [45]:
l_nll = torch.mean(torch.sum(-lmax * lsoft2, dim=1)); l_nll

tensor(5.2509)

## Test

In [46]:
temp = 1.0
cos_sim = F.cosine_similarity(output1, output2, dim=-1)/temp
cexp = torch.exp(cos_sim)
cexp

tensor([[2.7110, 0.5657, 1.3314, 0.8510],
        [0.5592, 2.7045, 0.8256, 1.0145],
        [1.2905, 0.8131, 2.7098, 1.3287],
        [0.8287, 1.0140, 1.3910, 2.6817]])

In [47]:
x = (cexp * labels).sum(dim=-1); x

tensor([2.7110, 2.7045, 2.7098, 2.6817])

In [48]:
(cexp*(1-labels)).sum(dim=-1)

tensor([2.7481, 2.3994, 3.4322, 3.2337])

In [49]:
denom = cexp.sum(dim=-1) - x; denom

tensor([2.7481, 2.3994, 3.4322, 3.2337])

In [50]:
lsoft = -torch.log(cexp/denom)
lsoft

tensor([[ 0.0136,  1.4448,  0.9470,  1.3350],
        [ 1.5921, -0.1197,  1.4248,  1.1592],
        [ 0.7559,  1.0821,  0.2363,  0.8895],
        [ 1.1988,  0.8613,  0.9032,  0.1872]])

In [51]:
-torch.log(x)+torch.log(denom)

tensor([ 0.0136, -0.1197,  0.2363,  0.1872])

In [52]:
neg_denom = (cexp*(1-labels)).sum(dim=-1); neg_denom
lsoft1 = torch.log(cexp/neg_denom)
lsoft2 = torch.sum(-labels * lsoft1, dim=-1)
lsoft2

tensor([ 0.0136, -0.1197,  0.2363,  0.1872])

## Metric Learning

In [53]:

def sim_mat(x, y=None):
    """
    returns a matrix where entry (i,j) is the dot product of x[i] and x[j]
    """
    if y is None:
        y = x
    return torch.matmul(x, y.t())

def convert_to_pairs(indices_tuple, labels):
    """
    This returns anchor-positive and anchor-negative indices,
    regardless of what the input indices_tuple is
    Args:
        indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices
                        within a batch
        labels: a tensor which has the label for each element in a batch
    """
    if indices_tuple is None:
        return get_all_pairs_indices(labels)
    elif len(indices_tuple) == 4:
        return indices_tuple
    else:
        a, p, n = indices_tuple
        return a, p, a, n

def get_all_pairs_indices(labels, ref_labels=None):
    """
    Given a tensor of labels, this will return 4 tensors.
    The first 2 tensors are the indices which form all positive pairs
    The second 2 tensors are the indices which form all negative pairs
    """
    if ref_labels is None:
        ref_labels = labels
    labels1 = labels.unsqueeze(1)
    labels2 = ref_labels.unsqueeze(0)
    matches = (labels1 == labels2).byte()
    diffs = matches ^ 1
    if ref_labels is labels:
        matches -= torch.eye(matches.size(0)).byte().to(labels.device)
    a1_idx = matches.nonzero()[:, 0].flatten()
    p_idx = matches.nonzero()[:, 1].flatten()
    a2_idx = diffs.nonzero()[:, 0].flatten()
    n_idx = diffs.nonzero()[:, 1].flatten()
    return a1_idx, p_idx, a2_idx, n_idx



In [54]:
class NTXentLoss2(nn.Module):

    def __init__(self, temperature, **kwargs):
        super().__init__(**kwargs)
        self.temperature = temperature
        self.normalize_embeddings = True

    def forward(self, embeddings, labels, indices_tuple):
        cosine_similarity = sim_mat(embeddings)
        if not self.normalize_embeddings:
            embedding_norms_mat = self.embedding_norms.unsqueeze(0)*self.embedding_norms.unsqueeze(1)
            cosine_similarity = cosine_similarity / (embedding_norms_mat)
        cosine_similarity = cosine_similarity / self.temperature

        a1, p, a2, n = convert_to_pairs(indices_tuple, labels)

        if len(a1) > 0 and len(a2) > 0:
            pos_pairs = cosine_similarity[a1, p].unsqueeze(1)
            neg_pairs = cosine_similarity[a2, n]
            n_per_p = (a2.unsqueeze(0) == a1.unsqueeze(1)).float()
            neg_pairs = neg_pairs*n_per_p
            neg_pairs[n_per_p==0] = float('-inf')

            max_val = torch.max(pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0])
            numerator = torch.exp(pos_pairs - max_val).squeeze(1)
            denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator
            log_exp = torch.log((numerator/denominator) + 1e-20)
            return torch.mean(-log_exp)
        return 0

In [55]:
ml = losses.NTXentLoss(1.0)

AttributeError: 'list' object has no attribute 'NTXentLoss'

In [56]:
stacked = torch.cat((targ, rand_aug), dim=0)

In [57]:
labels = torch.arange(targ.shape[0]).repeat(2)

In [58]:
get_all_pairs_indices(labels)

(tensor([0, 1, 2, 3, 4, 5, 6, 7]),
 tensor([4, 5, 6, 7, 0, 1, 2, 3]),
 tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
         4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7]),
 tensor([1, 2, 3, 5, 6, 7, 0, 2, 3, 4, 6, 7, 0, 1, 3, 4, 5, 7, 0, 1, 2, 4, 5, 6,
         1, 2, 3, 5, 6, 7, 0, 2, 3, 4, 6, 7, 0, 1, 3, 4, 5, 7, 0, 1, 2, 4, 5, 6]))

In [3]:
#export
from pytorch_metric_learning import losses
class XentLoss(losses.NTXentLoss):
    def forward(self, output1, output2):
        stacked = torch.cat((output1, output2), dim=0)
        labels = torch.arange(output1.shape[0]).repeat(2)
        return super().forward(stacked, labels, None)

In [77]:
ml(stacked, labels, None)

tensor(1.8078)

## Put it into code

In [61]:
#export
# https://arxiv.org/pdf/2002.05709.pdf
class XentContrastiveLoss(nn.Module):
    def __init__(self, temp=0.5):
        super().__init__()
        self.temp = temp

    def forward(self, output1, output2, labels):
        cos_sim = F.cosine_similarity(output1, output2, dim=-1)/self.temp
        cexp = torch.exp(cos_sim)
        neg_denom = (cexp*(1-labels)).sum(dim=-1)
        lsoft = torch.log(cexp/neg_denom)
        lsoft = torch.sum(-labels * lsoft, dim=-1)
        print(lsoft)
        return lsoft.mean()

In [201]:
#export
class XentContrastiveLoss2(nn.Module):
    def __init__(self, temp=0.5):
        super().__init__()
        self.temp = temp

    def forward(self, output1, output2, labels):
        cos_sim = F.cosine_similarity(output1, output2, dim=-1)/self.temp
        cexp = torch.exp(cos_sim)
        x = (cexp * labels).sum(dim=-1)
        denom = cexp.sum(dim=-1) - x
        lsoft = -torch.log(x/denom)
        print(lsoft)
        return lsoft.mean()

## Batch

In [63]:
#export
class BatchContrastiveLoss(nn.Module):
    def __init__(self, loss_func):
        super().__init__()
        self.loss_func = loss_func
        self.onehot = not isinstance(loss_func, XentOldContrastiveLoss)
        
    def forward(self, output1, output2):
        output1, output2, labels = batched_labels(output1, output2, self.onehot)
        return self.loss_func(output1, output2, labels)

In [78]:
temp = 0.1
ls2 = {
#     '1_contloss': ContrastiveLoss(margin=1.0),
    '4_xentloss': XentContrastiveLoss(temp=temp),
    '5_xentloss2': XentContrastiveLoss2(temp=temp),
    '6_oldxentloss': XentOldContrastiveLoss(temp=temp)
}

def batch_loss_out(t,a):
    return {k:BatchContrastiveLoss(v)(t,a) for k,v in ls2.items()}

In [79]:
bs,feat=16,128 # 
targ = torch.randn(bs, feat) # bs x features
rand_aug = torch.randn(bs, feat) # bs x features
aug = targ + 0.05 # bs x feat

In [98]:
batch_loss_out(targ,targ)

{'4_xentloss': tensor(-6.8875),
 '5_xentloss2': tensor(-6.8876),
 '6_oldxentloss': tensor(0.0010)}

In [81]:
batch_loss_out(targ,rand_aug)

{'4_xentloss': tensor(3.4285),
 '5_xentloss2': tensor(3.4285),
 '6_oldxentloss': tensor(3.4732)}

In [77]:
bs,feat=256,256 # 
targ = torch.randn(bs, feat) # bs x features
aug = targ+0.1 # bs x features
rand_aug = torch.randn(bs, feat) # bs x features
batch_loss_out(targ,-aug)

{'4_xentloss': tensor(7.5412),
 '5_xentloss2': tensor(7.5412),
 '6_oldxentloss': tensor(7.5418)}

## Export

In [4]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_metadata.ipynb.
Converted 01_preprocess.ipynb.
Converted 01_preprocess_mean_std.ipynb.
Converted 02_train.ipynb.
Converted 02_train_01_save_features.ipynb.
Converted 03_train3d.ipynb.
Converted 04_trainfull3d_deprecated.ipynb.
Converted 04_trainfull3d_labels.ipynb.
Converted 05_train_adjacent.ipynb.
Converted 06_seutao_features.ipynb.
Converted 07_adni.ipynb.
This cell doesn't have an export destination and was ignored:
e
This cell doesn't have an export destination and was ignored:
#
This cell doesn't have an export destination and was ignored:
e
Converted 07_adni_01.ipynb.
Converted 08_contrastive_loss-Copy1.ipynb.
Converted 08_contrastive_loss.ipynb.
Converted 08_imagewang.ipynb.
Converted 08_train_self_supervised.ipynb.
Converted 08_train_self_supervised_train_1.ipynb.
Converted 08_train_self_supervised_train_2_nocombined.ipynb.
Converted 08_train_self_supervised_train_2_nocombined_contrast.ipynb.
Converted 08_train_self_supervised_train_3.ipynb.
Converted 08_train_sel