A notebook for extracting embeddings from [OpenAI's Jukebox model](https://openai.com/index/jukebox/), following the approach described in [Castellon et al. (2021)](https://arxiv.org/abs/2107.05677) with some modifications followed in [Spotify's Llark paper](https://arxiv.org/pdf/2310.07160):

- Source: Output of the 36th layer of the Jukebox encoder
- Original Jukebox encoding: 4800-dimensional vectors at 345Hz
- Audio/embeddings are chunked into 25 seconds clips as that is the max Jukebox can take in as input, any clips shorter than 25 seconds are padded before passed through Jukebox
- Approach: Mean-pooling within 100ms frames, resulting in:
    - Downsampled frequency: 10Hz
    - Embedding size: 1.2 × 10^6 for a 25s audio clip.
    - For a 25s audio clip the 2D array shape will be [250, 4800]
- This method retains temporal information while reducing the embedding size

Having a Colab notebook for this gives us an easily reproducible environment and allows us to take advantage of the cheap T4 GPU's Colab offers.

Join the Jukebox community at https://discord.gg/aEqXFN9amV

Speed upsampling supported. Switch to upsample mode will happen automatically if data file is detected within the folder provided.

**How to handle memory problems:** In theory, this notebook is crafted to avoid Out of memory errors, but here's some tricks if you still encounter one:
* Restart runtime: At the top of the notebook, click "Runtime" and then "Restart runtime". Then run everything again. You should do this everytime you start a second run within the same session or after you've interrupted one.
* Decrease sample count: Choose a lower number for 'hps.n_samples'.

**Please note that the core Jukebox architecture hasn't been updated since 2020. This means Jukebox is very slow and inefficient at generating samples. It will take a couple of hours to get any results.**

In [None]:
#@title Check which GPU you were assigned by running this cell.
!nvidia-smi -i 0 -e 0
!nvidia-smi -L
your_lyrics = """
"""

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title Select your model and settings

from google.colab import drive
drive.mount('/content/gdrive')

!pip install --upgrade git+https://github.com/craftmine1000/jukebox-saveopt.git

import jukebox
import torch as t
import librosa
import os
from IPython.display import Audio
from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
from jukebox.hparams import Hyperparams, setup_hparams
from jukebox.sample import sample_single_window, _sample, \
                           sample_partial_window, upsample, \
                           load_prompts
from jukebox.utils.dist_utils import setup_dist_from_mpi
from jukebox.utils.torch_utils import empty_cache
# MPI Connect. MPI doesn't like being initialized twice, hence the following
try:
    if device is not None:
        pass
except NameError:
    rank, local_rank, device = setup_dist_from_mpi()

model = "5b" # @param ["5b_lyrics","5b","1b_lyrics"]
if model == '5b':
  your_lyrics = """
  """

save_and_load_models_from_drive = True

#START GDRIVE MODEL LOADER
if save_and_load_models_from_drive == True:
  import os.path
  !apt install pv
  !mkdir /root/.cache ; mkdir /root/.cache/jukebox ; mkdir /root/.cache/jukebox/models ; mkdir /root/.cache/jukebox/models/1b_lyrics ; mkdir /root/.cache/jukebox/models/5b_lyrics ; mkdir /root/.cache/jukebox/models/5b
  !mkdir /content/gdrive/MyDrive/jukebox
  !mkdir /content/gdrive/MyDrive/jukebox/models
  !mkdir /content/gdrive/MyDrive/jukebox/models/5b
  !mkdir /content/gdrive/MyDrive/jukebox/models/5b_lyrics
  !mkdir /content/gdrive/MyDrive/jukebox/models/1b_lyrics


def load_5b_vqvae():
    if os.path.exists("/root/.cache/jukebox/models/5b/vqvae.pth.tar") == False:
      if os.path.exists("/content/gdrive/MyDrive/jukebox/models/5b/vqvae.pth.tar") == False:
        print("5b_vqvae not stored in Google Drive. Downloading for the first time.")
        !wget https://openaipublic.azureedge.net/jukebox/models/5b/vqvae.pth.tar -O /content/gdrive/MyDrive/jukebox/models/5b/vqvae.pth.tar
      else:
        print("5b_vqvae stored in Google Drive.")
      print('Copying 5b VQVAE')
      !pv /content/gdrive/MyDrive/jukebox/models/5b/vqvae.pth.tar > /root/.cache/jukebox/models/5b/vqvae.pth.tar

def load_1b_lyrics_level2():
  if os.path.exists("/root/.cache/jukebox/models/1b_lyrics/prior_level_2.pth.tar") == False:
    if os.path.exists("/content/gdrive/MyDrive/jukebox/models/1b_lyrics/prior_level_2.pth.tar") == False:
      print("1b_lyrics_level_2 not stored in Google Drive. Downloading for the first time. This will take a few more minutes.")
      !wget https://openaipublic.azureedge.net/jukebox/models/1b_lyrics/prior_level_2.pth.tar -O /content/gdrive/MyDrive/jukebox/models/1b_lyrics/prior_level_2.pth.tar
    else:
      print("1b_lyrics_level_2 stored in Google Drive.")
    print("Copying 1B_Lyrics Level 2")
    !pv /content/gdrive/MyDrive/jukebox/models/1b_lyrics/prior_level_2.pth.tar > /root/.cache/jukebox/models/1b_lyrics/prior_level_2.pth.tar

def load_5b_lyrics_level2():
  if os.path.exists("/root/.cache/jukebox/models/5b_lyrics/prior_level_2.pth.tar") == False:
    if os.path.exists("/content/gdrive/MyDrive/jukebox/models/5b_lyrics/prior_level_2.pth.tar") == False:
      print("5b_lyrics_level_2 not stored in Google Drive. Downloading for the first time. This will take up to 10-15 minutes.")
      !wget https://openaipublic.azureedge.net/jukebox/models/5b_lyrics/prior_level_2.pth.tar -O /content/gdrive/MyDrive/jukebox/models/5b_lyrics/prior_level_2.pth.tar
    else:
      print("5b_lyrics_level_2 stored in Google Drive.")
    print("Copying 5B_Lyrics Level 2")
    !pv /content/gdrive/MyDrive/jukebox/models/5b_lyrics/prior_level_2.pth.tar > /root/.cache/jukebox/models/5b_lyrics/prior_level_2.pth.tar

def load_5b_level1():
  if os.path.exists('/root/.cache/jukebox/models/5b/prior_level_1.pth.tar') == False:
    if os.path.exists("/content/gdrive/MyDrive/jukebox/models/5b/prior_level_1.pth.tar") == False:
      print("5b_level_1 not stored in Google Drive. Downloading for the first time. This may take a few more minutes.")
      !wget https://openaipublic.azureedge.net/jukebox/models/5b/prior_level_1.pth.tar -O /content/gdrive/MyDrive/jukebox/models/5b/prior_level_1.pth.tar
    else:
      print("5b_level_1 stored in Google Drive.")
    print("Copying 5B Level 1")
    !pv /content/gdrive/MyDrive/jukebox/models/5b/prior_level_1.pth.tar > /root/.cache/jukebox/models/5b/prior_level_1.pth.tar

def load_5b_level0():
  if os.path.exists('/root/.cache/jukebox/models/5b/prior_level_0.pth.tar') == False:
    if os.path.exists("/content/gdrive/MyDrive/jukebox/models/5b/prior_level_0.pth.tar") == False:
      print("5b_level_0 not stored in Google Drive. Downloading for the first time. This may take a few minutes.")
      !wget https://openaipublic.azureedge.net/jukebox/models/5b/prior_level_0.pth.tar -O /content/gdrive/MyDrive/jukebox/models/5b/prior_level_0.pth.tar
    else:
      print("5b_level_0 stored in Google Drive.")
    print("Copying 5B Level 0")
    !pv /content/gdrive/MyDrive/jukebox/models/5b/prior_level_0.pth.tar > /root/.cache/jukebox/models/5b/prior_level_0.pth.tar

def load_5b_level2():
  if os.path.exists('/root/.cache/jukebox/models/5b/prior_level_2.pth.tar') == False:
    if os.path.exists("/content/gdrive/MyDrive/jukebox/models/5b/prior_level_2.pth.tar") == False:
      print("5b_level_2 not stored in Google Drive. Downloading for the first time. This will take up to 10-15 minutes.")
      !wget https://openaipublic.azureedge.net/jukebox/models/5b/prior_level_2.pth.tar -O /content/gdrive/MyDrive/jukebox/models/5b/prior_level_2.pth.tar
    else:
      print("5b_level_2 stored in Google Drive.")
  print("Copying 5B Level 2")
  !pv /content/gdrive/MyDrive/jukebox/models/5b/prior_level_2.pth.tar > /root/.cache/jukebox/models/5b/prior_level_2.pth.tar

if save_and_load_models_from_drive == True:
  if model == '5b_lyrics':
    load_5b_vqvae()
    load_5b_lyrics_level2()
    load_5b_level1()
    load_5b_level0()
  if model == '5b':
    load_5b_vqvae()
    load_5b_level2()
    load_5b_level1()
    load_5b_level0()
  elif model == '1b_lyrics':
    load_5b_vqvae()
    load_1b_lyrics_level2()
    load_5b_level1()
    load_5b_level0()
#END GDRIVE MODEL LOADER


In [None]:
import jukebox
import torch as t
import librosa
import os
from IPython.display import Audio
from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
from jukebox.hparams import Hyperparams, setup_hparams
from jukebox.sample import sample_single_window, _sample, sample_partial_window, upsample, load_prompts
from jukebox.utils.dist_utils import setup_dist_from_mpi
from jukebox.utils.torch_utils import empty_cache

# MPI Connect
try:
    if device is not None:
        pass
except NameError:
    rank, local_rank, device = setup_dist_from_mpi()

# Model selection and configuration
model = "5b" #@param ["5b_lyrics", "5b", "1b_lyrics"]
if model == '5b':
    your_lyrics = """"""

hps = Hyperparams()
hps.sr = 44100
hps.n_samples = 2 #@param {type:"integer"}
hps.name = '/content/gdrive/MyDrive/Project_1' #@param {type:"string"}
chunk_size = 64 if model in ('5b', '5b_lyrics') else 128

# GPU detection and configuration
gpu_info = !nvidia-smi -L
if gpu_info[0].find('Tesla T4') >= 0:
    max_batch_size = 2
elif gpu_info[0].find('Tesla K80') >= 0:
    max_batch_size = 8
elif gpu_info[0].find('Tesla P100') >= 0:
    max_batch_size = 3
elif gpu_info[0].find('Tesla V100') >= 0:
    max_batch_size = 3
elif gpu_info[0].find('L4') >= 0:
    max_batch_size = 3
elif gpu_info[0].find('A100') >= 0:
    max_batch_size = 6
else:
    max_batch_size = 3

print(f'{gpu_info[0]} detected, max_batch_size set to {max_batch_size}')

hps.levels = 3
hps.hop_fraction = [0.5, 0.5, 0.125]

import torch

# Load model
vqvae, *priors = MODELS[model]
vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)
top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vqvae = vqvae.to(device)
top_prior = top_prior.to(device)

