In [None]:
from src.requirements import *
from src.audio_handler import *

In [None]:
def compute_mask_indices(B, T, mask_prob, mask_length, device="cpu"):

    mask = torch.zeros(B, T, dtype=torch.bool, device=device)

    num_masked_spans = int((T * mask_prob) / mask_length)
    for b in range(B):
        starts = torch.randint(0, T - mask_length, (num_masked_spans,))
        for s in starts:
            mask[b, s : s + mask_length] = True
    return mask

## SSL Pipeline
- Feature Encoder with 1-channel input (mono audio)
- Autoregressive Context NN using Gated Recurrent Network (GRU) (might remove this)
- Contrastive Predictor (what dimensions i get and what i actually need to get)
- SSL Autoregressive model: encoder -> context -> predictor -> z, q, masked_indices (z = only predicted masked output, q = predicted values, masked_indices = indices that were masked in the original which was then predicted by the model)

In [None]:
# Feature Encoder (SSL)
class FeatureEncoder(nn.Module):
    def __init__(self, in_channels=1, hidden_dim=256):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels, hidden_dim, kernel_size=10, stride=5, padding=3),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=8, stride=4, padding=2),
            nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x)

# Context Network (RNN)
class AutoregressiveContext(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)

    def forward(self, z):
        # natural autoregressive context representation
        output, _ = self.gru(z)
        return output

# Contrastive Head
class ContrastivePredictor(nn.Module):
    def __init__(self, hidden_dim, proj_dim):
        super().__init__()
        self.project = nn.Linear(hidden_dim, proj_dim)

    def forward(self, x):
        return self.project(x)

# Full SSL Encoder Model
class SSLAutoregressiveModel(nn.Module):
    def __init__(self, feat_dim=256, proj_dim=256):
        super().__init__()
        self.encoder = FeatureEncoder()
        self.context = AutoregressiveContext(feat_dim, feat_dim)
        self.predictor = ContrastivePredictor(feat_dim, proj_dim)
        self.target_proj = nn.Linear(feat_dim, proj_dim)

    def forward(self, x, mask=None, mask_prob=0.065, mask_length=10):
        # z -> true latent features
        z = self.encoder(x)
        z = z.transpose(1, 2)
        B, T, F = z.shape

        # masking
        z_masked = z.clone()
        
        if mask is None:
            mask = compute_mask_indices(B, T, mask_prob, mask_length, device=z.device)
        
        # z_masked[torch.arange(z.size(0)).unsqueeze(1), mask_indices] = 0
        z_masked[mask.unsqueeze(-1).expand_as(z_masked)] = 0

        # c -> context
        c = self.context(z_masked)
        
        # q -> predicted queries
        q = self.predictor(c)

        # projection
        z_proj = self.target_proj(z)

        return z_proj, q, mask

## Contrastive loss
$$ L_{SSL} = \sum_i -log \frac{exp(sin(\hat z_i, z_i^+)/\tau)}{\sum_j exp(sin(\hat z_i, z_j^-)/\tau)} $$

- B = Batch size
- T = Temporal Length
- D = Feature dimension

## latent_features (z)
- Latent features are unique representations of an audio sample in N-dimension (here 512-dimensional vector).
- for "namaskar" with each letter representing one vector pos z = [z1, z2, z3, z4, z5, z6, z7, z8]
- say we have masked_indices of shape = (8, 1) given as mask_indices = [1, 1, 0, 0, 1, 0, 1, 1]
- so z_masked = [z1, z2, 0, 0, z5, 0, z7, z8]

## context (c)
- we now have a masked input which has some missing values. now the model needs to understand the context through z_masked.
- context (c) is programmed to understand what z_masked is saying so far.
- c tries to understand and contextualize "na--k-ar"

## predicted queries (q)
- we now have latent_features and context of masked_latent_features, now the model needs to predict based on context
- the model tries to predict queries (q) to fill the masked indices, say we got q = "--ma-k--"

## contrastive loss ($L_{SSL}$)
- we already have latent_features which is the tokenized audio sample, and now predicted queries.

$\hat z_i$ = q = predicted masked representations <br>
$z_i^+$ = not z_masked = whatever values/tokens were masked <br>
$z_i^-$ = negative samples = we are intended to give the model multiple wrong choices so that it learns to pick the most appropriate one (there are no correct answers) <br>
but, these features (instead of manual intervention) are fed to the formula using a random masked values from $z_k$ as a negative sample

In [None]:
# inefficient approach
# def contrastive_loss_ineffa(latent_features, predicted_queries, mask_indices, temperature=0.1):
  
#     B, T, D = predicted_queries.shape
#     loss = 0
#     count = 0
#     for b in range(B):
#         for t in mask_indices[b]:
#             positive = latent_features[b, t]
#             query = predicted_queries[b, t]
#             sim_pos = F.cosine_similarity(query, positive, dim=0)
#             sim_all = F.cosine_similarity(query.unsqueeze(0), latent_features[b], dim=1)
#             numerator = torch.exp(sim_pos / temperature)
#             denominator = torch.exp(sim_all / temperature).sum()
#             loss += -torch.log(numerator / denominator)
#             count += 1
#     return loss / count

# vectorization
def contrastive_loss(z, q, mask=None, temperature=0.1):
    """
    z: latent features
    q: predicted queries
    mask: hidden features
    """
    z = F.normalize(z, dim=-1)
    q = F.normalize(q, dim=-1)

    sim = torch.bmm(q, z.transpose(1,2)) / temperature

    sim_pos = sim.diagonal(dim1=1, dim2=2)

    logsum_exp = torch.logsumexp(sim, dim=-1)
    loss = - (sim_pos - logsum_exp)

    if mask is not None:
        loss = loss * mask.float()

    return loss.mean()

In [None]:
path = os.path.join('data')
batch_size = 4

# device parameter is not implemented yet
# target sampling rate is defaulted at 16kHz so manually defining 16000 is not necessary
# this parameter is added in case we decide to change to working sampling rate to less than or greater than 16kHz

train_dataset = AudioDataset(audio_path=path)
train_dl = DataLoader(
    dataset = train_dataset,
    batch_size = batch_size,
    num_workers = 4,
    pin_memory = True,
    prefetch_factor = 2,
    collate_fn = collate_padding
)

In [None]:
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = SSLAutoregressiveModel()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model = model.to(device)
model

In [None]:
def dummy_test():
    # Testing if model is working
    model.eval()
    
    batch_size = 16
    seq_len = 16000
    
    dummy_input = torch.randn(batch_size, 1, seq_len).to(device)
    z = model.encoder(dummy_input).transpose(1, 2)
    B, T_enc, F = z.shape
    
    mask = compute_mask_indices(B, T_enc, mask_prob=0.065, mask_length=10, device=device)
    z_out, q_out, mask_out = model(dummy_input)
    
    print("z_out:", z_out.shape)
    print("q_out:", q_out.shape)

In [None]:
scaler = torch.GradScaler(device)

model.train()
with profile(
    activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes = True,
    profile_memory = True,
    with_stack = True) as prof:
    
    for batch in tqdm(train_dl):
        batch = batch.to(device)
        optimizer.zero_grad()
        with torch.autocast(device):
            z, q, mask = model(batch)
            loss = contrastive_loss(z, q, mask)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))