In [1]:
from src.requirements import *
from src.audio_handler import *

## SSL Pipeline
- Feature Encoder with 1-channel input (mono audio)
- Autoregressive Context NN using Gated Recurrent Network (GRU) (might remove this)
- Contrastive Predictor (what dimensions i get and what i actually need to get)
- SSL Autoregressive model: encoder -> context -> predictor -> z, q, masked_indices (z = only predicted masked output, q = predicted values, masked_indices = indices that were masked in the original which was then predicted by the model)

In [None]:
# Feature Encoder (SSL)
class FeatureEncoder(nn.Module):
    def __init__(self, in_channels=1, hidden_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels, hidden_dim, kernel_size=10, stride=5, padding=3),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=8, stride=4, padding=2),
            nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x)

# Context Network (RNN)
class AutoregressiveContext(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)

    def forward(self, z):
        # natural autoregressive context representation
        output, _ = self.gru(z)
        return output

# Contrastive Head
class ContrastivePredictor(nn.Module):
    def __init__(self, hidden_dim, proj_dim):
        super().__init__()
        self.project = nn.Linear(hidden_dim, proj_dim)

    def forward(self, x):
        return self.project(x)

# Full SSL Encoder Model
class SSLAutoregressiveModel(nn.Module):
    def __init__(self, feat_dim=512, proj_dim=256):
        super().__init__()
        self.encoder = FeatureEncoder()
        self.context = AutoregressiveContext(feat_dim, feat_dim)
        self.predictor = ContrastivePredictor(feat_dim, proj_dim)

    def forward(self, x, mask_indices):
        # z -> true latent features
        z = self.encoder(x).transpose(1, 2)  # [B, T, F]
        z_masked = z.clone()

        # masking
        z_masked[torch.arange(z.size(0)).unsqueeze(1), mask_indices] = 0

        # c -> context
        c = self.context(z_masked)
        
        # q -> predicted queries
        q = self.predictor(c)

        return z, q, mask_indices

## Contrastive loss
$$ L_{SSL} = \sum_i -log \frac{exp(sin(\hat z_i, z_i^+)/\tau)}{\sum_j exp(sin(\hat z_i, z_j^-)/\tau)} $$

- B = Batch size
- T = Temporal Length
- D = Feature dimension

## latent_features (z)
- Latent features are unique representations of an audio sample in N-dimension (here 512-dimensional vector).
- for "namaskar" with each letter representing one vector pos z = [z1, z2, z3, z4, z5, z6, z7, z8]
- say we have masked_indices of shape = (8, 1) given as mask_indices = [1, 1, 0, 0, 1, 0, 1, 1]
- so z_masked = [z1, z2, 0, 0, z5, 0, z7, z8]

## context (c)
- we now have a masked input which has some missing values. now the model needs to understand the context through z_masked.
- context (c) is programmed to understand what z_masked is saying so far.
- c tries to understand and contextualize "na--k-ar"

## predicted queries (q)
- we now have latent_features and context of masked_latent_features, now the model needs to predict based on context
- the model tries to predict queries (q) to fill the masked indices, say we got q = "--ma-k--"

## contrastive loss ($L_{SSL}$)
- we already have latent_features which is the tokenized audio sample, and now predicted queries.

$\hat z_i$ = q = predicted masked representations <br>
$z_i^+$ = not z_masked = whatever values/tokens were masked <br>
$z_i^-$ = negative samples = we are intended to give the model multiple wrong choices so that it learns to pick the most appropriate one (there are no correct answers) <br>
but, these features (instead of manual intervention) are fed to the formula using a random masked values from $z_k$ as a negative sample

In [None]:
# Contrastive Loss Function
# contrastive_loss(latent_features (z), predicted_queries (q), masked_indices (mask_indices), temperature)
# B = Batch size
# T = Temporal Length
# D = Feature dimension

def contrastive_loss(latent_features, predicted_queries, mask_indices, temperature=0.1):
    B, T, D = predicted_queries.shape
    loss = 0
    count = 0
    for b in range(B):
        for t in mask_indices[b]:
            positive = latent_features[b, t]
            query = predicted_queries[b, t]
            sim_pos = F.cosine_similarity(query, positive, dim=0)
            sim_all = F.cosine_similarity(query.unsqueeze(0), latent_features[b], dim=1)
            numerator = torch.exp(sim_pos / temperature)
            denominator = torch.exp(sim_all / temperature).sum()
            loss += -torch.log(numerator / denominator)
            count += 1
    return loss / count

# Indices to mask are chosen at random with probability of current index being chosen as 6.5%

def compute_mask_indices(B, T, mask_prob=0.065, mask_length=10):
    mask = torch.zeros(B, T, dtype=torch.bool)

    num_masked_spans = int((T * mask_prob) / mask_length)

    for b in range(B):
        span_starts = torch.randperm(T - mask_length)[:num_masked_spans]
        for start in span_starts:
            mask[b, start : start + mask_length] = True

    return mask

In [None]:
path = os.path.join('data')
batch = 16

# device parameter is not implemented yet
# target sampling rate is defaulted at 16kHz so manually defining 16000 is not necessary
# this parameter is added in case we decide to change to working sampling rate to less than or greater than 16kHz

train_dataset = AudioDataset(audio_path=path, device='gpu', target_sr=16000)

In [None]:
train_dl = DataLoader(dataset=train_dataset, batch_size=batch, shuffle=True)

In [None]:
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = SSLAutoregressiveModel()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model = model.to(device)
model

In [None]:
# Testing if model is working
model.eval()
dummy_input = torch.randn().to(device)
outputs = model(dummy_input)
print("Output shape: ", outputs.data)

In [None]:
for batch in train_loader:
    masked_indices = generate_mask_indices(batch)

    latent_features, predicted_queries, masked_indices = model(batch.unsqueeze(0), masked_indices)
    loss = contrastive_loss(latent_features, predicted_queries, masked_indices)

    optimizer.zero_grad()
    loss.backward()
    optimtizer.step()