### 1. Inpsect data integrity in sdd dataset

In [1]:
from pathlib import Path

AUSIOSET_PATH = Path('/mnt/shared/alpaca/song_describer_dataset/audio')
AUDIOCAP_PATH = Path('/mnt/shared/alpaca/song_describer_dataset/song_describer.csv') # you can download this from kaggle

# check whether there is missing audio files
sdd_fileid_set = {str(file.stem[:-5]) for file in AUSIOSET_PATH.rglob('*') if file.is_file() and file.suffix == '.mp3'}
sdd_filepath_set = {file for file in AUSIOSET_PATH.rglob('*') if file.is_file() and file.suffix == '.mp3'}
# make filename_list a dict

import pandas as pd
caption_df_original = pd.read_csv(AUDIOCAP_PATH)

# extract track ids from caption_df_original and conver to set
caption_track_id_set = set(str(x) for x in caption_df_original['track_id'])
print(caption_track_id_set)

# compare the two sets
missing_set = caption_track_id_set - sdd_fileid_set
print(len(missing_set)) # OK every audio files are in the csv and vice versa

# create a dictionary, key = track_id, value = caption
trackid_caption_dict = {str(row['track_id']): row['caption'] for index, row in caption_df_original.iterrows()}
print(trackid_caption_dict)

# filter some special characters (substitute with space)
special_chars = {'&', ',', '"', "'", '/', ';', '“', '(', '‘', '’', '.', ')', '-', '\n', ':'}
track_ids = trackid_caption_dict.keys()
for track_id in track_ids:
    caption = trackid_caption_dict[track_id]
    for char in special_chars:
        caption = caption.replace(char, ' ')
    trackid_caption_dict[track_id] = caption

