In [1]:
import os
import torch
from transformers import Wav2Vec2FeatureExtractor, BertTokenizerFast, BertModel
from transformers.models.hubert.modeling_hubert import HubertPreTrainedModel, HubertFeatureEncoder
from datasets import load_dataset, load_from_disk

In [2]:
class HubertConvFeatureExtractorWrapper(HubertPreTrainedModel):
    # named HubertFeatureEncoder on huggingface
    def __init__(
        self,
        config
        ):
        super().__init__(config)
        
        self.feature_extractor = HubertFeatureEncoder(config)
        
        self.post_init()
    
    
    def forward(self, input_values: torch.Tensor) -> torch.Tensor:
        return self.feature_extractor(input_values)

In [19]:
class LibriPreprocessor:
    def __init__(
        self,
        dataset=None,
        max_length: int = 16,
        dataset_name: str = 'librispeech_asr',
        save_dir: str = 'E:/Datasets/',
        text_model_name: str = 'google/bert_uncased_L-2_H-768_A-12',
    ):
        assert torch.cuda.is_available(), "CUDA is not available, should run on GPU"
        
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('ntu-spml/distilhubert')
        self.feature_encoder = HubertConvFeatureExtractorWrapper.from_pretrained('ntu-spml/distilhubert')
        self.feature_encoder.eval()
        
        self.tokenizer = BertTokenizerFast.from_pretrained(text_model_name)
        self.text_model = BertModel.from_pretrained(text_model_name)
        self.text_model.eval()
        
        self.dataset_name = dataset_name
        self.cache_dir = save_dir
        self.save_dir = save_dir
        
        self.max_length = max_length*16000
        
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
                
        self.dataset = None
        
        
    def load_dataset(self, dataset_split: str = 'train.360'):
        self.dataset = load_dataset(self.dataset_name, 'clean', split=dataset_split, cache_dir=self.cache_dir)
        
        
    def _speech_file_to_array(self, data):
        data['speech'] = data['audio']['array']
        data['sampling_rate'] = data['audio']['sampling_rate']
        return data
    
    
    def speech_file_to_array(self, dataset=None, save_to_hd: bool = True):
        if dataset is not None:
            self.dataset = dataset
        self.dataset = self.dataset.map(
            self._speech_file_to_array, 
            remove_columns=['file', 'audio', 'speaker_id', 'chapter_id', 'id']
        )
        if save_to_hd:
            self.dataset.save_to_disk(f'{self.save_dir}file_to_speech_array/')
        return self.dataset
    
    
    def filter_long_audio(self, dataset=None, max_audio_length: int = 16, save_to_hd: bool = True):
        if dataset is not None:
            self.dataset = dataset
        self.dataset = self.dataset.filter(
            lambda x: len(x['speech'])//x['sampling_rate'] < max_audio_length, 
            keep_in_memory=True
        )
        if save_to_hd:
            self.dataset.save_to_disk(f'{self.save_dir}filtered/')
        return self.dataset
        
        
    def _extract_features_and_tokenize(self, data):
        # check that all files have the correct sampling rate
        assert (
            len(set(data['sampling_rate'])) == 1
        ), f"Make sure all inputs have the same sampling rate of {self.feature_extractor.sampling_rate}."
        
        # extract and pad input values
        input_values = self.feature_extractor(data['speech'], sampling_rate=data['sampling_rate'][0])
        data['input_values'] = input_values.input_values
        padded_input_values = self.feature_extractor(data['speech'], padding=True, return_tensors='pt', sampling_rate=data['sampling_rate'][0])
        
        # compute the latent features from the conv module
        import torch
        with torch.no_grad():
            input_values = padded_input_values['input_values'].to(self.device)
            latent_features = self.feature_encoder(input_values).transpose(1, 2)
            latent_features = latent_features.cpu().numpy()
            data['latent_features'] = latent_features
            
        # tokenize text
        tokenized_batch = self.tokenizer(data['text'], padding='longest', max_length=128, pad_to_max_length=False)
        data['input_ids'] = tokenized_batch['input_ids']
        data['attention_mask_text'] = tokenized_batch['attention_mask']
        data['token_type_ids_text'] = tokenized_batch['token_type_ids']
        
        return data

            
        
    def extract_features_and_tokenize(self, dataset=None, save_to_hd: bool = True):
        if dataset is not None:
            self.dataset = dataset
        self.feature_encoder.cuda()
        self.dataset = self.dataset.map(
            self._extract_features_and_tokenize, 
            batch_size=16, 
            num_proc=1, 
            batched=True, 
            remove_columns=['text', 'sampling_rate'],
            keep_in_memory=True
        )
        if save_to_hd:
            self.dataset.save_to_disk(f'{self.save_dir}features_and_tokens/')
        self.feature_encoder.cpu()
        return self.dataset
    
    
    def _encode_text(self, data):
        import torch
        with torch.no_grad():
            input_ids = torch.tensor(data['input_ids'], dtype=torch.int, device=self.device)
            attention_mask = torch.tensor(data['attention_mask_text'], dtype=torch.int, device=self.device)
            token_type_ids = torch.tensor(data['token_type_ids_text'], dtype=torch.int, device=self.device)
            embeddings = self.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask, 
                token_type_ids=token_type_ids
            )
            embeddings = embeddings.pooler_output.cpu().numpy()
            data['sentence_embedding'] = embeddings
            
            return data
    
    
    def encode_text(self, dataset=None, save_to_hd: bool = True, shard=0):
        if dataset is not None:
            self.dataset = dataset
        self.text_model.cuda()
        self.dataset = self.dataset.map(
            self._encode_text, 
            batch_size=16, 
            num_proc=1, 
            batched=True, 
            remove_columns=['input_ids', 'attention_mask_text', 'token_type_ids_text'],
            keep_in_memory=True
        )
        if save_to_hd:
            self.dataset.save_to_disk(f'{self.save_dir}encoded/{shard}/')
        self.text_model.cpu()
    
    
    def save_dataset(
        self, 
        save_in: str,
        save_path: str,
    ):
        self.dataset.save_to_disk(f'{save_in}/librispeech_asr_encoded')

