In [None]:
# we have the output function
all_im_embs = CLIPModelOutput.image_embeddings
all_txt_embs = CLIPModelOutput.text_embeddings
N = CLIPModelOutput.num_computed_embeddings
sim_bs = CLIPModelOutput.sim_batch_size

if all_im_embs is None:
    raise AssertionError(
        "Run traker.task.get_embeddings first before featurizing!"
    )

# tailored for open_clip
# https://github.com/mlfoundations/open_clip/blob/fb72f4db1b17133befd6c67c9cf32a533b85a321/src/open_clip/model.py#L242-L245
clip_inputs = {"image": image.unsqueeze(0), "text": label.unsqueeze(0)}
image_embeddings, text_embeddings, _ = ch.func.functional_call(
    model, (weights, buffers), args=(), kwargs=clip_inputs
)

ii = ch.multinomial(
    input=ch.arange(N).float(), num_samples=sim_bs, replacement=False
)

result = -ch.logsumexp(
    -image_embeddings @ (text_embeddings - all_txt_embs[ii]).T, dim=1
) + -ch.logsumexp(
    -text_embeddings @ (image_embeddings - all_im_embs[ii]).T, dim=1
)
return result.sum()  # shape of result should be [1]

In [None]:
# clip loss

# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter

# extract feature representations of each modality
I_f = image_encoder(I)  # [n, d_i]
T_f = text_encoder(T)  # [n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t) / 2

In [113]:
import torch

all_img_embds = torch.randn(3, 25)
all_txt_embds = torch.randn(3, 25)
input_img_emb = torch.randn(1, 25)
input_txt_emb = torch.randn(1, 25)


In [114]:
import torch


def out_fn(all_img_embds, all_txt_embds, input_img_emb, input_txt_emb):
    N = all_img_embds.shape[0]
    ii = torch.arange(N)
    # torch.multinomial(
    #     input=torch.arange(N).float(), num_samples=sim_bs, replacement=False
    # )
    return -torch.logsumexp(
        -input_img_emb @ (input_txt_emb - all_txt_embds[ii]).T, dim=1
    ) + -torch.logsumexp(
        -input_txt_emb @ (input_img_emb - all_img_embds[ii]).T, dim=1
    )


In [115]:
print(out_fn(all_img_embds, all_txt_embds, input_img_emb, input_txt_emb))
print(out_fn(all_img_embds, all_txt_embds, all_img_embds[:1], all_txt_embds[:1]))


tensor([-12.2665])
tensor([-0.0601])


In [116]:
def compute_loss(all_img_embds, all_txt_embds, input_img_emb, input_txt_emb):
    # Compute similarity between input image embedding and all text embeddings
    logits = torch.matmul(input_img_emb, all_txt_embds.T)
    
    # Ground truth label (single element for input embeddings)
    labels = torch.tensor([0], device=input_img_emb.device)
    
    # Compute negative log likelihood loss
    loss = torch.nn.functional.cross_entropy(logits, labels)
    
    return loss

In [117]:
print(compute_loss(all_img_embds, all_txt_embds, input_img_emb, input_txt_emb))
print(compute_loss(all_img_embds, all_txt_embds, all_img_embds[:1], all_txt_embds[:1]))


tensor(1.5377)
tensor(0.0468)


In [92]:
def loss_fn(all_img_embds, all_txt_embds, input_img_emb, input_txt_emb):
    # Compute similarity between input image embedding and all text embeddings
    logits_img = torch.matmul(input_img_emb, all_txt_embds.T)
    # Compute similarity between input text embedding and all image embeddings
    logits_txt = torch.matmul(input_txt_emb, all_img_embds.T)

    # Ground truth labels (single element for input embeddings)
    labels_img = torch.tensor([0], device=input_img_emb.device)
    labels_txt = torch.tensor([0], device=input_txt_emb.device)

    # Compute cross entropy loss for input image and text embeddings
    loss_i = torch.nn.functional.cross_entropy(logits_img, labels_img)
    loss_t = torch.nn.functional.cross_entropy(logits_txt, labels_txt)

    return loss_i, loss_t

loss_i, loss_t = loss_fn(all_img_embds, all_txt_embds, input_img_emb, input_txt_emb)

print(loss_i, loss_t)
print((loss_i + loss_t) / 2)

tensor(0.1029) tensor(5.6286)
tensor(2.8657)