## Download musicnet

This notebook generates embeddings for the MusicNet dataset as an example, the full dataset can be found here on HuggingFace: https://huggingface.co/datasets/jonflynn/musicnet_jukebox_embeddings

In [None]:
%%time
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json  # upload your kaggle.json key file

!mkdir -p /content/musicnet
!kaggle datasets download -d imsparsh/musicnet-dataset
!unzip musicnet-dataset.zip -d /content/musicnet

## Auth with GCP to store embeddings

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
import os

# Set the environment variable to point to your service account key file, upload your service account json key
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/content/admin.json'

# Authenticate with gcloud
!gcloud auth activate-service-account --key-file $GOOGLE_APPLICATION_CREDENTIALS

In [None]:
!pip install librosa

Set your GCP env up

In [None]:
GOOGLE_CLOUD_PROJECT = "your-gcp-project"
GCS_BUCKET_NAME = "jukebox-embeddings"
GCP_REGION = "us-central1"

os.environ["GOOGLE_CLOUD_PROJECT"] = GOOGLE_CLOUD_PROJECT

## Extract embeddings

In [None]:
JUKEBOX_SAMPLE_RATE = 44100
T = 8192
JUKEBOX_EXPECTED_SAMPLES_LEN = 1048576
JUKEBOX_SAMPLE_SECONDS = JUKEBOX_EXPECTED_SAMPLES_LEN / JUKEBOX_SAMPLE_RATE
ACTS_SAMPLE_RATE = T / JUKEBOX_SAMPLE_SECONDS

