In [1]:
#default_exp contrastive_loss

In [2]:
#export
from rsna_retro.imports import *

Loading imports


In [3]:
torch.cuda.set_device(1)

## Loss 

### Taken from here: https://github.com/adambielski/siamese-triplet/blob/master/losses.py

In [4]:
#export
class TripletLoss(nn.Module):
    """
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
    """

    def __init__(self, margin=0.5):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative=None, size_average=True):
        if negative is None: negative = positive.flip(dims=[0])
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()


## Sanity Checking Triple loss

In [5]:
tloss = TripletLoss(margin=0.5)

In [6]:
anch = torch.randn(4, 8) # bs x features
pos = anch + 0.1 # bs x feat
neg = pos.flip(dims=[0])
tloss(anch,pos,neg), tloss(anch,neg,pos)

(tensor(0.), tensor(24.8184))

In [7]:
tloss(anch,pos)

tensor(0.)

## Sanity Check Contrastive Loss

In [8]:
#export
# https://github.com/adambielski/siamese-triplet/blob/master/losses.py
class ContrastiveLoss(nn.Module):
    """
    Contrastive loss
    Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise
    """
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.eps = 1e-9

    def forward(self, output1, output2, target, size_average=True):
        distances = (output2 - output1).pow(2).sum(-1)  # squared distances
