In [1]:
import torch
import pytorch_lightning as pl
import soundfile as sf
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import Wav2Vec2FeatureExtractor, BertTokenizerFast, Wav2Vec2Model, BertModel, BertConfig
from sentence_transformers.util import cos_sim, semantic_search
from metrics import mean_reciprocal_rank
from config import ParallelSpeechAndTextModelPretrainingConfig

In [2]:
torch.cuda.device_count()

1

In [151]:
from typing import Dict, List, Union
from transformers import BertTokenizerFast, Wav2Vec2FeatureExtractor

class LibriPreprocessor:
  def __init__(self):
    self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    self.extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base')
  
  
  def speech_file_to_array_fn(self, data):
    speech_array, sampling_rate = sf.read(data["file"])
    data["speech"] = speech_array
    data["sampling_rate"] = sampling_rate
    data["target_text"] = data["text"]
    return data
    
    
  def prepare_dataset(self, data):    
    # check that all files have the correct sampling rate
    assert (
        len(set(data["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {self.extractor.feature_extractor.sampling_rate}."

    data["input_values"] = self.extractor(data["speech"], sampling_rate=data["sampling_rate"][0]).input_values
    
    tokenized_batch = self.tokenizer(data["target_text"], padding='longest', max_length=128, pad_to_max_length=False)
    data['input_ids'] = tokenized_batch['input_ids']
    data['attention_mask_text'] = tokenized_batch['attention_mask']
    data['token_type_ids_text'] = tokenized_batch['token_type_ids']
    
    return data


  def __call__(
    self,
    batch: List[Dict[str, Union[List[int], torch.Tensor]]],
    ) -> Dict[str, torch.Tensor]:
    """
    Collate function to be used when training with PyTorch Lightning.
    Args:
        extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
            The processor used for proccessing the data.
        tokenizer (:class:`~transformers.BertTokenizerFast`)
            The tokenizer used for proccessing the data.
        features (:obj:`List[Dict[str, Union[List[int], torch.Tensor]]]`):
            A list of features to be collated.
    Returns:
        :obj:`Dict[str, torch.Tensor]`: A dictionary of tensors containing the collated features.
    """ 
    input_features = [{"input_values": feature["input_values"]} for feature in batch]
    input_sentences = [{"input_ids": feature["input_ids"]} for feature in batch]
    
    speech_batch = self.extractor.pad(
        input_features,
        padding='longest',
        return_tensors="pt",
        )
    text_batch = self.tokenizer.pad(
        input_sentences,
        padding='longest',
        return_tensors='pt'
    )
    
    return speech_batch, text_batch



class LibriSpeechDataset(Dataset):
  def __init__(self, libri_dataset):
    self.libri_dataset = libri_dataset
  
  
  def __len__(self):
    return len(self.libri_dataset)
  
  
  def __getitem__(self, index):
    return self.libri_dataset[index]

In [152]:
libri = load_dataset('patrickvonplaten/librispeech_asr_dummy', 'clean', split='validation')
preprocessor = LibriPreprocessor()
libri_read = libri.map(preprocessor.speech_file_to_array_fn, remove_columns=['chapter_id', 'id', 'speaker_id'])
libri_prepared = libri_read.map(preprocessor.prepare_dataset, batch_size=8, num_proc=4, batched=True)
libri_dataset = LibriSpeechDataset(libri_prepared)
libri_dataloader = DataLoader(libri_dataset, batch_size=4, collate_fn=preprocessor)

Reusing dataset librispeech_asr (C:\Users\marco\.cache\huggingface\datasets\patrickvonplaten___librispeech_asr\clean\2.1.0\f2c70a4d03ab4410954901bde48c54b85ca1b7f9bf7d616e7e2a72b5ee6ddbfc)
100%|██████████| 73/73 [00:00<00:00, 572.42ex/s]


In [158]:
for i, batch in enumerate(libri_dataloader):
  speech, text = batch
  print(speech['input_values'].shape)
  print(text['input_ids'].shape)
  if i == 2:
    break

torch.Size([4, 199760])
torch.Size([4, 86])
torch.Size([4, 470400])
torch.Size([4, 86])
torch.Size([4, 292640])
torch.Size([4, 51])


In [151]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
bert = BertModel.from_pretrained('bert-base-uncased')
wav2vec2 = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
compute_grad_on_last_n_layers = 1

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2Model: ['lm_head.bias', 'lm_head.weight']
- Thi

In [147]:
count_parameters = lambda model : {'requires_grad':sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6,
                                   'does_not_require_grad':sum(p.numel() for p in model.parameters() if not p.requires_grad)/1e6}

In [148]:
for module in wav2vec2._modules:
  print(module)

print(count_parameters(wav2vec2))

for param in wav2vec2.feature_extractor.parameters():
  param.requires_grad = False
for param in wav2vec2.feature_projection.parameters():
  param.requires_grad = False
for param in wav2vec2.encoder.pos_conv_embed.parameters():
  param.requires_grad = False
for param in wav2vec2.encoder.layer_norm.parameters():
  param.requires_grad = False
for param in wav2vec2.encoder.dropout.parameters():
  param.requires_grad = False
  
for i, encoder_layer in enumerate(wav2vec2.encoder.layers._modules):
  if i < (len(wav2vec2.encoder.layers._modules) - compute_grad_on_last_n_layers):
    for param in wav2vec2.encoder.layers[i].parameters():
      param.requires_grad = False

print(count_parameters(wav2vec2))

feature_extractor
feature_projection
encoder
{'requires_grad': 94.371712, 'does_not_require_grad': 0.0}
{'requires_grad': 7.08864, 'does_not_require_grad': 87.283072}


In [149]:
for module in bert._modules:
  print(module)

print(count_parameters(bert))

for param in bert.embeddings.parameters():
  param.requires_grad = False
for i, encoder_layer in enumerate(bert.encoder.layer._modules):
  if i < (len(bert.encoder.layer) - compute_grad_on_last_n_layers):
    for param in bert.encoder.layer[i].parameters():
      param.requires_grad = False
      
print(count_parameters(bert))

embeddings
encoder
pooler
{'requires_grad': 109.48224, 'does_not_require_grad': 0.0}
{'requires_grad': 7.678464, 'does_not_require_grad': 101.803776}
