In [1]:
from pathlib import Path
import numpy as np
import transformers
import zarr
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchaudio
import multiprocessing as mp
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
import time

In [2]:
root_dir = Path.cwd()

data_folder = root_dir / "data" / "voxceleb"

audio_files_path = data_folder / "dev" / "aac"

clip_langauge_path = data_folder / "audio_clips_meta_data.csv"
speaker_path = data_folder / "vox2_meta.csv"

zarr_archive_path = root_dir / "data" / "preprocessed" / "voxceleb2_v1.zarr"

zarr_archive_path.parent.mkdir(exist_ok=True, parents=True)

speaker_df = pd.read_csv(speaker_path, delimiter="\t")
language_df = pd.read_csv(clip_langauge_path, index_col=False).drop(
    columns=["Unnamed: 0"]
)

print(f"Number of speakers:\t{speaker_df.shape[0]}")

print(f"Number of audio clips:\t{language_df.shape[0]}")
print(
    f"Number of english audio clips:\t{language_df[language_df.language == 'en'].shape[0]}"
)

Number of speakers:	6114
Number of audio clips:	1092009
Number of english audio clips:	537134


In [3]:
n_samples_per_speaker = 3


language_df = (
    language_df[language_df.language == "en"]
    .groupby("speaker_id")
    .filter(lambda x: len(x) >= n_samples_per_speaker)
    .groupby("speaker_id")
    .apply(lambda x: x.sample(n=n_samples_per_speaker, random_state=1337))
    .reset_index(drop=True)
)

  .apply(lambda x: x.sample(n=n_samples_per_speaker, random_state=1337))


In [4]:
print(f"Number of samples:\t{language_df.shape[0]}")
print(f"Number of unique speakers:\t{language_df.speaker_id.unique().shape[0]}")
language_df = language_df.rename(columns={"speaker_id": "client_id"})
language_df.head()

Number of samples:	12756
Number of unique speakers:	4252


Unnamed: 0,language,client_id,clip_id,audio_file,path
0,en,id00012,_raOc3-IRsw,00114.m4a,id00012/_raOc3-IRsw/00114.m4a
1,en,id00012,Z-G8-wqpxwU,00097.m4a,id00012/Z-G8-wqpxwU/00097.m4a
2,en,id00012,C_FAL9gv8bo,00021.m4a,id00012/C_FAL9gv8bo/00021.m4a
3,en,id00016,mW9EXHGCHi4,00127.m4a,id00016/mW9EXHGCHi4/00127.m4a
4,en,id00016,29NOrEy8ZY0,00004.m4a,id00016/29NOrEy8ZY0/00004.m4a


In [5]:
feature_size = 80
chunk_length = 30
hop_length = 160
target_sample_rate = 16000
feature_extractor = transformers.WhisperFeatureExtractor(
    feature_size=feature_size,
    chunk_length=chunk_length,
    hop_length=hop_length,
    device="cuda",
)

In [6]:
batch_size = 128

zarr_root = zarr.open(zarr_archive_path, mode="w")

df = language_df

zarr_root.create_array(
    "features",
    shape=(
        df.shape[0],
        feature_size,
        int(target_sample_rate / hop_length * chunk_length) // 3,
    ),
    chunks=(1, feature_size, int(target_sample_rate / hop_length * chunk_length) // 3),
    dtype="float32",
)

zarr_root.create_array(
    "attention_mask",
    shape=(
        df.shape[0],
        int(target_sample_rate / hop_length * chunk_length) // 3,
    ),
    chunks=(1, int(target_sample_rate / hop_length * chunk_length) // 3),
    dtype="int32",
)

zarr_root.create_array(
    "client_index",
    shape=(df.shape[0],),
    dtype="int32",
)

zarr_root.create_array(
    "client_id",
    shape=(df.shape[0],),
    dtype=f"S{len(df.iloc[0])}",
)

zarr_root.create_array(
    "path",
    shape=(df.shape[0],),
    dtype=f"S32",
)

client_indices = {
    v: k
    for k, v in df.groupby("client_id")
    .count()
    .reset_index()[["client_id"]]
    .to_dict()["client_id"]
    .items()
}

time_pre = []
time_feature = []
time_after = []


def process_batch(batch_inputs):
    batch_idx, batch = batch_inputs

    audio_16k = []
    row_indices = []
    batch_client_ids = []
    batch_age_numbers = []
    batch_genders = []
    client_ids = []
    paths = []

    time_pre_start = time.time()

    for row_index, (index, row) in enumerate(batch.iterrows()):
        audio_path = audio_files_path / row.client_id / row.clip_id / row.audio_file

        # row_index += batch_idx * batch_size
        row_index = index

        row_indices.append(index)

        if not audio_path:
            print(f"Missing audio file:\t{row.path}")
            continue

        audio, orig_sr = torchaudio.load(audio_path)

        audio_16k.append(
            torchaudio.functional.resample(
                audio, orig_freq=orig_sr, new_freq=target_sample_rate
            )[0, :]
        )

        batch_client_ids.append(row.client_id)

        client_ids.append(row.client_id.encode("utf-8"))
        paths.append(row.path.ljust(32).encode("utf-8"))

    audio_max_length = max(len(w) for w in audio_16k)

    audio_padded = np.array(
        [
            np.pad(
                w, (0, audio_max_length - len(w)), mode="constant", constant_values=0
            )
            for w in audio_16k
        ]
    )

    time_feature_start = time.time()
    time_pre.append(time_feature_start - time_pre_start)

    features = feature_extractor(
        audio_padded,
        return_tensors="np",
        return_attention_mask=True,
        sampling_rate=target_sample_rate,
        device="cuda",
    )

    time_after_start = time.time()
    time_feature.append(time_after_start - time_feature_start)

    zarr_root["features"][row_indices[0] : row_indices[-1] + 1, :] = features[
        "input_features"
    ][:, :, :1000]
    zarr_root["attention_mask"][row_indices[0] : row_indices[-1] + 1, :] = features[
        "attention_mask"
    ][:, :1000]

    zarr_root["client_index"][row_indices[0] : row_indices[-1] + 1] = np.array(
        [client_indices[x] for x in batch_client_ids]
    )

    zarr_root["client_id"][row_indices[0] : row_indices[-1] + 1] = np.array(client_ids)
    zarr_root["path"][row_indices[0] : row_indices[-1] + 1] = np.array(paths)

    time_after_end = time.time()

    time_after.append(time_after_end - time_after_start)


batches = np.array_split(df.reset_index(), np.ceil(df.shape[0] / batch_size))

for batch_idx, batch in tqdm(enumerate(batches), total=len(batches)):
    process_batch((batch_idx, batch))

  return bound(*args, **kwds)
100%|██████████████████████████████████████████████████████████████████| 100/100 [05:12<00:00,  3.12s/it]


In [7]:
print(f"Before:\t{sum(time_pre) / len(time_pre)}")
print(f"feature:\t{sum(time_feature) / len(time_feature)}")
print(f"after:\t{sum(time_after) / len(time_after)}")

Before:	1.2836652660369874
feature:	1.1886273860931396
after:	0.6267783164978027
