In [None]:
from src.requirements import *
from src.audio_handler import AudioDataset, collate_padding
from src.ssl_model import *

## Contrastive loss
$$ L_{SSL} = \sum_i -log \frac{exp(sin(\hat z_i, z_i^+)/\tau)}{\sum_j exp(sin(\hat z_i, z_j^-)/\tau)} $$

- B = Batch size
- T = Temporal Length
- D = Feature dimension

## latent_features (z)
- Latent features are unique representations of an audio sample in N-dimension (here 512-dimensional vector).
- for "namaskar" with each letter representing one vector pos z = [z1, z2, z3, z4, z5, z6, z7, z8]
- say we have masked_indices of shape = (8, 1) given as mask_indices = [1, 1, 0, 0, 1, 0, 1, 1]
- so z_masked = [z1, z2, 0, 0, z5, 0, z7, z8]

## context (c)
- we now have a masked input which has some missing values. now the model needs to understand the context through z_masked.
- context (c) is programmed to understand what z_masked is saying so far.
- c tries to understand and contextualize "na--k-ar"

## predicted queries (q)
- we now have latent_features and context of masked_latent_features, now the model needs to predict based on context
- the model tries to predict queries (q) to fill the masked indices, say we got q = "--ma-k--"

## contrastive loss ($L_{SSL}$)
- we already have latent_features which is the tokenized audio sample, and now predicted queries.

$\hat z_i$ = q = predicted masked representations <br>
$z_i^+$ = not z_masked = whatever values/tokens were masked <br>
$z_i^-$ = negative samples = we are intended to give the model multiple wrong choices so that it learns to pick the most appropriate one (there are no correct answers) <br>
but, these features (instead of manual intervention) are fed to the formula using a random masked values from $z_k$ as a negative sample

In [None]:
# vectorization
def contrastive_loss(z, q, mask=None, temperature=0.1):
    z = F.normalize(z, dim=-1)
    q = F.normalize(q, dim=-1)

    z_cpu = z.detach().cpu()
    q_cpu = q.detach().cpu()

    # batch matrxi-matrix product
    sim_cpu = torch.bmm(q_cpu, z_cpu.transpose(1,2)) / temperature
    sim = sim_cpu.to(z.device, non_blocking=True)

    sim_pos = sim.diagonal(dim1=1, dim2=2)

    logsum_exp = torch.logsumexp(sim, dim=-1)
    loss = - (sim_pos - logsum_exp)

    if mask is not None:
        loss = loss * mask.float()

    return loss.mean()

In [None]:
def contrastive_loss_chunked(z, q, mask=None, temperature=0.1, chunk_size=256):
    B, T, D = z.shape
    z = F.normalize(z, dim=-1)
    q = F.normalize(q, dim=-1)

    total_loss = 0.0
    total_count = 0

    for start in range(0, T, chunk_size):
        end = min(start + chunk_size, T)
        q_chunk = q[:, start:end, :]  # [B, chunk, D]

        # Compute similarity only for this chunk
        sim = torch.bmm(q_chunk, z.transpose(1, 2)) / temperature  # [B, chunk, T]
        sim_pos = torch.sum(q_chunk * z[:, start:end, :], dim=-1) / temperature  # [B, chunk]
        logsum_exp = torch.logsumexp(sim, dim=-1)  # [B, chunk]
        loss_chunk = -(sim_pos - logsum_exp)  # [B, chunk]

        if mask is not None:
            loss_chunk = loss_chunk * mask[:, start:end].float()

        total_loss += loss_chunk.sum()
        total_count += loss_chunk.numel()

        # free intermediate tensor memory
        del q_chunk, sim, sim_pos, logsum_exp, loss_chunk
        torch.cuda.empty_cache()

    return total_loss / total_count

In [None]:
def train(model, train_dl, loss_fn, epochs, optimizer, scaler, scheduler, device):
    accum = 16
    model.train()
    cumulative_loss = []

    for epoch in range(epochs):
        print(f"Epoch [{epoch+1}/{epochs}]")
        
        for i, batch in enumerate(tqdm(train_dl)):
            batch = batch.to(device)
            
            with torch.autocast(device):
                z, q, mask = model(batch)
                loss = loss_fn(z, q, mask, chunk_size=128) / accum
            scaler.scale(loss).backward()
    
            if (i+1) % accum == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

        scheduler.step()
        print(f"Loss: {loss}")
        cumulative_loss.append(loss)
        torch.cuda.empty_cache()

    torch.save(model.state_dict(), os.path.join("models", "ssl_model_prototype.pth"))
    
    return cumulative_loss

In [None]:
path = os.path.join('data', 'metadata.tsv')
batch_size = 2

train_dataset = AudioDataset(metadata_path=path)
train_dl = DataLoader(
    dataset = train_dataset,
    batch_size = batch_size,
    pin_memory = True,
    collate_fn = collate_padding, 
    shuffle=True
)

In [None]:
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = SSLAutoregressiveModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

epochs = 1
scaler = torch.GradScaler(device)
scheduler = StepLR(optimizer, step_size=1, gamma=0.1)

cumulative_loss = train(model, train_dl, contrastive_loss_chunked, epochs, optimizer, scaler, scheduler, device)