In [None]:
# jukebox conda environment is required for this notebook to run

# mp3 support
!conda install -c conda-forge ffmpeg --yes 

import os
import math
import numpy as np
import librosa as l
import torch as t
import soundfile as sf

from jukebox.make_models import make_vqvae 
from jukebox.hparams import Hyperparams, setup_hparams

In [None]:
output_path = r'e:\ML\jukebox-upsampling-dataset'

folders = [
    r'c:\Music\Artist\Album',
    r'c:\Music\Artist2\Album2',
]

sr = 44100
chunk_length_in_seconds = 10

In [None]:
def setup_vqvae(chunk_length, sr):
    hps = Hyperparams(
        levels=3, 
        sample_length=chunk_length, 
        sr=sr,
        n_samples=1,
        hop_fraction=[0.5,0.5,0.125])

    return make_vqvae(setup_hparams('vqvae', 
                       dict(sample_length=hps.get('sample_length', 0), 
                            sample_length_in_seconds=hps.get('sample_length_in_seconds', 0))), 'cuda')

def process(folders, output_path, chunk_length_in_seconds, sr):
    counter = 0
    
    chunk_length = 128 * (sr * chunk_length_in_seconds // 128) # chunk size is rounded down to be multiple of 128

    vqvae = setup_vqvae(chunk_length, sr).cuda()
    
    for f in folders:
        for track in os.listdir(f):
            if any(ext in track for ext in ['.wav', '.flac', '.mp3']):
                fullPath = os.path.join(f, track)
                
                print(fullPath)
                
                y, _ = l.load(fullPath, sr = sr)

                y = l.util.normalize(y)

                for i in range(chunk_length, len(y), chunk_length):
                    chunk = y[i - chunk_length:i]
                    
                    sf.write(os.path.join(output_path, f'%06d.wav' % counter), chunk, sr, 'PCM_24')
                    
                    x = t.tensor(chunk).unsqueeze(0).unsqueeze(2).cuda()
                    zs = vqvae.encode(x, start_level=2)
                    emb = vqvae.bottleneck.decode(zs, start_level=2, end_level=None)

                    np.save(os.path.join(output_path, f'%06d.emb' % counter), emb[0].squeeze(0).cpu().detach().numpy())

                    counter += 1
                    
process(folders, output_path, chunk_length_in_seconds, sr)