In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import Wav2Vec2Model

In [None]:
# Load pre-trained wav2vec 2.0 model
model = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')

# Freeze all the parameters in the model except for the last layer
for param in model.parameters():
    param.requires_grad = False
model.encoder.layer[-1].requires_grad = True

In [None]:
# Define contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    
    def forward(self, x1, x2, label):
        distance = torch.nn.functional.pairwise_distance(x1, x2)
        loss = torch.mean((1 - label) * torch.pow(distance, 2) + 
                          label * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2))
        return loss

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Define batch size and number of epochs
batch_size = 32
num_epochs = 10

# Define training data loader
train_dataset = MyDataset(...)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Train the encoder using contrastive learning
model.train()
for epoch in range(num_epochs):
    for batch in train_loader:
        clean_audio = batch[0]
        noisy_audio = batch[1]
        
        # Get representations of clean and noisy audio
        clean_rep = model(clean_audio).last_hidden_state
        noisy_rep = model(noisy_audio).last_hidden_state
        
        # Create labels for contrastive loss
        labels = torch.ones(batch_size)
        
        # Compute contrastive loss
        loss = ContrastiveLoss()(clean_rep, noisy_rep, labels)
        
        # Backward pass and optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Evaluate the trained encoder on downstream speech recognition task
    # ...
    
# Integrate the trained encoder with the decoder for speech recognition task
# ...

In [None]:
''' ------------------------------------------------------------------------------- '''

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import Wav2Vec2Model

# Load pre-trained wav2vec 2.0 model
model = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')

# Define contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5):
        super().__init__()
        self.margin = margin
        
    def forward(self, z1, z2, label):
        dist = torch.norm(z1 - z2, dim=1)
        loss = torch.mean(label * dist + (1 - label) * torch.max(torch.tensor(0.0), self.margin - dist))
        return loss

# Define training dataset
# Assume clean_audio and noisy_audio are lists of file paths
train_dataset = torch.utils.data.TensorDataset(torch.tensor(clean_audio), torch.tensor(noisy_audio))

# Define data loader
batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Define optimizer
lr = 1e-4
optimizer = optim.Adam(model.parameters(), lr=lr)

# Define number of epochs
num_epochs = 10

# Training loop
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (clean_batch, noisy_batch) in enumerate(train_loader):
        optimizer.zero_grad()
        
        # Generate representations for clean and noisy audio
        with torch.no_grad():
            z1 = model(clean_batch)['last_hidden_state']
            z2 = model(noisy_batch)['last_hidden_state']
        
        # Create positive and negative labels
        label = torch.ones(batch_size)
        neg_label = torch.zeros(batch_size)
        
        # Compute contrastive loss
        loss = ContrastiveLoss()(z1, z2, label)
        loss += ContrastiveLoss()(z1, z1.flip(0), neg_label) # Negative pairs
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
    
    # Save encoder
    torch.save(model.encoder, 'contrastive_encoder.pt')
    
# Load trained encoder and decoder for downstream task
encoder = torch.load('contrastive_encoder.pt')
decoder = nn.Linear(768, num_classes)
model = nn.Sequential(encoder, decoder)
