In [None]:
import os
import json
import yaml
import shutil
import secrets
import string
from tqdm import tqdm
import numpy as np
import torch
import torchaudio
import librosa
from datasets import load_dataset, Audio
import ctranslate2
from encodec import EncodecModel
from encodec.utils import convert_audio
import sys
sys.path.append('./WhisperSeg')
from WhisperSeg.model import WhisperSegmenterFast

# Load configuration from YAML file
with open('config.yaml', 'r') as config_file:
    config = yaml.safe_load(config_file)

# Extract configuration values
hf_dataset = config['hf_dataset']
data_dir = config['data_dir']
plot_dir = config['plot_dir']
whisperseg_config = config['whisperseg_config']

# Load the dataset
dataset = load_dataset(hf_dataset, split='train')

# Initialize models
segmenter = WhisperSegmenterFast("Systran/faster-whisper-large-v2", device="cpu")
encodec_model = EncodecModel.encodec_model_24khz()

def generate_random_id(length):
    alphabet = string.ascii_letters + string.digits
    return ''.join(secrets.choice(alphabet) for _ in range(length))

def generate_features(audio, sr, min_freq, spec_time_step, num_trials):
    ftr = segmenter.get_sliced_audios_features(audio, sr, min_freq, spec_time_step, num_trials)
    features = ctranslate2.StorageView.from_array(np.asarray([ftr[0][2]]))
    mel = ftr[0][2]
    encoded = segmenter.model_list[0].encode(features)
    embedding = torch.tensor(np.array(encoded).tolist(), device="cpu")
    
    audio = torch.tensor(audio)
    wav = convert_audio(audio, sr, encodec_model.sample_rate, encodec_model.channels)
    wav = wav.unsqueeze(0)
    with torch.no_grad():
        encoded_frames = encodec_model.encode(wav)
    codecs = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
    
    return mel, embedding, codecs

def process_subset(subset, config):
    features_dir = os.path.join(data_dir, subset, 'features')
    os.makedirs(features_dir, exist_ok=True)
    for subdir in ['spectrograms', 'whisper_embeddings', 'encodec_codecs']:
        os.makedirs(os.path.join(features_dir, subdir), exist_ok=True)

    processed_data = []
    subset_data = dataset.filter(lambda s: s['subset'] == subset)

    for item in tqdm(subset_data, desc=f"Processing {subset}"):
        try:
            random_id = generate_random_id(32)
            mel, embedding, codecs = generate_features(
                item['audio']['array'],
                item['audio']['sampling_rate'],
                config['min_freq'], 
                config['spec_time_step'], 
                config['num_trials']
            )

            mel_file = os.path.join(features_dir, 'spectrograms', f'{random_id}.pt')
            embedding_file = os.path.join(features_dir, 'whisper_embeddings', f'{random_id}.pt')
            codecs_file = os.path.join(features_dir, 'encodec_codecs', f'{random_id}.pt')

            torch.save(mel, mel_file)
            torch.save(embedding, embedding_file)
            torch.save(codecs, codecs_file)

            processed_data.append({
                'speaker': item.get('speaker', 'None'),
                'type': item.get('type', 'utterance'),
                'text': item['text'],
                'mel': mel_file,
                'embedding': embedding_file,
                'codecs': codecs_file,
                'waveform': item['audio']['path']
            })

        except Exception as e:
            print(f'Error processing item in {subset}: {e}')

    with open(os.path.join(features_dir, 'dataset.json'), 'w') as f:
        json.dump(processed_data, f)
subsets = set(dataset['subset'])

# Process each subset
for subset in subsets:
    if subset in whisperseg_config['songbirds']['datasets']:
        config = whisperseg_config['songbirds']
    elif subset in whisperseg_config['humans']['datasets']:
        config = whisperseg_config['humans']
    else:
        print(f"Skipping unknown subset: {subset}")
        continue
    
    process_subset(subset, config)

print("Processing complete!")

In [22]:
item = dataset[0]
audio, sr, min_freq, spec_time_step, num_trials = item['audio']['array'], item['audio']['sampling_rate'], config['min_freq'], config['spec_time_step'],  config['num_trials']
ftr = segmenter.get_sliced_audios_features(audio, sr, min_freq, spec_time_step, num_trials)
features = ctranslate2.StorageView.from_array(np.asarray([ftr[0][2]]))
mel = ftr[0][2]
encoded = segmenter.model_list[0].encode(features)
embedding = torch.tensor(np.array(encoded).tolist(), device="cpu")
    

In [12]:
encoded = segmenter.model_list[0].encode(features)

In [17]:
encoded

<ctranslate2._ext.StorageView at 0x20b4558dd70>

In [18]:
z = np.array(encoded)

In [19]:
z

array([[[-0.5442839 , -0.47230586, -0.1608399 , ..., -3.0254629 ,
         -0.5218957 , -1.1900226 ],
        [ 1.0759406 , -0.36049455,  0.8323346 , ..., -2.6637006 ,
         -0.47056672, -0.29163405],
        [ 1.4053938 ,  0.7581593 ,  0.5551884 , ..., -2.3033926 ,
         -0.41483313, -0.643653  ],
        ...,
        [ 0.9287099 , -0.6650456 ,  0.26375106, ...,  0.3682955 ,
          0.9621666 , -0.2075649 ],
        [ 0.9439007 , -0.97815657,  0.43080607, ...,  0.68621343,
          1.2327944 , -0.23502705],
        [ 0.9160954 , -0.63507974,  0.41857827, ...,  0.8747503 ,
          1.6396738 ,  0.04554638]]], dtype=float32)