### 1. Inpsect data integrity in sdd dataset

In [None]:
from pathlib import Path

AUSIOSET_PATH = Path('/mnt/shared/alpaca/song_describer_dataset/audio')
AUDIOCAP_PATH = Path('/mnt/shared/alpaca/song_describer_dataset/song_describer.csv') # you can download this from kaggle

# check whether there is missing audio files
sdd_fileid_set = {str(file.stem[:-5]) for file in AUSIOSET_PATH.rglob('*') if file.is_file() and file.suffix == '.mp3'}
sdd_filepath_set = {file for file in AUSIOSET_PATH.rglob('*') if file.is_file() and file.suffix == '.mp3'}
# make filename_list a dict

import pandas as pd
caption_df_original = pd.read_csv(AUDIOCAP_PATH)

# extract track ids from caption_df_original and conver to set
caption_track_id_set = set(str(x) for x in caption_df_original['track_id'])
print(caption_track_id_set)

# compare the two sets
missing_set = caption_track_id_set - sdd_fileid_set
print(len(missing_set)) # OK every audio files are in the csv and vice versa

# create a dictionary, key = track_id, value = caption
trackid_caption_dict = {str(row['track_id']): row['caption'] for index, row in caption_df_original.iterrows()}
print(trackid_caption_dict)

# filter some special characters (substitute with space)
special_chars = {'&', ',', '"', "'", '/', ';', '“', '(', '‘', '’', '.', ')', '-', '\n', ':'}
track_ids = trackid_caption_dict.keys()
for track_id in track_ids:
    caption = trackid_caption_dict[track_id]
    for char in special_chars:
        caption = caption.replace(char, ' ')
    trackid_caption_dict[track_id] = caption

In [None]:
import torchaudio
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import os

# Function to load and process each file
def process_file(file):
    wav, sr = torchaudio.load(file)  # Load the audio file
    shape = wav.shape

    return file, shape, sr

# Main function to process files in parallel
def process_files_parallel(file_set):
    shape_dict = {}
    sr_dict = {}

    with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
        results = list(tqdm(executor.map(process_file, file_set), total=len(file_set)))

    for result in results:
        file, shape, sr = result
        shape_dict[file] = shape
        sr_dict[file] = sr

    return shape_dict, sr_dict

shape_dict, sr_dict = process_files_parallel(sdd_filepath_set)
print(shape_dict)
print(sr_dict)

In [None]:
print(shape_dict)
# find the audio file of minimum length
min_length = min(shape_dict, key=shape_dict.get)
print(min_length, shape_dict[min_length]) # it's about 30 seconds...
max_length = max(shape_dict, key=shape_dict.get)
print(max_length, shape_dict[max_length]) # it's about 2 minutes...

# find the total length of the audio files
total_length = 0
for key in shape_dict:
    total_length += (shape_dict[key][1] / sr_dict[key])
    
total_length_in_hrs = total_length / 3600
print(total_length_in_hrs) # 23 hours

# padd all the torchaudio to the max length
WAV_MAX_LEN = 5759471

### The audios have all the different length

In [None]:
import torchaudio
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import os
import torch


# Function to load and process each file
# wav_caption_list = []

# for file in sdd_filepath_set:
#     wav, sr = torchaudio.load(file)  # Load the audio file
#     wav = wav.mean(dim=0, keepdim=True)  # Convert to mono
#     wav = torchaudio.transforms.Resample(sr, 16000)(wav)  # Resample to 16kHz
#     wav = wav.squeeze()
#     # make sure that wav is longer than 10 seconds
#     assert wav.shape[0] > 160000
    
#     # get the caption of the audio file
#     file_id = file.stem[:-5]
#     # extract the caption for corresponding track_id
#     caption = trackid_caption_dict[file_id]
#     wav_caption_list.append((wav, caption))
#     assert False

# Function to load and process each file
def process_file(file):
    wav, sr = torchaudio.load(file)  # Load the audio file
    wav = wav.mean(dim=0, keepdim=True)  # Convert to mono
    wav = torchaudio.transforms.Resample(sr, 16000)(wav)  # Resample to 16kHz
    wav = wav.squeeze()
    # make sure that wav is longer than 10 seconds
    assert wav.shape[0] > 160000
    
    wavlen = wav.shape[0]
    # pad the wav to the WAV_MAX_LEN
    wav = torch.nn.functional.pad(wav, (0, WAV_MAX_LEN - wavlen))
    
    
    # get the caption of the audio file
    file_id = file.stem[:-5]
    # extract the caption for corresponding track_id
    caption = trackid_caption_dict[file_id]
    
    
    return (wav, caption, wavlen)

# Main function to process files in parallel
def process_files_parallel(file_set):

    
    with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
        wav_caption_wavlen_list = list(tqdm(executor.map(process_file, file_set), total=len(file_set)))
    
    return wav_caption_wavlen_list

wav_caption_wavlen_list = process_files_parallel(sdd_filepath_set)

In [None]:
# save the processed data and caption in pickle format separately
import pickle

with open('pkls/sdd_dataset.pkl', 'wb') as f:
    pickle.dump(wav_caption_wavlen_list, f)

In [None]:
# test dataloder
import pickle
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path

from torch.nn.utils.rnn import pad_sequence


class SDDDataset(Dataset):
    def __init__(self, sdd_dataset_pkl_path: Path):
        
        sdd_dataset = pickle.load(open(sdd_dataset_pkl_path, 'rb'))
        self.wavs = [x[0] for x in sdd_dataset]
        self.txts = [x[1] for x in sdd_dataset]
        self.wavlens = [x[2] for x in sdd_dataset]
        self.num_data = len(self.wavs)
        self.wav_duration = 16000 * 10 # 10 seconds
                
    def __len__(self):
        return self.num_data
    
    def __getitem__(self, idx):
        # read wav from pt file, read txt from list
        real_wav_len = self.wavlens[idx]
        # randomly select a starting point
        start_point = torch.randint(0, real_wav_len - self.wav_duration, (1,)).item()
        wav = self.wavs[idx][start_point:start_point+self.wav_duration]
        cap = self.txts[idx]
        return wav, cap
    
    
training_data = SDDDataset(sdd_dataset_pkl_path='pkls/sdd_dataset.pkl')

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

In [None]:
wav, txt = next(iter(train_dataloader))
for w, t, in zip(wav, txt):
    print(w.shape, t)
    break