{'1180599', '359662', '1210725', '48719', '11839', '1162034', '428529', '1162697', '175043', '305163', '330650', '1051195', '1135707', '4883', '24837', '219026', '457126', '458251', '1166070', '461017', '537116', '359655', '785415', '339422', '458252', '134057', '1062831', '1243198', '1169716', '1051207', '959181', '1211608', '296238', '35349', '357385', '1350313', '1245182', '1194315', '1157362', '305160', '464504', '1353972', '40639', '43990', '793923', '1121847', '1014963', '145945', '986583', '1356875', '1093795', '11841', '202741', '1245191', '1313073', '1336413', '1350853', '244237', '1051201', '359664', '976784', '6725', '7248', '304775', '352685', '43987', '5089', '1194312', '132868', '352691', '518603', '457080', '537118', '786939', '785436', '1194314', '238099', '654631', '1188123', '959155', '437792', '1357005', '1051196', '260839', '1269561', '700217', '1178762', '938054', '190871', '1211605', '296236', '1241783', '402058', '348881', '1281209', '1187970', '1061564', '106707

In [2]:
# text normalization
# count total number of words in the captions (size of vocab)
from collections import Counter
import re

def get_word_counter(captions):
    word_counter = Counter()
    for caption in captions:
        words = re.findall(r'\w+', caption)
        for word in words:
            word_counter[word] += 1
    return word_counter

before_text_normalization = get_word_counter(trackid_caption_dict.values())
print(f'size of vocab before text normalization: {len(before_text_normalization)}')

size of vocab before text normalization: 2301


In [3]:
import spacy
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import os

get_lemma = spacy.load('en_core_web_sm')

# Function to get the WordNet POS tag

def process_caption(item):
    ytid, caption = item
    words = re.findall(r'\w+|\.', caption)
    normalized_caption = ' '.join([token.lemma_ for word in words for token in get_lemma(word)])
    return ytid, normalized_caption

# requires about 20GB of DRAM, ~ 3 min with 40 cores (serial ~ 16 min)
with ProcessPoolExecutor(max_workers=(os.cpu_count())) as executor:
    results = list(tqdm(executor.map(process_caption, trackid_caption_dict.items()), total=len(trackid_caption_dict)))

trackid_caption_dict = dict(results)

after_text_normalization = get_word_counter(trackid_caption_dict.values())
print(f'size of vocab after text normalization: {len(after_text_normalization)}')

100%|██████████| 706/706 [00:04<00:00, 173.28it/s]

size of vocab after text normalization: 1809





In [4]:
# show the reduced ratio
print(f'vocab size reduced of {1 - len(after_text_normalization) / len(before_text_normalization):.2f}')

vocab size reduced of 0.21


In [5]:
import torchaudio
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import os

# Function to load and process each file
def process_file(file):
    wav, sr = torchaudio.load(file)  # Load the audio file
    shape = wav.shape

    return file, shape, sr

# Main function to process files in parallel
def process_files_parallel(file_set):
    shape_dict = {}
    sr_dict = {}

    with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
        results = list(tqdm(executor.map(process_file, file_set), total=len(file_set)))

    for result in results:
        file, shape, sr = result
        shape_dict[file] = shape
        sr_dict[file] = sr

    return shape_dict, sr_dict

shape_dict, sr_dict = process_files_parallel(sdd_filepath_set)
print(shape_dict)
print(sr_dict)

100%|██████████| 706/706 [00:09<00:00, 76.26it/s]


{PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/70/1159370.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/34/243734.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/93/938293.2min.mp3'): torch.Size([2, 2946583]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/65/938065.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/18/461018.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/58/76658.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/87/357387.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/73/304773.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/34/1162034.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_

In [6]:
print(shape_dict)
# find the audio file of minimum length
min_length = min(shape_dict, key=shape_dict.get)
print(min_length, shape_dict[min_length]) # it's about 30 seconds...
max_length = max(shape_dict, key=shape_dict.get)
print(max_length, shape_dict[max_length]) # it's about 2 minutes...

# find the total length of the audio files
total_length = 0
for key in shape_dict:
    total_length += (shape_dict[key][1] / sr_dict[key])
    
total_length_in_hrs = total_length / 3600
print(total_length_in_hrs) # 23 hours

# padd all the torchaudio to the max length
WAV_MAX_LEN = 5759471

{PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/70/1159370.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/34/243734.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/93/938293.2min.mp3'): torch.Size([2, 2946583]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/65/938065.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/18/461018.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/58/76658.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/87/357387.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/73/304773.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_describer_dataset/audio/34/1162034.2min.mp3'): torch.Size([2, 5291759]), PosixPath('/mnt/shared/alpaca/song_

### The audios have all the different length

In [7]:

import torchaudio
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import os
import torch


# Function to load and process each file
# wav_caption_list = []

# for file in sdd_filepath_set:
#     wav, sr = torchaudio.load(file)  # Load the audio file
#     wav = wav.mean(dim=0, keepdim=True)  # Convert to mono
#     wav = torchaudio.transforms.Resample(sr, 16000)(wav)  # Resample to 16kHz
#     wav = wav.squeeze()
#     # make sure that wav is longer than 10 seconds
#     assert wav.shape[0] > 160000
    
#     # get the caption of the audio file
#     file_id = file.stem[:-5]
#     # extract the caption for corresponding track_id
#     caption = trackid_caption_dict[file_id]
#     wav_caption_list.append((wav, caption))
#     assert False

# Function to load and process each file
def process_file(file):
    wav, sr = torchaudio.load(file)  # Load the audio file
    wav = wav.mean(dim=0, keepdim=True)  # Convert to mono
    wav = torchaudio.transforms.Resample(sr, 16000)(wav)  # Resample to 16kHz
    wav = wav.squeeze()
    # make sure that wav is longer than 10 seconds
    assert wav.shape[0] > 160000
    
    wavlen = wav.shape[0]
    # pad the wav to the WAV_MAX_LEN
    wav = torch.nn.functional.pad(wav, (0, WAV_MAX_LEN - wavlen))
    
    
    # get the caption of the audio file
    file_id = file.stem[:-5]
    # extract the caption for corresponding track_id
    caption = trackid_caption_dict[file_id]
    
    
    return (wav, caption, wavlen)

# Main function to process files in parallel
def process_files_parallel(file_set):

    
    with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
        wav_caption_wavlen_list = list(tqdm(executor.map(process_file, file_set), total=len(file_set)))
    
    return wav_caption_wavlen_list

wav_caption_wavlen_list = process_files_parallel(sdd_filepath_set)

100%|██████████| 706/706 [01:02<00:00, 11.26it/s]


In [8]:
# save the processed data and caption in pickle format separately
import pickle

with open('pkls/sdd_dataset.pkl', 'wb') as f:
    pickle.dump(wav_caption_wavlen_list, f)

In [9]:
# test dataloder
import pickle
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path

from torch.nn.utils.rnn import pad_sequence


class SDDDataset(Dataset):
    def __init__(self, sdd_dataset_pkl_path: Path):
        
        self.sdd_dataset = pickle.load(open(sdd_dataset_pkl_path, 'rb'))
        self.num_data = len(self.sdd_dataset)
        self.wav_duration = 16000 * 10 # 10 seconds
                 
    def __len__(self):
        return self.num_data
    
    def __getitem__(self, idx):
        # read wav from pt file, read txt from list
        real_wav_len = self.sdd_dataset[idx][0].shape[0]
        # randomly select a starting point
        start_point = torch.randint(0, real_wav_len - self.wav_duration, (1,)).item()
        wav = self.sdd_dataset[idx][0][start_point:start_point+self.wav_duration]
        cap = self.sdd_dataset[idx][1]
        return wav, cap
    
    
training_data = SDDDataset(sdd_dataset_pkl_path='pkls/sdd_dataset.pkl')

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

In [10]:
wav, txt = next(iter(train_dataloader))
print(wav.shape, txt)

torch.Size([64, 160000]) ('a fun upbeat drum section play with a keyboard and some electronic horn type noise feel like the automatic song on a keyboard that you d play in music lesson', 'a high energy pop track feature synthesized sound and rock guitar', 'relax music play mostly on the piano that can be use while study meditate or just relax', 'this be an electronic song it be also instrumental song it start with a drum loop with snare kick and hihat after that it introduce synth pad and other electronic song the vibe of this song feel very old maybe around 2000s', 'an energetic song with a heavy riff and male vocal', 'an instrumental surf rock track with a twist open charleston beat with strum guitar and a mellow synth lead the song be a happy cyberpunk soundtrack', 'a song that s intensely personal as its superficial lyric that take you to nowhere', 'this folk ballad sung by a man in french slightly off tune have a nostalgic mood and feature violin and guitar', 'a grungy indie rock 