In [1]:
import torch
import numpy as np
from typing import Union, List, Dict
from transformers import BertTokenizerFast

class LibriPreprocessor:
  def __init__(self):
    self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')


  
  
  def pad_latent_features(self, latent_features, padding='longest', return_tensors="pt"):
    padding_value = 0.0
    if padding == 'longest':
      longest_latent_feature = max(len(item['latent_features']) for item in latent_features)

    padded_features = []
    for item in latent_features:
      latent_features_as_ndarray = np.array(item['latent_features']).astype(np.float32)
      padded_item = np.pad(latent_features_as_ndarray, 
                           ((0, longest_latent_feature - latent_features_as_ndarray.shape[0]), (0, 0)), 
                           mode='constant', 
                           constant_values=padding_value)
      if return_tensors == "pt":
        padded_item = torch.from_numpy(padded_item).to(torch.float32)
      padded_features.append(padded_item)
      
    if return_tensors == "pt":
      padded_features = torch.stack(padded_features)
      
    return padded_features


  def __call__(
    self,
    batch: List[Dict[str, Union[List[int], torch.Tensor]]],
    ) -> Dict[str, torch.Tensor]:
    """
    Collate function to be used when training with PyTorch Lightning.
    Args:
        extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
            The processor used for proccessing the data.
        tokenizer (:class:`~transformers.BertTokenizerFast`)
            The tokenizer used for proccessing the data.
        features (:obj:`List[Dict[str, Union[List[int], torch.Tensor]]]`):
            A list of features to be collated.
    Returns:
        :obj:`Dict[str, torch.Tensor]`: A dictionary of tensors containing the collated features.
    """ 
    latent_features = [{"latent_features": feature["latent_features"]} for feature in batch]
    # input_values = [{"input_values": feature["input_values"]} for feature in batch]
    input_sentences = [{"input_ids": feature["input_ids"]} for feature in batch]
    
    text_batch = self.tokenizer.pad(
        input_sentences,
        padding='longest',
        return_tensors='pt'
    )
    
    speech_batch = self.pad_latent_features(
        latent_features,
        padding='longest',
        return_tensors="pt",
    )
    
    # speech_batch = self.extractor.pad(
    #     input_values,
    #     padding='longest',
    #     return_tensors="pt",
    # )
    
    return speech_batch, text_batch

In [2]:
from torch.utils.data import Dataset, DataLoader

class LibriSpeechDataset(Dataset):
  def __init__(self, libri_dataset):
    self.libri_dataset = libri_dataset
  
  
  def __len__(self):
    return len(self.libri_dataset)
  
  
  def __getitem__(self, index):
    return self.libri_dataset[index]
  

In [4]:
import os
from datasets import concatenate_datasets, load_from_disk

rootdir = 'E:/Machine Learning/Datasets/librispeech/'

loaded_single_shard = load_from_disk('E:/Machine Learning/Datasets/librispeech/0')


In [5]:
import os
from datasets import concatenate_datasets, load_from_disk

libri_shards_path = 'E:/Machine Learning/Datasets/librispeech/' # '../data/libri_small_shards'
libri_shards_list = []
for i in range(len(next(os.walk(libri_shards_path))[1])):
    loaded_libri_shard = load_from_disk(f"{libri_shards_path}/{i}/")
    libri_shards_list.append(loaded_libri_shard)

libri_reconstructed = concatenate_datasets(libri_shards_list)

In [6]:
for i in range(1):
    print(i)

0


In [7]:
libri_reconstructed[0].keys()

dict_keys(['attention_mask_text', 'input_ids', 'input_values', 'latent_features', 'sampling_rate', 'speech', 'token_type_ids_text'])

In [None]:
preprocessor = LibriPreprocessor()
libri_dataset = LibriSpeechDataset(libri_reconstructed)
libri_dataloader = DataLoader(libri_dataset, batch_size=8, collate_fn=preprocessor)

In [None]:
from transformers import Wav2Vec2Model

wav2vec2 = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base')
feature_projection = wav2vec2.feature_projection
encoder = wav2vec2.encoder

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
feature_projection.to(device)
encoder.to(device)

for batch in libri_dataloader:
  inputs = batch[0].to(device)
  outputs = feature_projection(inputs)
  outputs = encoder(outputs[0])
  pooled_output = torch.mean(outputs['last_hidden_state'], dim=1)
  print(pooled_output.size())