In [1]:
from src.requirements import *
from src.audio_handler import *

In [2]:
torch.cuda.is_available()

True

In [3]:
def compute_mask_indices(B, T, mask_prob, mask_length, device="cpu"):

    mask = torch.zeros(B, T, dtype=torch.bool, device=device)

    num_masked_spans = int((T * mask_prob) / mask_length)
    for b in range(B):
        starts = torch.randint(0, T - mask_length, (num_masked_spans,))
        for s in starts:
            mask[b, s : s + mask_length] = True
    return mask

## SSL Pipeline
- Feature Encoder with 1-channel input (mono audio)
- Autoregressive Context NN using Gated Recurrent Network (GRU) (might remove this)
- Contrastive Predictor (what dimensions i get and what i actually need to get)
- SSL Autoregressive model: encoder -> context -> predictor -> z, q, masked_indices (z = only predicted masked output, q = predicted values, masked_indices = indices that were masked in the original which was then predicted by the model)

In [4]:
# Feature Encoder (SSL)
class FeatureEncoder(nn.Module):
    def __init__(self, in_channels=1, hidden_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels, hidden_dim, kernel_size=7, stride=5, padding=3),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, stride=4, padding=2),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.encoder(x)

# Context Network (RNN)
class AutoregressiveContext(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)

    def forward(self, z):
        # natural autoregressive context representation
        output, _ = self.gru(z)
        return output

# Contrastive Head
class ContrastivePredictor(nn.Module):
    def __init__(self, hidden_dim, proj_dim):
        super().__init__()
        self.project = nn.Linear(hidden_dim, proj_dim)

    def forward(self, x):
        return self.project(x)

# Full SSL Encoder Model
class SSLAutoregressiveModel(nn.Module):
    def __init__(self, feat_dim=128, proj_dim=128):
        super().__init__()
        self.encoder = FeatureEncoder()
        self.context = AutoregressiveContext(feat_dim, feat_dim)
        self.predictor = ContrastivePredictor(feat_dim, proj_dim)
        self.target_proj = nn.Linear(feat_dim, proj_dim)

    def forward(self, x, mask=None, mask_prob=0.065, mask_length=10):
        # z -> true latent features

        z = self.encoder(x)
            
        z = z.transpose(1, 2)
        B, T, F = z.shape

        # masking
        z_masked = z.clone()
        
        if mask is None:
            mask = compute_mask_indices(B, T, mask_prob, mask_length, device=z.device)
        
        # z_masked[torch.arange(z.size(0)).unsqueeze(1), mask_indices] = 0
        z_masked[mask.unsqueeze(-1).expand_as(z_masked)] = 0

        # c -> context
        c = self.context(z_masked)
        
        # q -> predicted queries
        q = self.predictor(c)

        # projection
        z_proj = self.target_proj(z)

        return z_proj, q, mask

## Contrastive loss
$$ L_{SSL} = \sum_i -log \frac{exp(sin(\hat z_i, z_i^+)/\tau)}{\sum_j exp(sin(\hat z_i, z_j^-)/\tau)} $$

- B = Batch size
- T = Temporal Length
- D = Feature dimension

## latent_features (z)
- Latent features are unique representations of an audio sample in N-dimension (here 512-dimensional vector).
- for "namaskar" with each letter representing one vector pos z = [z1, z2, z3, z4, z5, z6, z7, z8]
- say we have masked_indices of shape = (8, 1) given as mask_indices = [1, 1, 0, 0, 1, 0, 1, 1]
- so z_masked = [z1, z2, 0, 0, z5, 0, z7, z8]

## context (c)
- we now have a masked input which has some missing values. now the model needs to understand the context through z_masked.
- context (c) is programmed to understand what z_masked is saying so far.
- c tries to understand and contextualize "na--k-ar"

## predicted queries (q)
- we now have latent_features and context of masked_latent_features, now the model needs to predict based on context
- the model tries to predict queries (q) to fill the masked indices, say we got q = "--ma-k--"

## contrastive loss ($L_{SSL}$)
- we already have latent_features which is the tokenized audio sample, and now predicted queries.

$\hat z_i$ = q = predicted masked representations <br>
$z_i^+$ = not z_masked = whatever values/tokens were masked <br>
$z_i^-$ = negative samples = we are intended to give the model multiple wrong choices so that it learns to pick the most appropriate one (there are no correct answers) <br>
but, these features (instead of manual intervention) are fed to the formula using a random masked values from $z_k$ as a negative sample

In [5]:
# vectorization
def contrastive_loss(z, q, mask=None, temperature=0.1):
    """
    z: latent features
    q: predicted queries
    mask: hidden features
    """
    z = F.normalize(z, dim=-1)
    q = F.normalize(q, dim=-1)

    z_cpu = z.detach().cpu()
    q_cpu = q.detach().cpu()

    # batch matrxi-matrix product
    sim_cpu = torch.bmm(q_cpu, z_cpu.transpose(1,2)) / temperature
    sim = sim_cpu.to(z.device, non_blocking=True)

    sim_pos = sim.diagonal(dim1=1, dim2=2)

    logsum_exp = torch.logsumexp(sim, dim=-1)
    loss = - (sim_pos - logsum_exp)

    if mask is not None:
        loss = loss * mask.float()

    return loss.mean()

