In [None]:
# datasets to process
# https://huggingface.co/datasets/aaparajit02/punjabi-asr / https://ai4bharat.iitm.ac.in/shrutilipi/ / # https://www.kaggle.com/datasets/warcoder/punjabi-speech-recognition
# https://data.mendeley.com/datasets/sdbc8f5b77/2
# https://figshare.com/articles/dataset/Google-synth_A_Synthesized_Punjabi_Speech_Dataset/23615607/1
# https://figshare.com/articles/dataset/_strong_CMU-synth_A_synthesized_Punjabi_Speech_dataset_strong_/23606697/1
# https://huggingface.co/datasets/google/fleurs
# https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1
# https://ai4bharat.iitm.ac.in/indicvoices/
# https://ai4bharat.iitm.ac.in/indicsuperb/

# features: ['speaker_id', 'audio', 'text', 'gender', 'duration'],

# CMU Synth and GoogleSynth is not good for TTS but good for ASR

# Which columns to remove and retain

In [None]:
from datasets import load_dataset, load_from_disk, concatenate_datasets, Dataset, DatasetDict
import pandas as pd
import os
import glob
import pylab as plt
from pydub import AudioSegment

dir = '/mnt/data/Speech Dataset/processed_datasets/'
# list dirs in dir
dirs = os.listdir(dir)
print(dirs)
datasets = [load_from_disk(f'{dir}{d}') for d in dirs]

for d, dir in zip(datasets, dirs):
    print(dir)

    for split in d:
        source_column = [f'{dir}__{split}'] * len(d[split])
        d[split] = d[split].add_column("source", source_column)

all_data_splits = []
train_data_splits = []
train_valid_data_splits = []
test_data_splits = []
valid_data_splits = []

for d in datasets:
    for split in d:
        all_data_splits.append(d[split])
        if split == 'train':
            train_data_splits.append(d[split])
        if split == 'train' or 'valid' in split:
            train_valid_data_splits.append(d[split])
        if split == 'test':
            test_data_splits.append(d[split])
        if 'valid' in split:
            valid_data_splits.append(d[split])

print(len(all_data_splits))
print(len(train_data_splits))
print(len(train_valid_data_splits))
print(len(test_data_splits))
print(len(valid_data_splits))

ds_all = concatenate_datasets(all_data_splits)
ds_train = concatenate_datasets(train_data_splits)
ds_train_valid = concatenate_datasets(train_valid_data_splits)
ds_test = concatenate_datasets(test_data_splits)
ds_valid = concatenate_datasets(valid_data_splits)
print(ds_all)
print(ds_train)
print(ds_train_valid)
print(ds_test)
print(ds_valid)

# get indexes of audio files with duration >= 30s
d = ds_all['duration']
d_train = ds_train['duration']
d_test = ds_test['duration']
d_valid = ds_valid['duration']

indexes = [i for i, x in enumerate(d) if x >= 30]
indexes_train = [i for i, x in enumerate(d_train) if x >= 30]
indexes_test = [i for i, x in enumerate(d_test) if x >= 30]
indexes_valid = [i for i, x in enumerate(d_valid) if x >= 30]

print(len(indexes))
print(len(indexes_train))
print(len(indexes_test))
print(len(indexes_valid))


indexes_all = list(range(len(ds_all)))
indexes_all_train = list(range(len(ds_train)))
indexes_all_test = list(range(len(ds_test)))
indexes_all_valid = list(range(len(ds_valid)))

indexes_all = [x for x in indexes_all if x not in indexes]
indexes_all_train = [x for x in indexes_all_train if x not in indexes_train]
indexes_all_test = [x for x in indexes_all_test if x not in indexes_test]
indexes_all_valid = [x for x in indexes_all_valid if x not in indexes_valid]

ds_all_f = ds_all.select(indexes_all)
ds_train_f = ds_train.select(indexes_all_train)
ds_test_f = ds_test.select(indexes_all_test)
ds_valid_f = ds_valid.select(indexes_all_valid)

d_tv = ds_train_valid['duration']
indexes_tv = [i for i, x in enumerate(d_tv) if x >= 30]
indexes_all_tv = list(range(len(ds_train_valid)))
indexes_all_tv = [x for x in indexes_all_tv if x not in indexes_tv]
ds_train_valid_f = ds_train_valid.select(indexes_all_tv)

def get_split_duration(ds):
    return sum(ds['duration']) / 3600

print(f'{get_split_duration(ds_all_f)}, {get_split_duration(ds_train_f)}, {get_split_duration(ds_test_f)}, {get_split_duration(ds_valid_f)}, {get_split_duration(ds_train_valid_f)}')
plt.hist(ds_all_f['duration'])
plt.hist(ds_train_f['duration'])
plt.hist(ds_train_valid_f['duration'])

ds = DatasetDict({'train': ds_train_f, 'valid': ds_valid_f, 'test': ds_test_f})


In [None]:
ds_train_f = ds_train_f.shuffle(seed=42)
ds_valid_f = ds_valid_f.shuffle(seed=42)
ds_test_f = ds_test_f.shuffle(seed=42)
ds_train_f = ds_train_f.flatten_indices(num_proc=24,cache_file_name='/mnt/sea/tmp/ds_train_f2.cache', writer_batch_size = 32)
ds_valid_f = ds_valid_f.flatten_indices(num_proc=24,cache_file_name='/mnt/sea/tmp/ds_valid_f2.cache', writer_batch_size = 32)
ds_test_f = ds_test_f.flatten_indices(num_proc=24,cache_file_name='/mnt/sea/tmp/ds_test_f2.cache', writer_batch_size = 32)

In [None]:
ds_final = DatasetDict({'train': ds_train_f, 'valid': ds_valid_f, 'test': ds_test_f})


In [None]:
ds_final

In [None]:
ds_final.save_to_disk('/home/kd/Desktop/ds/Punjabi_ASR_DS', num_proc=24)