In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from transformers import WavLMModel, AutoProcessor

class finetune_wavlm(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus")
        self.model.train()
        self.conv1d = nn.Conv1d(768, 512, kernel_size=17, stride=17, padding=1)
        self.linear = nn.Linear(512, 10)
    
    def add_padding(self, audio_feats):
        """
        Add padding to the audio_feats got from wavlm to match seq_len_after_pad along temporal dimension.

        Args:
            audio_feats (list): list of tensors of shape (1, audio_feat_dim, scene_length*(50)-1)
            # ! for memory management you only get paths to the numpy files.

        Returns:
            dict:
                audio_feats (list): list of tensors of shape (1, audio_feat_dim, seq_len_after_pad)
                masks (list): list of tensors of shape (1, seq_len_after_pad)
        """
        num_bins_after_pad = 150
        seq_len_after_pad = 17*num_bins_after_pad # as we have 50 secs of audio
        duplicated_audio_feats = []
        for audio_feat in audio_feats:
            for i in range(0, audio_feat.shape[0], 50):
                dup = audio_feat[i,:].unsqueeze(0)
                audio_feat = torch.cat((audio_feat, dup), dim=0)
            audio_feat = audio_feat.transpose(0, 1).unsqueeze(0)
            duplicated_audio_feats.append(audio_feat)
        audio_feats = duplicated_audio_feats
        # audio_feats shape is 1, 768, scene_length*(50)-1+duplicated
        # I want to pad it to 1, 768, seq_len_after_pad
        audio_feats = [audio_feat.squeeze(0) for audio_feat in audio_feats]
        audio_feats = torch.stack(audio_feats)
        return {"audio_feats": audio_feats}
    
    def forward(self, input,seq_len):
        x, extract_feats = self.model(input, return_dict=False)
        print(f"X shape: {x.shape}")
        # print(f"Mask shape: {mask.shape}")
        # print add padding
        x = self.add_padding(x)
        x = self.conv1d(x["audio_feats"])
        print(x.shape)
        # do mean along temporal dimension to get 1, 512 vector. divide it by seq_len to get mean
        denominator = torch.tensor(seq_len).to('cuda')
        denominator[denominator == 0] = 300 # nothing in audio
        x = torch.sum(x, dim=2)/denominator
        x = self.linear(x)
        return x

model = finetune_wavlm().to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
input = torch.rand(16000*10)
# # pad the tensor to 20 seconds
input = torch.cat((input, torch.zeros(16000*10)), dim=0)
# processor = AutoProcessor.from_pretrained("patrickvonplaten/wavlm-libri-clean-100h-base-plus")
# input = processor(input, return_tensors="pt", padding=True, sampling_rate=16000).input_values.to('cuda')
print(f"input shape: {input.shape}")
output = model(input,seq_len=10)
target = torch.rand(1, 10).to('cuda')
loss = F.binary_cross_entropy_with_logits(output, target)
loss.backward()
optimizer.step()
print(output.shape)

  from .autonotebook import tqdm as notebook_tqdm


input shape: torch.Size([1, 320000])
output: torch.Size([1, 999, 512])


UnboundLocalError: local variable 'x' referenced before assignment