In an effort to "squeeze" nuisance variability out of the biological partition of my VAE latent space, I have introduced NT-Xent metric loss to the objective function used for training. Basically, the method works by generating "contrastive pairs"; i.e. altered versions (e.g. flipped, resized, brightened) of the same underlying image. The NT-Xent loss is designed to encourage the model to place transformed versions of the same image nearby in latent space. 

Let's look at how this plays out.

In [12]:
import torch
from pytorch_metric_learning.losses import NTXentLoss
from pytorch_metric_learning import distances
import numpy as np

# this is my implementation
def nt_xent(inputs, labels=None, temperature=1):
    
    if labels is None:
        labels = torch.zeros(inputs.shape[0], dtype=torch.long)
        
    loss_fun = torch.nn.CrossEntropyLoss()
    
    loss = loss_fun(-inputs/temperature, labels)
    
    return loss

In [11]:
# simulate some hypothetical distances
batch_size = 10 # so that we have a decent sample size
mu_pos = 10
mu_neg = 100
temp=0.01

logit_array_sim = torch.rand((batch_size, batch_size))
logit_array_sim[:, 0] = logit_array_sim[:, 0]*mu_pos
logit_array_sim[:, 1:] = logit_array_sim[:, 1:]*mu_neg
# logit_array_sim = logit_array_sim.long()

print(nt_xent(logit_array_sim, temperature=temp))
# print(logit_array_sim)

tensor(169.0919)


In [None]:
# let's see what the built-in version form the metric learning package does
temp = .001

ntx_loss_euc = NTXentLoss(temperature=temp, distance=
                          distances.LpDistance(normalize_embeddings=False, is_inverted=True))
ntx_loss_cos = NTXentLoss(temperature=temp)

batch_size = 10
latent_dim = 5

class_vec = torch.zeros((batch_size,), dtype=torch.float32) 
class_vec[5:] = 1

c1val = 1
c2val = 1

mu_array = torch.ones((batch_size, latent_dim)) 

mu_array[:5, :] = c1val
mu_array[5:, :] = c2val

embeddings = torch.normal(mu_array, std=torch.ones((batch_size, latent_dim)))

loss_euc = ntx_loss_euc(embeddings=embeddings, labels=class_vec)
loss_cos = ntx_loss_cos(embeddings=embeddings, labels=class_vec)

print(loss_euc)
# print(loss_cos)

In [172]:
# let's try to engineer my own version of multi-target NT-Xent
# target = None
batch_size = 10
mu_pos = 1
mu_neg = 10
temperature = .001

logits = torch.zeros((batch_size, batch_size), dtype=torch.float32)
n_pos = 3
# target_array = torch.zeros((batch_size, batch_size), dtype=torch.float32)

logits[:, :n_pos] = mu_pos
logits[:, n_pos:] = mu_neg
logits[:, -1] = 1
# def nt_xent_multitarget(logits, temperature=1, target=None):
target = torch.zeros((logits.shape), dtype=torch.float32)
target[:, :n_pos] = 1

target[:, -1] = -1
# if target is None:
#     target = torch.zeros((logits.shape), dtype=torch.float32)
#     target[:, 0] = 1

logits_tempered = logits/temperature
# max_val = torch.max(logits_tempered)
# logits_normed = logits_tempered - max_val # should prevent overflow
logits_tempered[target==-1] = -torch.inf
logits_num = logits_tempered.clone()
logits_num[target==0] = -torch.inf
# logits_exp = torch.exp(logits_normed)

numerator = -torch.logsumexp(logits_num, axis=1)
denominator = torch.logsumexp(logits_tempered, axis=1)

loss = numerator + denominator


print(torch.mean(loss))
#     return loss


tensor(9000.6943)


In [148]:
print(torch.logsumexp(logits_num[0, :], axis=0))

tensor(100.)


In [None]:
# calculate for different temperatures and difference sizes
mu_neg = 100
temp_vec = np.logspace(-4, 2)
delta_vec = np.logspace(-3, 0)

loss_array = np.empty((len(temp_vec), len(delta_vec)))

for d, delta in enumerate(delta_vec):
    
    mu_pos = mu_neg*delta
    
    logit_array_sim = torch.rand((batch_size, batch_size))
    logit_array_sim[:, 0] = mu_pos
    logit_array_sim[:, 1:] = mu_neg
    
    for t, temp in enumerate(temp_vec):
        loss_array[t, d] = nt_xent(logit_array_sim, temperature=temp)

In [None]:
import plotly.express as px
import plotly.graph_objects as go

fig = go.Figure()
fig.add_trace(go.Surface(z=loss_array, x=np.log10(delta_vec), y=np.log10(temp_vec)))

fig.update_layout(scene = dict(
                    xaxis_title='pos/neg',
                    yaxis_title='temperature'))
                  
fig.show()

In [None]:
np.max(loss_array)

