### 0. Download "music" portiion of audioset dataset

In [None]:
!mkdir audioset_download

In [None]:
from audioset_download import Downloader
d = Downloader(root_path='/mnt/sdb/audioset-download', labels=["Music"], n_jobs=36, download_type='eval', copy_and_replicate=False)
d.download(format = 'wav')

### 1. Inpsect how many files from 'balanced dataset' are in the csv

In [None]:
from pathlib import Path

AUSIOSET_PATH = Path('/mnt/shared/alpaca/audioset-download')
AUDIOCAP_PATH = Path('musiccaps-public.csv') # you can download this from kaggle

ytid_filename_dict = {file.stem[:11]: file for file in AUSIOSET_PATH.rglob('*') if file.is_file() and file.suffix == '.wav'}
# make filename_list a dict

In [None]:
import pandas as pd

caption_df_original = pd.read_csv(AUDIOCAP_PATH)

caption_df_filtered = caption_df_original[caption_df_original['ytid'].isin(ytid_filename_dict.keys())]

In [None]:
len(caption_df_original), len(caption_df_filtered)
# (5521, 4367), about 80% of the captions are in the audio set 'balanced'

In [None]:
ytid_caption_dict = caption_df_filtered.set_index('ytid')['caption'].to_dict()

In [None]:
# check if ytid are all unique -> yes
len(ytid_caption_dict), len(set(ytid_caption_dict.keys()))

# check the maximum word length of the captions (not the number of alphabet) -> 136
max([len(caption.split()) for caption in ytid_caption_dict.values()])

### 2. save audio as tensor

In [None]:

!mkdir wavs

In [None]:
# load the wav file using torchaudio and save it as 16khz in the wavs folder

import torchaudio
from tqdm import tqdm
# count the number of sr using defaultdict
from collections import defaultdict
import torch

sr_counter = defaultdict(int)

cnt = 0
error_keys = []

# iterate over the ytid_caption_dict
# 1. if the audio cannot be loaded, add the ytid to the error_keys
# 2. resample the audio to 16kHz and save it as tensor in the wavs folder
# 3. if the audio is shorter than 9 seconds (16000 * 9), discard the audio and add the ytid to the error_keys
# 4. clip if the audio is longer than 10 seconds and pad if the audio is shorter than 10 seconds

for ytid in tqdm(ytid_caption_dict.keys()):
    # catch the decode error
    try:
        wav_path = ytid_filename_dict[ytid]
        wav, sr = torchaudio.load(ytid_filename_dict[ytid])
        
    except Exception as e:
        error_keys.append(ytid)
        cnt += 1
        continue
    
    if sr != 16000:
        wav = torchaudio.transforms.Resample(sr, 16000)(wav)
        
    sr = 16000
    
    audio_len = wav.size(1) / sr
    
    if audio_len < 9:
        error_keys.append(ytid)
        cnt += 1
        continue
    
    elif audio_len > 10:
        wav = wav[:, :(sr * 10)]
    elif audio_len < 10:
        # zero pad
        wav = torch.nn.functional.pad(wav, (0, (sr * 10) - wav.size(1)))
    
    # reduce dim to mono
    wav = torch.mean(wav, dim=0, keepdim=True)
    
        
    assert wav.size(1) == (sr * 10)
    
    torchaudio.save(Path(f'wavs/{ytid}.wav'), wav, sr)
    sr_counter[sr] += 1
    
print(cnt)
print(sr_counter)

ytid_caption_dict = {key: value for key, value in ytid_caption_dict.items() if key not in error_keys}

import pickle
# save dict
with open('ytid_caption_dict.pkl', 'wb') as f:
    pickle.dump(ytid_caption_dict, f)

# 680 audio files are not available -> 3687

### 3. filter captions
some descriptions include special tokens like ',', '/', '.'

In [None]:
# inspect the non-alphabet characters in caption
import pickle

with open('ytid_caption_dict.pkl', 'rb') as f:
    ytid_caption_dict = pickle.load(f)
    
