In [1]:
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    _compute_mask_indices,
    _sample_negative_indices,
)

import torch
import numpy as np
from torch import nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
bs, seqlen, feat = 16, 100, 78

preds = torch.rand(bs, seqlen, feat)
labels = torch.rand(bs, seqlen, feat)

In [3]:
mask_time_indices = _compute_mask_indices(
    shape=(bs, seqlen), mask_prob=0.5, mask_length=1
)
mask_time_indices.shape

(16, 100)

In [4]:
sampled_negative_indices = _sample_negative_indices(
    features_shape=(bs, seqlen),
    num_negatives=20,
    mask_time_indices=mask_time_indices,
)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices)
sampled_negative_indices.shape

torch.Size([16, 100, 20])

In [5]:
# for training, we sample negatives
# 3. sample K negatives (distractors) quantized states for contrastive loss

negative_features = labels.view(-1, feat)[sampled_negative_indices.long().view(-1)]
print(negative_features.shape)

negative_features = negative_features.view(bs, seqlen, -1, feat).permute(2, 0, 1, 3)
negative_features.shape

torch.Size([32000, 78])


torch.Size([20, 16, 100, 78])

In [6]:
def compute_contrastive_logits(
    target_features: torch.FloatTensor,
    negative_features: torch.FloatTensor,
    predicted_features: torch.FloatTensor,
    temperature: int = 0.1,
):
    """
    Compute logits for contrastive loss based using cosine similarity as the distance measure between
    `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
    """
    target_features = torch.cat([target_features, negative_features], dim=0)

    logits = torch.cosine_similarity(
        predicted_features.float(), target_features.float(), dim=-1
    ).type_as(target_features)

    # apply temperature
    logits = logits / temperature
    return logits

In [7]:
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf

logits = compute_contrastive_logits(labels[None, :], negative_features, preds)
logits.shape

torch.Size([21, 16, 100])

In [8]:
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
# its cosine similarity will be masked

neg_is_pos = (preds == negative_features).all(-1)
neg_is_pos.shape

torch.Size([20, 16, 100])

In [9]:
if neg_is_pos.any():
    logits[1:, neg_is_pos] = float("-inf")

In [10]:
logits.shape

torch.Size([21, 16, 100])

In [11]:
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
mask_time_indices = torch.from_numpy(mask_time_indices)

logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
logits.shape, target.shape

(torch.Size([1600, 21]), torch.Size([1600]))

In [12]:
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="mean")
contrastive_loss

tensor(3.0946)

### Class implementation 

In [21]:
import torch
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    _compute_mask_indices,
    _sample_negative_indices,
)
from src.utils.torch_utils import Wav2Vec2ContrastiveLoss


# Test the class
preds = torch.rand(bs, seqlen, feat).requires_grad_(True)
labels = torch.rand(bs, seqlen, feat)
loss_fn = Wav2Vec2ContrastiveLoss(mask_rate=0.5, reduction="mean")
mask = loss_fn.create_mask((bs, seqlen))
print(f"shapes: {preds.shape}, {labels.shape}, {mask.shape}")
loss = loss_fn(preds, labels, mask)
loss

shapes: torch.Size([16, 100, 78]), torch.Size([16, 100, 78]), (16, 100)


tensor(3.0854, grad_fn=<NllLossBackward0>)

In [25]:
import torch

t = torch.arange(10)
t

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [26]:
t[1:6]

tensor([1, 2, 3, 4, 5])