import argparse
import io
import logging
import pathlib
import tempfile
from functools import lru_cache
from math import floor
from pathlib import Path
import librosa as lr
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, List, Callable
import numpy as np
import torch
from google.cloud import storage
import concurrent.futures
import re

class EmptyFileError(ValueError):
    pass

@lru_cache(None)
def gcs_client():
    return storage.Client()

@lru_cache(None)
def gcs_bucket(bucket_name: str):
    return gcs_client().get_bucket(bucket_name)

def to_device(tensor, device):
    if isinstance(tensor, (list, tuple)):
        return [to_device(t, device) for t in tensor]
    elif isinstance(tensor, dict):
        return {k: to_device(v, device) for k, v in tensor.items()}
    elif isinstance(tensor, torch.Tensor):
        return tensor.to(device)
    return tensor

def split_gcs_bucket_and_filepath(filepath: str) -> Tuple[str, str]:
    return filepath.replace("gs://", "").split("/", maxsplit=1)

def read_wav_bytes(filepath: str) -> bytes:
    assert filepath.startswith("gs://"), f"Expected a file path on GCS, but got: {filepath!r}"
    bucket_name, file_name = split_gcs_bucket_and_filepath(filepath)
    return gcs_bucket(bucket_name).blob(file_name).download_as_string()