special_tokens = set()
for caption in ytid_caption_dict.values():
    is_bad_caption = False
    for char in caption:
        if not char.isalnum():
            special_tokens.add(char)
            is_bad_caption = True
    # if is_bad_caption == True:
    #     print(caption)
            
print(special_tokens) # output: {'&', ',', '"', "'", '/', ';', '“', '(', '‘', '’', '.', ')', '-', '\n', ':', ' '}
# musiclm_pytorch (x_clip)'s text encoder only filters ' ', 'tab', '\n'

# remove the special tokens in the dictionary (except for ' ')
ytid_caption_dict = {key: ''.join([char for char in value if char.isalnum() or char == ' ']) for key, value in ytid_caption_dict.items()}

# save dict
with open('ytid_caption_dict.pkl', 'wb') as f:
    pickle.dump(ytid_caption_dict, f)

In [None]:
!mkdir preprocessed_tensors

In [None]:
# calculate the total tensor size of audio files in GB
# ensure all the wav files have the same length of 16000 * 10

# save the dataset as pickle file for later use

import pickle

# import defaultdict
from collections import defaultdict
from tqdm import tqdm
import torchaudio
import torch
from pathlib import Path

ytid_caption_dict = None
# load saved ytid -> caption dict
with open('ytid_caption_dict.pkl', 'rb') as f:
    ytid_caption_dict = pickle.load(f)
    

total_size_in_bytes = 0

# counter of the wav length
counter_wav_len = defaultdict(int)
shape_wave_len = set()

wav_list = []
txt_list = []


for idx, ytid in tqdm(enumerate(sorted(ytid_caption_dict.keys()))):
    resampled_wav_path = Path(f'wavs/{ytid}.wav')
    wav, sr = torchaudio.load(resampled_wav_path)
    wav = wav.squeeze()
    assert sr == 16000
    counter_wav_len[wav.numel()] += 1
    total_size_in_bytes += wav.numel() * wav.element_size()
    
    # save the wav as tensor.pt
    output_no = str(idx).zfill(7)
    output_file = Path(f'preprocessed_tensors/{output_no}.pt')
    torch.save(wav, output_file)               
    
    wav_list.append(wav)
    txt_list.append(ytid_caption_dict[ytid])

total_size_in_gb = total_size_in_bytes / 1024**3
print(total_size_in_gb) # 3K = 2.5GB -> 10GB = 12K, 100GB = 120K
print(counter_wav_len)

wav_list = torch.stack(wav_list)

with open('pkls/wavs.pkl', 'wb') as f:
    pickle.dump(wav_list, f)
    print('wavs.pkl saved')
    
with open('pkls/txts.pkl', 'wb') as f:
    pickle.dump(txt_list, f)
    print('txts.pkl saved')


In [None]:
# create dataloader for mulan
from torch.utils.data import Dataset


class MuLanDataset(Dataset):
    def __init__(self, audio_pt_path: Path, txt_pickle_path: Path, wav_pickle_path: Path):
        
        self.audio_pt_path = audio_pt_path
        
        with open(wav_pickle_path, 'rb') as f:
            self.wavs = pickle.load(f)
        
        with open(txt_pickle_path, 'rb') as f:
            self.txts = pickle.load(f)

        self.num_data = len(self.txts)
                
    def __len__(self):
        return self.num_data
    
    def __getitem__(self, idx):
        # read wav from pt file, read txt from list
        return self.wavs[idx], self.txts[idx]
    
    
    # def __getitem__(self, idx):
    #     # read wav from pt file, read txt from list
    #     wav = torch.load(self.audio_pt_path / f'{idx}.pt')
    #     txt = self.txts[idx]
    #     return wav, txt
    
    
training_data = MuLanDataset(audio_pt_path=Path('preprocessed_tensors'), 
                             txt_pickle_path=Path('pkls/txts.pkl'),
                             wav_pickle_path=Path('pkls/wavs.pkl'))

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

### Test Dataloader

In [None]:
wav, txt = next(iter(train_dataloader))
print(wav.shape)
print(len(txt))