In [2]:
import torch
import pytorch_lightning as pl
import soundfile as sf
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import Wav2Vec2FeatureExtractor, BertTokenizerFast, Wav2Vec2Model, BertModel, BertConfig
from sentence_transformers.util import cos_sim, semantic_search
from metrics import mean_reciprocal_rank

In [9]:
torch.cuda.device_count()

1

In [151]:
from typing import Dict, List, Union
from transformers import BertTokenizerFast, Wav2Vec2FeatureExtractor

class LibriPreprocessor:
  def __init__(self):
    self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    self.extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base')
  
  
  def speech_file_to_array_fn(self, data):
    speech_array, sampling_rate = sf.read(data["file"])
    data["speech"] = speech_array
    data["sampling_rate"] = sampling_rate
    data["target_text"] = data["text"]
    return data
    
    
  def prepare_dataset(self, data):    
    # check that all files have the correct sampling rate
    assert (
        len(set(data["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {self.extractor.feature_extractor.sampling_rate}."

    data["input_values"] = self.extractor(data["speech"], sampling_rate=data["sampling_rate"][0]).input_values
    
    tokenized_batch = self.tokenizer(data["target_text"], padding='longest', max_length=128, pad_to_max_length=False)
    data['input_ids'] = tokenized_batch['input_ids']
    data['attention_mask_text'] = tokenized_batch['attention_mask']
    data['token_type_ids_text'] = tokenized_batch['token_type_ids']
    
    return data


  def __call__(
    self,
    batch: List[Dict[str, Union[List[int], torch.Tensor]]],
    ) -> Dict[str, torch.Tensor]:
    """
    Collate function to be used when training with PyTorch Lightning.
    Args:
        extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
            The processor used for proccessing the data.
        tokenizer (:class:`~transformers.BertTokenizerFast`)
            The tokenizer used for proccessing the data.
        features (:obj:`List[Dict[str, Union[List[int], torch.Tensor]]]`):
            A list of features to be collated.
    Returns:
        :obj:`Dict[str, torch.Tensor]`: A dictionary of tensors containing the collated features.
    """ 
    input_features = [{"input_values": feature["input_values"]} for feature in batch]
    input_sentences = [{"input_ids": feature["input_ids"]} for feature in batch]
    
    speech_batch = self.extractor.pad(
        input_features,
        padding='longest',
        return_tensors="pt",
        )
    text_batch = self.tokenizer.pad(
        input_sentences,
        padding='longest',
        return_tensors='pt'
    )
    
    return speech_batch, text_batch



class LibriSpeechDataset(Dataset):
  def __init__(self, libri_dataset):
    self.libri_dataset = libri_dataset
  
  
  def __len__(self):
    return len(self.libri_dataset)
  
  
  def __getitem__(self, index):
    return self.libri_dataset[index]

In [152]:
libri = load_dataset('patrickvonplaten/librispeech_asr_dummy', 'clean', split='validation')
preprocessor = LibriPreprocessor()
libri_read = libri.map(preprocessor.speech_file_to_array_fn, remove_columns=['chapter_id', 'id', 'speaker_id'])
libri_prepared = libri_read.map(preprocessor.prepare_dataset, batch_size=8, num_proc=4, batched=True)
libri_dataset = LibriSpeechDataset(libri_prepared)
libri_dataloader = DataLoader(libri_dataset, batch_size=4, collate_fn=preprocessor)

Reusing dataset librispeech_asr (C:\Users\marco\.cache\huggingface\datasets\patrickvonplaten___librispeech_asr\clean\2.1.0\f2c70a4d03ab4410954901bde48c54b85ca1b7f9bf7d616e7e2a72b5ee6ddbfc)
100%|██████████| 73/73 [00:00<00:00, 572.42ex/s]


In [158]:
for i, batch in enumerate(libri_dataloader):
  speech, text = batch
  print(speech['input_values'].shape)
  print(text['input_ids'].shape)
  if i == 2:
    break

torch.Size([4, 199760])
torch.Size([4, 86])
torch.Size([4, 470400])
torch.Size([4, 86])
torch.Size([4, 292640])
torch.Size([4, 51])


In [32]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [71]:
sentences = ['This is a test sentence.', 'This is another test sentence.']
model_inputs = tokenizer(sentences, padding='longest', max_length=128, pad_to_max_length=False, return_tensors='pt')
model_outputs = model(input_ids=model_inputs['input_ids'], attention_mask=model_inputs['attention_mask'])
model_outputs['mean_pooled_output'] = torch.mean(model_outputs['last_hidden_state'], dim=1)
print(model_outputs['mean_pooled_output'])

tensor([[ 6.6064e-02, -2.1769e-01, -1.5390e-01,  ..., -2.2918e-01,
         -3.9249e-04,  3.9901e-01],
        [ 1.0360e-01, -2.4060e-01, -5.9160e-02,  ..., -1.4624e-01,
          1.3627e-01,  2.9249e-01]], grad_fn=<MeanBackward1>)


In [44]:
model_outputs['last_hidden_state'].shape

torch.Size([2, 8, 768])

In [63]:
pooled = torch.mean(model_outputs['last_hidden_state'], dim=1)
print(pooled.shape)

torch.Size([2, 768])


In [5]:
positives = torch.rand(4,8)
negatives = positives[torch.randperm(positives.shape[0]),:]
print(positives)
print(negatives)

tensor([[0.5349, 0.1388, 0.1123, 0.2608, 0.4066, 0.1963, 0.2137, 0.1983],
        [0.8943, 0.9790, 0.1911, 0.9959, 0.8279, 0.2848, 0.1843, 0.6564],
        [0.7489, 0.1035, 0.6099, 0.6300, 0.7315, 0.5545, 0.7233, 0.2897],
        [0.0853, 0.2057, 0.4239, 0.2176, 0.1818, 0.4055, 0.3233, 0.5674]])
tensor([[0.5349, 0.1388, 0.1123, 0.2608, 0.4066, 0.1963, 0.2137, 0.1983],
        [0.0853, 0.2057, 0.4239, 0.2176, 0.1818, 0.4055, 0.3233, 0.5674],
        [0.8943, 0.9790, 0.1911, 0.9959, 0.8279, 0.2848, 0.1843, 0.6564],
        [0.7489, 0.1035, 0.6099, 0.6300, 0.7315, 0.5545, 0.7233, 0.2897]])
