In [None]:
from src.requirements import *
from src.audio_handler import *

In [None]:
path = os.path.join('data')
batch = 16

# device parameter is not implemented yet
# target sampling rate is defaulted at 16kHz so manually defining 16000 is not necessary
# this parameter is added in case we decide to change to working sampling rate to less than or greater than 16kHz

train_dataset = AudioDataset(audio_path=path, device='gpu', target_sr=16000)

In [None]:
train_dl = DataLoader(dataset=train_dataset, batch_size=batch, shuffle=True)

## SSL Pipeline
- Feature Encoder with 1-channel input (mono audio)
- Autoregressive Context NN using Gated Recurrent Network (GRU) (might remove this)
- Contrastive Predictor (what dimensions i get and what i actually need to get)
- SSL Autoregressive model: encoder -> context -> predictor -> z, q, masked_indices (z = only predicted masked output, q = predicted values, masked_indices = indices that were masked in the original which was then predicted by the model)
- Contrastive Loss: based on formula

In [None]:
# Feature Encoder (SSL)
class FeatureEncoder(nn.Module):
    def __init__(self, in_channels=1, hidden_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels, hidden_dim, kernel_size=10, stride=5, padding=3),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=8, stride=4, padding=2),
            nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x)

# Context Network (Transformer or RNN)
class AutoregressiveContext(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)

    def forward(self, z):
        output, _ = self.gru(z)
        return output  # autoregressive context representation

# Contrastive Head
class ContrastivePredictor(nn.Module):
    def __init__(self, hidden_dim, proj_dim):
        super().__init__()
        self.project = nn.Linear(hidden_dim, proj_dim)

    def forward(self, x):
        return self.project(x)

# Full SSL Encoder Model
class SSLAutoregressiveModel(nn.Module):
    def __init__(self, feat_dim=512, proj_dim=256):
        super().__init__()
        self.encoder = FeatureEncoder()
        self.context = AutoregressiveContext(feat_dim, feat_dim)
        self.predictor = ContrastivePredictor(feat_dim, proj_dim)

    def forward(self, x, mask_indices):
        z = self.encoder(x).transpose(1, 2)  # [B, T, F]
        z_masked = z.clone()

        # Apply masking
        z_masked[torch.arange(z.size(0)).unsqueeze(1), mask_indices] = 0

        c = self.context(z_masked)
        q = self.predictor(c)  # queries

        return z, q, mask_indices

In [None]:
# Contrastive Loss Function
def contrastive_loss(z, q, mask_indices, temperature=0.1):
    B, T, D = q.shape
    loss = 0
    count = 0
    for b in range(B):
        for t in mask_indices[b]:
            positive = z[b, t]  # original latent
            query = q[b, t]
            sim_pos = F.cosine_similarity(query, positive, dim=0)
            sim_all = F.cosine_similarity(query.unsqueeze(0), z[b], dim=1)
            numerator = torch.exp(sim_pos / temperature)
            denominator = torch.exp(sim_all / temperature).sum()
            loss += -torch.log(numerator / denominator)
            count += 1
    return loss / count

# Indices to mask is chosen at random with probability of current index being chosen as 6.5%
def compute_mask_indices(B, T, mask_prob=0.065, mask_length=10):
    mask = torch.zeros(B, T, dtype=torch.bool)

    num_masked_spans = int((T * mask_prob) / mask_length)

    for b in range(B):
        span_starts = torch.randperm(T - mask_length)[:num_masked_spans]
        for start in span_starts:
            mask[b, start : start + mask_length] = True

    return mask

In [None]:
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = SSLAutoregressiveModel()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model = model.to(device)
model

In [None]:
# Testing if model is working
model.eval()
dummy_input = torch.randn().to(device)
outputs = model(dummy_input)
print("Output shape: ", outputs.data)

In [None]:
for batch in train_loader:
    masked_indices = generate_mask_indices(batch)

    z, q, masked_indices = model(batch.unsqueeze(0), masked_indices)
    loss = contrastive_loss(z, q, masked_indices)

    optimizer.zero_grad()
    loss.backward()
    optimtizer.step()