#         distances = F.pairwise_distance(output1, output2, keepdim=True).pow(2).sum(-1)  # squared distances
        losses = (target.float() * distances +
                        (1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
        return losses.mean() if size_average else losses.sum()

In [9]:
# Original code
# https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch/blob/master/Siamese-networks-medium.ipynb
class PairContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super().__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim = False)
        if len(euclidean_distance.shape)>1: euclidean_distance = euclidean_distance.sum(-1)
        c_pos = (label) * torch.pow(euclidean_distance, 2)
        c_neg = (1-label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        loss_contrastive = (c_neg + c_pos)
        return loss_contrastive.mean()

        return loss_contrastive

In [10]:
#export
# https://stackoverflow.com/questions/47107589/how-do-i-use-a-bytetensor-in-a-contrastive-cosine-loss-function
class CosineContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super().__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        cos_sim = F.cosine_similarity(output1, output2, dim=-1)
        loss_cos_con = torch.mean((label) * torch.div(torch.pow((1.0-cos_sim), 2), 4) +
                                    (1-label) * torch.pow(cos_sim * torch.lt(cos_sim, self.margin), 2))
        return loss_cos_con

In [11]:
def cosine_sim(output1, output2):
    num = output1.T @ output2
    denom = torch.norm(output1) * torch.norm(output2)
    out = num / denom
    return out

In [12]:
ls = {
    '1_contloss': ContrastiveLoss(),
    '2_pairloss': PairContrastiveLoss(),
    '3_cosloss': CosineContrastiveLoss(),
#     '4_xentloss': XentContrastiveLoss()
}

In [13]:
def loss_out(t,a,l):
    return {k:v(t,a,l) for k,v in ls.items()}

In [14]:
targ = torch.randn(4, 8) # bs x features
aug = targ + 0.1 # bs x feat
targ.shape, aug.shape

(torch.Size([4, 8]), torch.Size([4, 8]))

In [15]:
labels = torch.zeros(aug.shape[0])
labels[0] = 1

# single target to rest of batch
loss_out(targ[:1], aug, labels)

{'1_contloss': tensor(0.0200),
 '2_pairloss': tensor(0.0200),
 '3_cosloss': tensor(0.0243)}

In [16]:
# 1-to-1 match targ -> aug. Loss should be 8x higher
labels = torch.ones(aug.shape[0])
loss_out(targ, aug, labels)

{'1_contloss': tensor(0.0800),
 '2_pairloss': tensor(0.0800),
 '3_cosloss': tensor(4.0329e-06)}

In [17]:
# single targ -> all rand. Super High Loss
labels = torch.ones(aug.shape[0])
loss_out(targ[:1], aug, labels)

{'1_contloss': tensor(18.8922),
 '2_pairloss': tensor(18.8921),
 '3_cosloss': tensor(0.2493)}

### Batched contrastive loss

In [18]:
#export
def batched_labels(output1, output2, onehot=True):
    bs = output1.shape[0]
    rp = [1]*len(output1.shape)
    o1 = output1.repeat(*rp,bs).view(bs,*output1.shape)
    labels = torch.arange(bs, device=output1.device)
    if onehot: labels = torch.eye(o1.shape[0], device=output1.device)[labels]
    return o1, output2, labels

In [19]:
closs = ContrastiveLoss()
# closs = CosineContrastiveLoss()

In [20]:
bs,feat=6,18 # 
targ = torch.randn(bs, feat)*5 # bs x features
aug = targ+0.1 # bs x features
# aug = torch.randn(bs, feat) # bs x features
o1, o2, l = batched_labels(targ, targ)

In [21]:
o1.pow(2).shape

torch.Size([6, 6, 18])

In [22]:
o1.shape

torch.Size([6, 6, 18])

In [23]:
torch.norm(o1)

tensor(135.7687)

In [24]:
o1.shape, o2.shape

(torch.Size([6, 6, 18]), torch.Size([6, 18]))

In [25]:
cs = F.cosine_similarity(o1, o2, dim=-1)/0.1
F.cross_entropy(cs, torch.arange(bs))

tensor(0.0007)

In [26]:
torch.eye(bs)

tensor([[1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1.]])

In [27]:
F.softmax(torch.eye(2), dim=-1)

tensor([[0.7311, 0.2689],
        [0.2689, 0.7311]])

In [28]:
torch.eye(2)

tensor([[1., 0.],
        [0., 1.]])

In [29]:
F.cross_entropy(torch.eye(2), torch.tensor([0,1]))

tensor(0.3133)

In [30]:
cs/0.5

tensor([[ 20.0000,   4.6216,  -3.3481,   4.0333,  -2.9250,   2.5821],
        [  4.6216,  20.0000,   0.6266,   0.6842,   2.6846,  -2.2457],
        [ -3.3481,   0.6266,  20.0000, -10.0835,  -2.5101,   2.1343],
        [  4.0333,   0.6842, -10.0835,  20.0000,  -7.1823,  -4.2012],
        [ -2.9250,   2.6846,  -2.5101,  -7.1823,  20.0000,   5.4538],
        [  2.5821,  -2.2457,   2.1343,  -4.2012,   5.4538,  20.0000]])

In [31]:
F.softmax(cs/0.1, dim=-1)

tensor([[1.0000e+00, 4.0384e-34, 0.0000e+00, 2.1319e-35, 0.0000e+00, 1.5048e-38],
        [4.0384e-34, 1.0000e+00, 8.5339e-43, 1.1393e-42, 2.5118e-38, 0.0000e+00],
        [0.0000e+00, 8.5339e-43, 1.0000e+00, 0.0000e+00, 0.0000e+00, 1.6041e-39],
        [2.1319e-35, 1.1393e-42, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 2.5118e-38, 0.0000e+00, 0.0000e+00, 1.0000e+00, 2.5895e-32],
        [1.5048e-38, 0.0000e+00, 1.6041e-39, 0.0000e+00, 2.5895e-32, 1.0000e+00]])

In [32]:
o1.shape, o2.shape, l.shape

(torch.Size([6, 6, 18]), torch.Size([6, 18]), torch.Size([6, 6]))

In [33]:
losses = []
for i in range(targ.shape[0]):
    bs = targ.shape[0]
    labels = torch.zeros(bs)
    labels[i] = 1 # set current target as the only positive label
    losses.append(closs(targ[i], aug, labels))
torch.stack(losses).mean()

tensor(0.0300)

In [34]:
closs(o1,o2,l)

tensor(0.)

In [35]:
# Sanity checking - setting wrong target as positive label
losses = []
for i in range(targ.shape[0]):
    bs = targ.shape[0]
    labels = torch.zeros(bs)
    labels[0] = 1 # set current target as the only positive label
    losses.append(closs(targ[i], aug, labels))
torch.stack(losses).mean()

tensor(175.1980)

In [36]:
l3 = l[torch.zeros(o1.shape[0]).long()]
closs(o1,o2,l3)

tensor(174.6499)

In [37]:
# Sanity checking - setting wrong target as positive label
losses = []
for i in range(targ.shape[0]):
    bs = targ.shape[0]
    labels = torch.ones(bs)
    labels[0] = 1 # set current target as the only positive label
    losses.append(closs(targ[i], aug, labels))
torch.stack(losses).mean()

tensor(882.6721)

In [38]:
l4 = torch.ones(o1.shape[0], o1.shape[0])
closs(o1,o2,l4)

tensor(882.4920)

## XentContrastiveLoss

In [39]:
#export
# https://arxiv.org/pdf/2002.05709.pdf
class XentOldContrastiveLoss(nn.Module):
    def __init__(self, temp=0.5):
        super().__init__()
        self.temp = temp

    def forward(self, output1, output2, labels):
        cos_sim = F.cosine_similarity(output1, output2, dim=-1)/self.temp
        xent_loss = F.cross_entropy(cos_sim,labels.long())
        return xent_loss

In [153]:
bs,feat=4,16 # 
targ = torch.randn(bs, feat) # bs x features
rand_aug = torch.randn(bs, feat) # bs x features
aug = targ + 0.1 # bs x feat
output1, output2, labels = batched_labels(targ, aug, True)

In [154]:
_, lmax = labels.max(dim=-1); lmax
xent_loss = XentOldContrastiveLoss(1.0)
xent_loss(output1, output2, lmax)

tensor(0.7626)

In [155]:
cos_sim = F.cosine_similarity(output1, output2, dim=-1)

In [156]:
lsoft = F.log_softmax(cos_sim, dim=-1); lsoft

tensor([[-0.6937, -1.8886, -1.5841, -1.9388],
        [-1.9855, -0.7927, -1.3897, -1.8270],
        [-1.7515, -1.4645, -0.8636, -1.7507],
        [-1.9708, -1.7779, -1.6339, -0.7002]])

#### Manual

In [157]:
lsoft1 = torch.log(torch.exp(cos_sim)/torch.exp(cos_sim).sum(dim=-1)); lsoft1

tensor([[-0.6937, -1.9863, -1.7527, -1.9423],
        [-1.8878, -0.7927, -1.4606, -1.7327],
        [-1.5828, -1.3936, -0.8636, -1.5856],
        [-1.9674, -1.8721, -1.7990, -0.7002]])

In [158]:
# nll loss
l_nll = torch.mean(torch.sum(-labels * lsoft1, dim=1)); l_nll

tensor(0.7626)

In [159]:
F.nll_loss(lsoft, lmax)

tensor(0.7626)

## Without Xent

In [160]:
cos_sim

tensor([[ 0.9960, -0.1988,  0.1057, -0.2491],
        [-0.1980,  0.9948,  0.3978, -0.0395],
        [ 0.1069,  0.3939,  0.9948,  0.1077],
        [-0.2776, -0.0846,  0.0594,  0.9930]])

In [161]:
cexp = torch.exp(cos_sim)
cexp

tensor([[2.7075, 0.8197, 1.1115, 0.7795],
        [0.8204, 2.7042, 1.4885, 0.9613],
        [1.1128, 1.4827, 2.7041, 1.1137],
        [0.7576, 0.9189, 1.0612, 2.6994]])

In [162]:
cexp.sum(dim=-1)

tensor([5.4182, 5.9743, 6.4134, 5.4370])

In [163]:
neg_denom = (cexp*(1-labels)).sum(dim=-1); neg_denom

tensor([2.7107, 3.2702, 3.7092, 2.7376])

In [164]:
pos = cos_sim[range(lmax.shape[0]),lmax]; pos

tensor([0.9960, 0.9948, 0.9948, 0.9930])

In [165]:
cexp.sum(dim=-1) - torch.exp(pos)

tensor([2.7107, 3.2702, 3.7092, 2.7376])

In [166]:
lsoft2 = torch.log(cexp/neg_denom); lsoft2

tensor([[-1.1679e-03, -1.3837e+00, -1.2052e+00, -1.2561e+00],
        [-1.1952e+00, -1.9004e-01, -9.1303e-01, -1.0466e+00],
        [-8.9029e-01, -7.9096e-01, -3.1605e-01, -8.9943e-01],
        [-1.2748e+00, -1.2695e+00, -1.2515e+00, -1.4067e-02]])

In [167]:
l_nll = torch.mean(torch.sum(-lmax * lsoft2, dim=1)); l_nll

tensor(5.1635)

## Test

In [216]:
temp = 1.0
cos_sim = F.cosine_similarity(output1, output2, dim=-1)/temp
cexp = torch.exp(cos_sim)
cexp

tensor([[2.7075, 0.8197, 1.1115, 0.7795],
        [0.8204, 2.7042, 1.4885, 0.9613],
        [1.1128, 1.4827, 2.7041, 1.1137],
        [0.7576, 0.9189, 1.0612, 2.6994]])

In [217]:
x = (cexp * labels).sum(dim=-1); x

tensor([2.7075, 2.7042, 2.7041, 2.6994])

In [218]:
(cexp*(1-labels)).sum(dim=-1)

tensor([2.7107, 3.2702, 3.7092, 2.7376])

In [219]:
denom = cexp.sum(dim=-1) - x; denom

tensor([2.7107, 3.2702, 3.7092, 2.7376])

In [220]:
lsoft = -torch.log(cexp/denom)
lsoft

tensor([[1.1679e-03, 1.3837e+00, 1.2052e+00, 1.2561e+00],
        [1.1952e+00, 1.9004e-01, 9.1303e-01, 1.0466e+00],
        [8.9029e-01, 7.9096e-01, 3.1605e-01, 8.9943e-01],
        [1.2748e+00, 1.2695e+00, 1.2515e+00, 1.4067e-02]])

In [208]:
-torch.log(x)+torch.log(denom)

tensor([0.0012, 0.1900, 0.3160, 0.0141])

In [215]:
torch.log(torch.arange(10,1).float())

RuntimeError: upper bound and larger bound inconsistent with step sign

In [209]:
neg_denom = (cexp*(1-labels)).sum(dim=-1); neg_denom
lsoft1 = torch.log(cexp/neg_denom)
lsoft2 = torch.sum(-labels * lsoft1, dim=-1)
lsoft2

tensor([0.0012, 0.1900, 0.3160, 0.0141])

## Put it into code

In [61]:
#export
# https://arxiv.org/pdf/2002.05709.pdf
class XentContrastiveLoss(nn.Module):
    def __init__(self, temp=0.5):
        super().__init__()
        self.temp = temp

    def forward(self, output1, output2, labels):
        cos_sim = F.cosine_similarity(output1, output2, dim=-1)/self.temp
        cexp = torch.exp(cos_sim)
        neg_denom = (cexp*(1-labels)).sum(dim=-1)
        lsoft = torch.log(cexp/neg_denom)
        lsoft = torch.sum(-labels * lsoft, dim=-1)
        print(lsoft)
        return lsoft.mean()

In [201]:
#export
class XentContrastiveLoss2(nn.Module):
    def __init__(self, temp=0.5):
        super().__init__()
        self.temp = temp

    def forward(self, output1, output2, labels):
        cos_sim = F.cosine_similarity(output1, output2, dim=-1)/self.temp
        cexp = torch.exp(cos_sim)
        x = (cexp * labels).sum(dim=-1)
        denom = cexp.sum(dim=-1) - x
        lsoft = -torch.log(x/denom)
        print(lsoft)
        return lsoft.mean()

## Batch

In [63]:
#export
class BatchContrastiveLoss(nn.Module):
    def __init__(self, loss_func):
        super().__init__()
        self.loss_func = loss_func
        self.onehot = not isinstance(loss_func, XentOldContrastiveLoss)
        
    def forward(self, output1, output2):
        output1, output2, labels = batched_labels(output1, output2, self.onehot)
        return self.loss_func(output1, output2, labels)

In [78]:
temp = 0.1
ls2 = {
#     '1_contloss': ContrastiveLoss(margin=1.0),
    '4_xentloss': XentContrastiveLoss(temp=temp),
    '5_xentloss2': XentContrastiveLoss2(temp=temp),
    '6_oldxentloss': XentOldContrastiveLoss(temp=temp)
}

def batch_loss_out(t,a):
    return {k:BatchContrastiveLoss(v)(t,a) for k,v in ls2.items()}

In [79]:
bs,feat=16,128 # 
targ = torch.randn(bs, feat) # bs x features
rand_aug = torch.randn(bs, feat) # bs x features
aug = targ + 0.05 # bs x feat

In [98]:
batch_loss_out(targ,targ)

{'4_xentloss': tensor(-6.8875),
 '5_xentloss2': tensor(-6.8876),
 '6_oldxentloss': tensor(0.0010)}

In [81]:
batch_loss_out(targ,rand_aug)

{'4_xentloss': tensor(3.4285),
 '5_xentloss2': tensor(3.4285),
 '6_oldxentloss': tensor(3.4732)}

In [77]:
bs,feat=256,256 # 
targ = torch.randn(bs, feat) # bs x features
aug = targ+0.1 # bs x features
rand_aug = torch.randn(bs, feat) # bs x features
batch_loss_out(targ,-aug)

{'4_xentloss': tensor(7.5412),
 '5_xentloss2': tensor(7.5412),
 '6_oldxentloss': tensor(7.5418)}

## Export

In [69]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_metadata.ipynb.
Converted 01_preprocess.ipynb.
Converted 01_preprocess_mean_std.ipynb.
Converted 02_train.ipynb.
Converted 02_train_01_save_features.ipynb.
Converted 03_train3d.ipynb.
Converted 04_trainfull3d_deprecated.ipynb.
Converted 04_trainfull3d_labels.ipynb.
Converted 05_train_adjacent.ipynb.
Converted 06_seutao_features.ipynb.
Converted 07_adni.ipynb.
This cell doesn't have an export destination and was ignored:
e
This cell doesn't have an export destination and was ignored:
#
This cell doesn't have an export destination and was ignored:
e
Converted 07_adni_01.ipynb.
Converted 08_contrastive_loss.ipynb.
Converted 08_imagewang.ipynb.
Converted 08_train_self_supervised.ipynb.
Converted 08_train_self_supervised_train_1.ipynb.
Converted 08_train_self_supervised_train_2_nocombined.ipynb.
Converted 08_train_self_supervised_train_2_nocombined_contrast.ipynb.
Converted 08_train_self_supervised_train_3.ipynb.
Converted 08_train_self_supervised_train_4_nocombined_xent-Copy1.