In [6]:
def contrastive_loss_chunked(z, q, mask=None, temperature=0.1, chunk_size=256):
    """
        z: [B, T, D] target latent features
        q: [B, T, D] predicted query features
        mask: optional [B, T] mask
        temperature: softmax temperature
        chunk_size: number of timesteps to process per chunk
    """
    B, T, D = z.shape
    z = F.normalize(z, dim=-1)
    q = F.normalize(q, dim=-1)

    total_loss = 0.0
    total_count = 0

    for start in range(0, T, chunk_size):
        end = min(start + chunk_size, T)
        q_chunk = q[:, start:end, :]  # [B, chunk, D]

        # Compute similarity only for this chunk
        sim = torch.bmm(q_chunk, z.transpose(1, 2)) / temperature  # [B, chunk, T]
        sim_pos = torch.sum(q_chunk * z[:, start:end, :], dim=-1) / temperature  # [B, chunk]
        logsum_exp = torch.logsumexp(sim, dim=-1)  # [B, chunk]
        loss_chunk = -(sim_pos - logsum_exp)  # [B, chunk]

        if mask is not None:
            loss_chunk = loss_chunk * mask[:, start:end].float()

        total_loss += loss_chunk.sum()
        total_count += loss_chunk.numel()

        # optional: free intermediate tensor memory early
        del q_chunk, sim, sim_pos, logsum_exp, loss_chunk
        torch.cuda.empty_cache()

    return total_loss / total_count

In [7]:
path = os.path.join('data')
batch_size = 2

# target sampling rate is defaulted at 16kHz so manually defining 16000 is not necessary
# this parameter is added in case we decide to change to working sampling rate to less than or greater than 16kHz

train_dataset = AudioDataset(audio_path=path)
train_dl = DataLoader(
    dataset = train_dataset,
    batch_size = batch_size,
    pin_memory = True,
    collate_fn = collate_padding, 
    shuffle=True
)

100%|████████████████████████████████████████████████████████████████████████████| 607/607 [00:00<00:00, 288721.09it/s]
100%|████████████████████████████████████████████████████████████████████████████| 642/642 [00:00<00:00, 287256.58it/s]
100%|████████████████████████████████████████████████████████████████████████████| 593/593 [00:00<00:00, 595029.25it/s]
100%|████████████████████████████████████████████████████████████████████████████| 652/652 [00:00<00:00, 550681.88it/s]
100%|████████████████████████████████████████████████████████████████████████████| 614/614 [00:00<00:00, 276765.47it/s]
100%|████████████████████████████████████████████████████████████████████████████| 625/625 [00:00<00:00, 257458.26it/s]
100%|████████████████████████████████████████████████████████████████████████████| 599/599 [00:00<00:00, 265636.30it/s]
100%|████████████████████████████████████████████████████████████████████████████| 624/624 [00:00<00:00, 312320.49it/s]
100%|███████████████████████████████████

In [8]:
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = SSLAutoregressiveModel()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model = model.to(device)
model

SSLAutoregressiveModel(
  (encoder): FeatureEncoder(
    (encoder): Sequential(
      (0): Conv1d(1, 128, kernel_size=(7,), stride=(5,), padding=(3,))
      (1): ReLU()
      (2): Conv1d(128, 128, kernel_size=(5,), stride=(4,), padding=(2,))
      (3): ReLU()
    )
  )
  (context): AutoregressiveContext(
    (gru): GRU(128, 128, batch_first=True)
  )
  (predictor): ContrastivePredictor(
    (project): Linear(in_features=128, out_features=128, bias=True)
  )
  (target_proj): Linear(in_features=128, out_features=128, bias=True)
)

In [9]:
def dummy_test():
    # Testing if model is working
    model.eval()
    
    batch_size = 16
    seq_len = 16000
    
    dummy_input = torch.randn(batch_size, 1, seq_len).to(device)
    z = model.encoder(dummy_input).transpose(1, 2)
    B, T_enc, F = z.shape
    
    mask = compute_mask_indices(B, T_enc, mask_prob=0.065, mask_length=10, device=device)
    z_out, q_out, mask_out = model(dummy_input)
    
    print("z_out:", z_out.shape)
    print("q_out:", q_out.shape)

In [10]:
def train(model, train_dl, loss_fn, epochs, optimizer, scaler, name, device):
    
    accum = 32
    model.train()

    for i, batch in enumerate(tqdm(train_dl)):
        batch = batch.to(device)
        with torch.autocast(device):
            z, q, mask = model(batch)
            loss = loss_fn(z, q, mask, chunk_size=128) / accum
        scaler.scale(loss).backward()

        if (i+1) % accum == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()


In [11]:
name = "ssl_encoder"
epochs = 5
scaler = torch.GradScaler(device)

train(model, train_dl, contrastive_loss_chunked, epochs, optimizer, scaler, name, device)

100%|██████████████████████████████████████████████████████████████████████████████| 4968/4968 [10:12<00:00,  8.12it/s]