def list_gcs_folders(input_dir: str) -> Sequence[str]:
    bucket, prefix = split_gcs_bucket_and_filepath(input_dir)
    blobs = gcs_client().list_blobs(bucket, prefix=prefix, delimiter='/')

    folder_paths = []
    for page in blobs.pages:
        for prefix in page.prefixes:
            folder_paths.append(f"gs://{bucket}/{prefix}")

    return folder_paths

def load_audio_from_file(fpath: str) -> np.ndarray:
    try:
        audio, _ = lr.load(fpath, sr=JUKEBOX_SAMPLE_RATE)
    except ValueError as ve:
        raise EmptyFileError(f"file {fpath} failed to read with exception {ve!r}; it is probably empty.")
    if audio.ndim == 1:
        audio = audio[np.newaxis]
    audio = audio.mean(axis=0)
    norm_factor = np.abs(audio).max()
    if norm_factor > 0:
        audio /= norm_factor
    return audio.flatten()

def download_gcs_file(gcs_path: str) -> str:
    bucket_name, file_name = split_gcs_bucket_and_filepath(gcs_path)
    bucket = gcs_bucket(bucket_name)
    blob = bucket.blob(file_name)
    _, local_path = tempfile.mkstemp()
    blob.download_to_filename(local_path)
    return local_path

