### 0. Download "music" portiion of audioset dataset

In [None]:
# !mkdir audioset_download

In [None]:
# from audioset_download import Downloader
# d = Downloader(root_path='/mnt/sdb/audioset-download', labels=["Music"], n_jobs=36, download_type='eval', copy_and_replicate=False)
# d.download(format = 'wav')

### 1. Inpsect how many files from 'balanced dataset' are in the csv

In [1]:
from pathlib import Path

AUSIOSET_PATH = Path('/mnt/shared/alpaca/audioset_minigimbob/audio')
AUDIOCAP_PATH = Path('musiccaps-public.csv') # you can download this from kaggle

ytid_filepath_dict = {file.stem[1:12]: file for file in AUSIOSET_PATH.rglob('*') if file.is_file() and file.suffix == '.wav'}
# make filename_list a dict

In [2]:
import pandas as pd

caption_df_original = pd.read_csv(AUDIOCAP_PATH)

caption_df_filtered = caption_df_original[caption_df_original['ytid'].isin(ytid_filepath_dict.keys())]

In [3]:
len(caption_df_original), len(caption_df_filtered)
# (5521, 5480), we have 41 missing files

(5521, 5480)

In [4]:
ytid_caption_dict = caption_df_filtered.set_index('ytid')['caption'].to_dict()
# substitute special characters to space
special_characetrs = ['.', ',', '!', '?', ':', ';', '"', "'", '(', ')', '[', ']', '{', '}', '<', '>', '/', '\\', '|', '@', '#', '$', '%', '^', '&', '*', '+', '=', '~', '`']
for ytid, caption in ytid_caption_dict.items():
    for special_characetrs in special_characetrs:
        caption = caption.replace(special_characetrs, ' ')
        caption = caption.lower()
    ytid_caption_dict[ytid] = caption


In [5]:
# text normalization
# count total number of words in the captions (size of vocab)
from collections import Counter
import re

def get_word_counter(captions):
    word_counter = Counter()
    for caption in captions:
        words = re.findall(r'\w+', caption)
        for word in words:
            word_counter[word] += 1
    return word_counter

before_text_normalization = get_word_counter(ytid_caption_dict.values())
print(f'size of vocab before text normalization: {len(before_text_normalization)}')

size of vocab before text normalization: 5435


In [6]:
import spacy
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import os

get_lemma = spacy.load('en_core_web_sm')

'''
(before) This recording contains breaking and 32 shooting sounds. There is also a lot of deep rumbling noise. The whole audio is panned to the right side of the speakers.
(after) this record contain break and 32 shoot sound . there be also a lot of deep rumble noise . the whole audio be pan to the right side of the speaker .
'''
def process_caption(item):
    ytid, caption = item
    words = re.findall(r'\w+|\.', caption)
    normalized_caption = ' '.join([token.lemma_ for word in words for token in get_lemma(word)])
    return ytid, normalized_caption

# requires about 20GB of DRAM, ~ 3 min with 40 cores (serial ~ 16 min)
with ProcessPoolExecutor(max_workers=(os.cpu_count())) as executor:
    results = list(tqdm(executor.map(process_caption, ytid_caption_dict.items()), total=len(ytid_caption_dict)))

ytid_caption_dict = dict(results)

after_text_normalization = get_word_counter(ytid_caption_dict.values())
print(f'size of vocab after text normalization: {len(after_text_normalization)}')

100%|██████████| 5480/5480 [01:23<00:00, 65.57it/s] 


size of vocab after text normalization: 4250


In [None]:
print(f'vocab size reduced of {1 - len(after_text_normalization) / len(before_text_normalization):.2f}')

In [None]:
# check if ytid are all unique -> yes
len(ytid_caption_dict), len(set(ytid_caption_dict.keys()))

# check the maximum word length of the captions (not the number of alphabet) -> 136
# max([len(caption.split()) for caption in ytid_caption_dict.values()])

### 2. save audio as tensor

In [None]:
# load the wav file using torchaudio and save it as 16khz in the wavs folder
import torchaudio
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import os

# Function to load and process each file
def process_file(file):
    wav, sr = torchaudio.load(file)  # Load the audio file
    shape = wav.shape

    return file, shape, sr, wav.shape[1] / sr

# Main function to process files in parallel
def process_files_parallel(file_set):
    shape_dict = {}
    sr_dict = {}
    duration_dict = {}

    with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
        results = list(tqdm(executor.map(process_file, file_set), total=len(file_set)))

    for result in results:
        file, shape, sr, duration = result
        shape_dict[file] = shape
        sr_dict[file] = sr
        duration_dict[file] = duration

    return shape_dict, sr_dict, duration_dict

shape_dict, sr_dict, duration_dict = process_files_parallel(ytid_filepath_dict.values())


In [None]:
from collections import Counter

# print unique sr_dict values and its counts
print({k: v for k, v in sorted(Counter(sr_dict.values()).items(), key=lambda item: item[1])})
# print unique duration in increasing order
print({k: v for k, v in sorted(Counter(duration_dict.values()).items(), key=lambda item: item[0])}, end='\n\n')
# min = 9.5 sec, max = 10.007 sec

In [None]:
import torch

TARGET_SR = 16000
TARGET_WAV_LEN = 16000 * 10

def process_file(audio_path):
    wav, sr = torchaudio.load(audio_path)  # Load the audio file
    wav = wav.mean(dim=0, keepdim=True)  # Convert to mono
    wav = torchaudio.transforms.Resample(sr, TARGET_SR)(wav)  # Resample to 16kHz
    wav = wav.squeeze()
    
    # fill the audio file with itself if the length is less than TARGET_WAV_LEN
    wavlen = wav.shape[0]
    
    if wavlen < TARGET_WAV_LEN:
        wav = torch.cat([wav, wav[:TARGET_WAV_LEN - wavlen]])
    elif wavlen > TARGET_WAV_LEN:
        wav = wav[:TARGET_WAV_LEN]
            
    # get the caption of the audio file
    file_id = audio_path.stem[1:12]
    # extract the caption for corresponding track_id
    caption = ytid_caption_dict[file_id]
        
    return (wav, caption)

# Main function to process files in parallel
def process_files_parallel(file_set):
    with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
        wav_caption_list = list(tqdm(executor.map(process_file, file_set), total=len(file_set)))
    
    return wav_caption_list

wav_caption_list = process_files_parallel(ytid_filepath_dict.values())

In [None]:
output_path = Path('pkls/musiccaps_dataset.pkl')
# save with pickle
import pickle
pickle.dump(wav_caption_list, open(output_path, 'wb'))