**What if we instead want to enforce a LACK of differentiation?** A nihilistic criterion, if you will

In [None]:
def nt_xent_nil(inputs, labels=None, temperature=1):
    
    if labels is None:
        labels = torch.zeros(inputs.shape[0], dtype=torch.long)
        
    loss_fun = torch.nn.CrossEntropyLoss()
    
    loss = loss_fun(inputs/temperature, labels)
    
    return loss

In [None]:
mu_neg = 100
temp_vec = np.logspace(-4, 2)
delta_vec = np.logspace(-3, 0)

loss_array_nil = np.empty((len(temp_vec), len(delta_vec)))

for d, delta in enumerate(delta_vec):
    
    mu_pos = mu_neg*delta
    
    logit_array_sim = torch.rand((batch_size, batch_size))
    logit_array_sim[:, 0] = mu_pos
    logit_array_sim[:, 1:] = mu_neg
    
    for t, temp in enumerate(temp_vec):
        loss_array_nil[t, d] = nt_xent_nil(logit_array_sim, temperature=temp)

In [None]:
fig = go.Figure()
fig.add_trace(go.Surface(z=loss_array_nil, x=np.log10(delta_vec), y=np.log10(temp_vec)))

fig.update_layout(scene = dict(
                    xaxis_title='pos/neg',
                    yaxis_title='temperature'))
                  
fig.show()

#### What about binary cross entropy? This is the recommended loss function when there are multiple positive examples

In [None]:
def cross_entropy_multitarget(logits, temperature=1, target=None):

    if target is None:
        target = torch.zeros((logits.shape), dtype=torch.float32)
        target[:, 0] = 1
        
    logits_exp = torch.exp(logits/temperature)
    numerator = torch.sum(torch.multiply(logits_exp, target), axis=1)
    denominator = torch.sum(logits_exp, axis=1)
    
    loss = -torch.log(torch.divide(numerator, denominator))
    
    return loss

In [None]:
batch_size = 10 # so that we have a decent sample size
mu_pos = 1
mu_neg = 15
temp = 1

logit_array_sim = torch.rand((batch_size, batch_size))

target_array = torch.zeros((batch_size, batch_size), dtype=torch.float32)

logit_array_sim[:, 0:5] = mu_pos
logit_array_sim[:, 5:] = mu_neg

target_array[:, 0:1] = 1
target_array[:, 1:] = 0

loss = cross_entropy_multitarget(-logit_array_sim, temp, target=target_array)
print(torch.mean(loss))

# nt_xent(-logit_array_sim, temperature=temp)

In [None]:
from pytorch.met

In [None]:
batch_size = 1000 # so that we have a decent sample size
mu_pos = 1
mu_neg = 100
temp=0.0001
# n_classes = 5
target = torch.randint(2, (batch_size, batch_size), dtype=torch.float32)  # 64 classes, batch size = 10
output = torch.full([batch_size, batch_size], 1.5)  # A prediction (logit)
output[target==1] = 100
output[target==0] = 0.1
# pos_weight = torch.ones([64])  # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss()
criterion(output, target) 

In [None]:
mu_pos = 1000
mu_neg = 1
temp = 100

target_vec = torch.zeros((1, batch_size), dtype=torch.float32)
target_vec[0, 0:int(batch_size/2)] = 0 
target_vec[0, int(batch_size/2):] = 1

logits = torch.zeros([1, batch_size], dtype=torch.float32)
logits[0, 0:int(batch_size/2)] = mu_neg/temp
logits[0, int(batch_size/2):] = mu_pos/temp
# logits = logits.to(torch.long)

criterion = torch.nn.BCEWithLogitsLoss()
criterion(logits, target_vec) 

In [None]:
logits

In [None]:


logit_array_sim = torch.rand((batch_size, batch_size))
logit_array_sim[:, 0] = mu_pos#logit_array_sim[:, 0]*mu_pos
logit_array_sim[:, 1:] = mu_neg#logit_array_sim[:, 1:]*mu_neg
# logit_array_sim = logit_array_sim.long()

print(nt_xent_nil(logit_array_sim, temperature=temp))

In [None]:
temperature = .0001
labels = torch.zeros(10, dtype=torch.long)
inputs = torch.zeros((10, 100))
inputs[:, 0] = 1
inputs[:, 1:] = 50

loss_fun = torch.nn.CrossEntropyLoss()
loss = loss_fun(-inputs/temperature, labels)
print(loss)

In [None]:
mu_neg = 100
delta = 0.001
mu_pos = mu_neg*delta
    
logit_array_sim = torch.rand((batch_size, batch_size))
logit_array_sim[:, 0] = logit_array_sim[:, 0]*mu_pos
logit_array_sim[:, 1:] = logit_array_sim[:, 1:]*mu_neg

print(torch.mean(logit_array_sim[:, 0]))
print(torch.mean(logit_array_sim[:, 1:]))