def chunk_audio_file(audio_path: str, chunk_duration: float) -> List[np.ndarray]:
    audio, sr = lr.load(audio_path, sr=JUKEBOX_SAMPLE_RATE)
    total_duration = lr.get_duration(y=audio, sr=sr)
    if audio.ndim == 1:
        audio = audio[np.newaxis]
    audio = audio.mean(axis=0)
    norm_factor = np.abs(audio).max()
    if norm_factor > 0:
        audio /= norm_factor
    audio = audio.flatten()

    chunks = []

    # If the total duration is less than or equal to chunk_duration,
    # treat the entire audio as a single chunk
    if total_duration <= chunk_duration:
        chunk = audio
        if len(chunk) < int(chunk_duration * sr):
            chunk = np.pad(chunk, (0, int(chunk_duration * sr) - len(chunk)))
        chunks.append((chunk, 0))
    else:
        for start_time in np.arange(0, total_duration, chunk_duration):
            start_sample = int(start_time * sr)
            end_sample = min(int(start_sample + chunk_duration * sr), len(audio))
            chunk = audio[start_sample:end_sample]

            # Pad the chunk if it's shorter than chunk_duration
            if len(chunk) < int(chunk_duration * sr):
                chunk = np.pad(chunk, (0, int(chunk_duration * sr) - len(chunk)))

            chunks.append((chunk, start_time))

    return chunks

def maybe_pad_audio_to_max_len(audio: np.ndarray) -> np.ndarray:
    if len(audio) < JUKEBOX_EXPECTED_SAMPLES_LEN:
        audio = np.pad(audio, (0, JUKEBOX_EXPECTED_SAMPLES_LEN - len(audio)))
    return audio

def get_z(audio: np.ndarray, vqvae):
    assert len(audio) >= JUKEBOX_EXPECTED_SAMPLES_LEN, f"expected samples with shape {JUKEBOX_EXPECTED_SAMPLES_LEN}; got shape {audio.shape}."
    audio = audio[:JUKEBOX_EXPECTED_SAMPLES_LEN]
    audio_tensor = torch.FloatTensor(audio[np.newaxis, :, np.newaxis]).to(device)
    zs = vqvae.encode(audio_tensor)
    z = zs[-1].flatten()[np.newaxis, :]
    if z.shape[-1] < T:
        raise ValueError("Audio file is not long enough")
    return z.to(device).long()  # Convert to long and ensure it's on the correct device