In [4]:
dataset = load_dataset('librispeech_asr', 'clean', split='train.360', cache_dir='E:/Datasets/librispeech/original/') # <-- this is the correct way to load the dataset

Reusing dataset librispeech_asr (E:/Datasets/librispeech/original/librispeech_asr\clean\2.1.0\14c8bffddb861b4b3a4fcdff648a56980dbb808f3fc56f5a3d56b18ee88458eb)


In [None]:
preprocessor = LibriPreprocessor(dataset=None, save_dir='E:/Datasets/librispeech/')

In [7]:
dataset_file_to_speech_array = preprocessor.speech_file_to_array(dataset=dataset, save_to_hd=True)

100%|██████████| 104014/104014 [45:03<00:00, 38.48ex/s]  


In [5]:
file_to_speech_array = load_from_disk('E:/Datasets/librispeech/file_to_speech_array/')

In [6]:
little_bits = file_to_speech_array.shard(128, 64)
len(little_bits)

813

In [7]:
little_bits_filtered = little_bits.filter(lambda x: len(x['speech'])//x['sampling_rate'] < 16, keep_in_memory=True)

100%|██████████| 1/1 [01:28<00:00, 88.54s/ba]


In [38]:
len(little_bits_filtered)

767

In [None]:
num_shards = 64

for i in range(num_shards-0):
    libri_shard = file_to_speech_array.shard(num_shards, i)
    print(f"Shard {i} has {len(libri_shard)} entries.")
    libri_filtered = preprocessor.filter_long_audio(dataset=libri_shard, max_audio_length=16, save_to_hd=False)
    print(f"Shard {i} has been filtered to {len(libri_filtered)} entries.")
    libri_audio_features = preprocessor.extract_features_and_tokenize(dataset=libri_filtered, save_to_hd=False)
    print(f"Shard {i} has had features extracted from audio file.")
    libri_encoded = preprocessor.encode_text(dataset=libri_audio_features, save_to_hd=True, shard=i)
    print(f'Shard {i} has had text encoded and was saved to disk.')

In [None]:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('ntu-spml/distilhubert')
feature_encoder = HubertConvFeatureExtractorWrapper.from_pretrained('ntu-spml/distilhubert')
tokenizer = BertTokenizerFast.from_pretrained('google/bert_uncased_L-2_H-768_A-12')
text_model = BertModel.from_pretrained('google/bert_uncased_L-2_H-768_A-12')

max_length = 16*16000
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def _extract_fts(data):
    # check that all files have the correct sampling rate
    assert (
        len(set(data['sampling_rate'])) == 1
    ), f"Make sure all inputs have the same sampling rate of {feature_extractor.sampling_rate}."
    # extract and pad input values
    input_values = feature_extractor(data['speech'], sampling_rate=data['sampling_rate'][0])
    data['input_values'] = input_values.input_values
    padded_input_values = feature_extractor(data['speech'], padding=True, return_tensors='pt', sampling_rate=data['sampling_rate'][0])
    
    # compute the latent features from the conv module
    import torch
    with torch.no_grad():
        input_values = padded_input_values['input_values'].to(device)
        latent_features = feature_encoder(input_values).transpose(1, 2)
        latent_features = latent_features.cpu().numpy()
        data['latent_features'] = latent_features
        
    # tokenize text
    tokenized_batch = tokenizer(data['text'], padding='longest', max_length=128, pad_to_max_length=False)
    data['input_ids'] = tokenized_batch['input_ids']
    data['attention_mask_text'] = tokenized_batch['attention_mask']
    data['token_type_ids_text'] = tokenized_batch['token_type_ids']
    
    return data

        
    
def extract_fts(dataset):
    feature_encoder.cuda()
    dataset = dataset.map(
        _extract_fts, 
        batch_size=16, 
        num_proc=1, 
        batched=True, 
        remove_columns=['text', 'sampling_rate'],
        keep_in_memory=True
    )
    feature_encoder.cpu()
    return dataset


def _encode_text(data):
    import torch
    with torch.no_grad():
        input_ids = torch.tensor(data['input_ids'], dtype=torch.int, device=device)
        attention_mask = torch.tensor(data['attention_mask_text'], dtype=torch.int, device=device)
        token_type_ids = torch.tensor(data['token_type_ids_text'], dtype=torch.int, device=device)
        embeddings = text_model(
            input_ids=input_ids,
            attention_mask=attention_mask, 
            token_type_ids=token_type_ids
        )
        embeddings = embeddings.pooler_output.cpu().numpy()
        data['sentence_embedding'] = embeddings
        
        return data


def encode_text(dataset=None):
    text_model.cuda()
    dataset = dataset.map(
        _encode_text, 
        batch_size=16, 
        num_proc=1, 
        batched=True, 
        remove_columns=['input_ids', 'attention_mask_text', 'token_type_ids_text'],
        keep_in_memory=True
    )
    text_model.cpu()
    return dataset

In [None]:
little_bits_fts = extract_fts(dataset=little_bits_filtered)

In [14]:
little_bits_encoded = encode_text(dataset=little_bits_fts)

100%|██████████| 48/48 [04:36<00:00,  5.76s/ba]


In [15]:
little_bits_encoded[0].keys()

dict_keys(['speech', 'input_values', 'latent_features', 'sentence_embedding'])