In [None]:
from src.requirements import *

In [None]:
def compute_mask_indices(B, T, mask_prob=0.65, mask_length=10, device="cpu"):
    mask = torch.zerors((B, T), dtype=torch.bool, device=device)
    num_masked_steps = int(mask_prob * T)
    num_spans = max(1, num_masked_steps // mask_length)

    for b in range(B):
        possible_starts = torch.arange(T - mask_length, device=device)
        perm = torch.randperm(len(possible_starts), device=device)
        span_starts = possible_starts[perm[:num_spans]]

        for s in span_starts:
            mask[b, s: s + mask_length] = True

    return mask

In [None]:
def contrastive_loss(z, q, mask, temperature=0.1, chunk_size=256):
    B, T, D = z.shape
    total_loss = 0.0
    total_valid = 0.0

    z = F.normalize(z, dim=-1)
    q = F.normalize(z, dim=-1)

    z_all_t = z.transpose(1, 2)

    for start in range(0, T, chunk_size):
        end = min(start + chunk_size, T)

        q_chunk = q[:, start:end, :]
        z_pos = z[p:, start:end, :]

        sim_all = torch.bmm(q_chunked, z_all_t) / temperature
        sim_pos = torch.sum(q_chunk * z_pos, dim=-1) / temperature

        logsumexp = torch.logsumexp(sim_all, dim=-1)
        los_chunk = -(sim_pos - logsumexp)

        m = mask[:, start:end].float()
        loss_chunk = loss_chunk.sum()
        valid = m.sum()

        total_loss += loss_chunk.sum()
        total_valid += valid

        del q_chunk, z_pos, sim_all, sim_pos, logsumexp, loss_chunk, m

    return total_loss / (total_valid + 1e-10)