def get_cond(hps, top_prior):
    sample_length_in_seconds = 62
    hps.sample_length = (int(sample_length_in_seconds * hps.sr) // top_prior.raw_to_tokens) * top_prior.raw_to_tokens
    metas = [dict(artist="unknown", genre="unknown", total_length=hps.sample_length, offset=0, lyrics="""lyrics go here!!!""")] * hps.n_samples
    labels = [None, None, top_prior.labeller.get_batch_labels(metas, "cuda")]
    x_cond, y_cond, prime = top_prior.get_cond(None, top_prior.get_y(labels[-1], 0))
    x_cond = x_cond[0, :T][np.newaxis, ...].to(device)
    y_cond = y_cond[0][np.newaxis, ...].to(device)
    return x_cond, y_cond

def get_final_activations(z: torch.Tensor, x_cond: torch.Tensor, y_cond: torch.Tensor, top_prior: Any) -> torch.Tensor:
    x = z[:, :T].to(device).long()  # Ensure x is a CUDA LongTensor
    top_prior.prior.only_encode = True
    out = top_prior.prior.forward(x, x_cond=x_cond, y_cond=y_cond, encoder_kv=None, fp16=False)
    return out

def windowed_average(acts: torch.Tensor, frame_len: int, ceil_mode: bool = False) -> torch.Tensor:
    assert acts.ndim == 2, "expected 2d inputs"
    assert acts.shape[1] == 4800
    acts = torch.unsqueeze(acts, 0)
    acts = torch.transpose(acts, 1, 2)
    pool = torch.nn.AvgPool1d(frame_len, stride=frame_len, ceil_mode=ceil_mode)
    acts = pool(acts)
    return torch.transpose(acts, 1, 2)

def get_acts_from_file(audio: np.ndarray, hps, vqvae, top_prior, meanpool=True, pool_frames_per_second=None):
    input_audio_len = len(audio)
    latent_audio_len = floor(T * input_audio_len / JUKEBOX_EXPECTED_SAMPLES_LEN)
    audio = maybe_pad_audio_to_max_len(audio)
    z = get_z(audio, vqvae)
    x_cond, y_cond = get_cond(hps, top_prior)
    acts = get_final_activations(z, x_cond, y_cond, top_prior)
    acts = acts.squeeze().type(torch.float32)
    acts = acts[:latent_audio_len, :]
    if meanpool:
        logging.warning(f"mean pooling at f={pool_frames_per_second}")
        if not pool_frames_per_second:
            acts = acts.mean(dim=0)
        else:
            frame_len = floor(ACTS_SAMPLE_RATE / pool_frames_per_second)
            acts = windowed_average(acts, frame_len)
            acts = torch.squeeze(acts, 0)
    acts = acts.cpu().numpy()
    logging.info(f"acts after pooling has shape {acts.shape}")
    return acts

def upload_to_gcs(storage_client_factory: Callable[[], storage.Client], bucket_name: str, blob_path: str, representation: np.ndarray) -> None:
    try:
        storage_client = storage_client_factory()
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(blob_path)
        with blob.open("wb") as f:
            np.save(f, representation)
        print(f"Uploaded to gs://{bucket_name}/{blob_path}")
    except Exception as e:
        print(f"Failed to upload {blob_path} to {bucket_name}: {str(e)}")

from typing import Callable

def upload_to_gcs_in_batches(upload_batch: List[Tuple[Callable, str, str, torch.Tensor]], batch_size: int):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        for i in range(0, len(upload_batch), batch_size):
            batch = upload_batch[i:i + batch_size]
            executor.map(lambda args: upload_to_gcs(*args), batch)

def process_musicnet_dataset(input_dir: str, output_dir: str, batch_size: int = 10):
    storage_client_factory = lambda: storage.Client()
    chunk_duration = 25.0

    train_data_dir = os.path.join(input_dir, 'musicnet', 'musicnet', 'train_data')
    test_data_dir = os.path.join(input_dir, 'musicnet', 'musicnet', 'test_data')

    # Combine the train and test directories into a single list
    data_dirs = [train_data_dir, test_data_dir]

    # Iterate through both directories
    for data_dir in data_dirs:
        for subdir, dirs, files in os.walk(data_dir):
            for file in files:
                if file.endswith('.wav'):
                    local_audio_path = os.path.join(subdir, file)
                    print(f"Processing {local_audio_path}")

                    audio_chunks = chunk_audio_file(local_audio_path, chunk_duration)

                    upload_batch = []
                    for chunk, start_time in audio_chunks:
                        # Construct output filename
                        output_filename = f"{os.path.basename(local_audio_path)}_{int(start_time)}.npy"

                        if output_dir.startswith("gs://"):
                            # Construct output path in GCS
                            bucket_name, blob_path = split_gcs_bucket_and_filepath(output_dir)
                            output_path = os.path.join(blob_path, output_filename)

                            # Check if output file already exists in GCS
                            storage_client = storage_client_factory()
                            out_bucket = storage_client.bucket(bucket_name)
                            out_blob = out_bucket.blob(output_path)
                            if out_blob.exists():
                                print(f"Output file {output_filename} already exists. Skipping this chunk.")
                                continue  # Skip processing this chunk

                            # Process the chunk since it hasn't been processed yet
                            with torch.no_grad():
                                representation = get_acts_from_file(
                                    chunk,
                                    hps,
                                    vqvae,
                                    top_prior,
                                    meanpool=True,
                                    pool_frames_per_second=10,
                                )

                            # Add to upload batch
                            upload_batch.append((storage_client_factory, bucket_name, output_path, representation))

                    print(f"Number of chunks to upload: {len(upload_batch)}")
                    if upload_batch:
                        print("Uploading representations to GCS...")
                        upload_to_gcs_in_batches(upload_batch, batch_size)

In [None]:
process_musicnet_dataset("/content/musicnet/", f"gs://{GCS_BUCKET_NAME}/musicnet_embeddings/")