In [1]:
import h5py

In [3]:
f = h5py.File('../../datasets/modelnet_test', 'r')

In [6]:
f['points']

<HDF5 dataset "points": shape (2468, 4096, 3), type "<f4">

In [20]:
import k3d
# k3d.points(f['points'][20][:], point_size=0.01)
k3d.vectors(origins=f['points'][20][:], vectors=f['point_normals'][20][:])

Output()

In [23]:
(f['point_normals'][20][:]**2).sum(axis=1)

array([34.520233, 33.833126, 28.984823, ..., 37.794834, 36.680023,
       35.602394], dtype=float32)

In [28]:
import torch
import torch.nn.functional as F

In [29]:
v1 = torch.rand(4, 128)
v2 = torch.rand(4, 128)
sample_idx = torch.tensor([0, 1, 2, 3])
labels = torch.tensor([0, 0, 2, 3])

In [54]:
def contrastive_loss_with_one_pos(v1_embeddings, v2_embeddings,
                                  anchor_mode, support_mode,
                                  sample_idx1, sample_idx2,
                                  labels1, labels2,
                                  discard_self_sim):
    batch_size = v1_embeddings.size(0)
    positive_mask = torch.eq(sample_idx1.view(-1, 1), sample_idx2.view(1, -1))  # true where samples' ids are equal
    denum_mask = torch.eq(labels1.view(-1, 1),
                          labels2.view(1, -1))  # mask to discard false negatives from denumerator

    v1_embeddings = F.normalize(v1_embeddings, dim=1)
    v2_embeddings = F.normalize(v2_embeddings, dim=1)

    if support_mode == 'one':
        support_repeat = 1
        embeddings = v2_embeddings
    elif support_mode == 'all':
        support_repeat = 2
        embeddings = torch.cat((v1_embeddings, v2_embeddings), dim=0)

    if anchor_mode == 'one':
        anchor_feature = v1_embeddings
        anchor_count = 1
    elif anchor_mode == 'all':
        anchor_feature = embeddings
        anchor_count = 2

    positive_mask = positive_mask.repeat(anchor_count, support_repeat).float()
    denum_mask = denum_mask.repeat(anchor_count, support_repeat)

    if discard_self_sim:
        logits_mask = torch.scatter(
            torch.ones_like(positive_mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(positive_mask.device),
            0
        )
        
    else:
        logits_mask = torch.ones_like(positive_mask)

    positive_mask *= logits_mask
    denum_mask &= ~positive_mask.bool()

    logits = anchor_feature @ embeddings.T
    logits.masked_fill_(denum_mask, -100)

    log_probs = logits - torch.logsumexp(logits, dim=1).unsqueeze(1)

    mean_over_pos = (log_probs * positive_mask).sum(dim=1) / positive_mask.sum(dim=1)
    loss = -mean_over_pos.mean()

    return loss

In [55]:
contrastive_loss_with_one_pos(v1, v2, 'all', 'all', sample_idx, sample_idx, labels, labels, True)

tensor([[ True,  True, False, False,  True,  True, False, False],
        [ True,  True, False, False,  True,  True, False, False],
        [False, False,  True, False, False, False,  True, False],
        [False, False, False,  True, False, False, False,  True],
        [ True,  True, False, False,  True,  True, False, False],
        [ True,  True, False, False,  True,  True, False, False],
        [False, False,  True, False, False, False,  True, False],
        [False, False, False,  True, False, False, False,  True]])
tensor([[ True,  True, False, False, False,  True, False, False],
        [ True,  True, False, False,  True, False, False, False],
        [False, False,  True, False, False, False, False, False],
        [False, False, False,  True, False, False, False, False],
        [False,  True, False, False,  True,  True, False, False],
        [ True, False, False, False,  True,  True, False, False],
        [False, False, False, False, False, False,  True, False],
        [

tensor(1.7741)