<a href="https://colab.research.google.com/github/ethman/tagbox/blob/main/TagBox.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TagBox
### VQGAN+CLIP for Music!

By [Ethan Manilow](https://ethman.github.io/), [Patrick O'Reilly](https://oreillyp.github.io/), [Prem Seetharaman](https://pseeth.github.io/), and [Bryan Pardo](https://bryan-pardo.github.io/)


This notebook is an interactive demo for [TagBox](https://github.com/ethman/tagbox). Similar to VQGAN+CLIP, TagBox guides gradient ascent in [OpenAI's Jukebox](https://openai.com/blog/jukebox/)'s embedding space using [Automatic Music Tagger networks](https://arxiv.org/pdf/2006.00751.pdf).

See our paper describing the technique here.


*Last updated: Oct. 25, 2021*

# Setup



In [None]:
#@title Check GPU
#@markdown Run this cell to see what GPU the Colab Notebook is running.

!nvidia-smi

Tue Oct 26 15:52:28 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.74       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   28C    P8    29W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
#@title Install/Load Packages (may take a few minutes)
#@markdown Only run this cell once.  
#@markdown Don't forget to turn on GPUs prior to running (Edit -> Notebook Settings).


!pip install git+https://github.com/ethman/tagbox.git
!pip install youtube-dl

!mkdir -p input_audio
!mkdir -p output_audio


import ipywidgets as widgets
import numpy as np
from IPython.display import Audio, display
from scipy.signal import get_window
from scipy.io.wavfile import write
import torch
from torch import nn, optim
from torch.nn import functional as F
import torchaudio
from tqdm.notebook import tqdm, trange
import warnings
import os

from jukebox.utils.dist_utils import setup_dist_from_mpi
from tagbox import JUKEBOX_SAMPLE_RATE, TAGGER_SR
from tagbox import load_audio_for_jbx, audio_for_jbx, make_labels, to_np
from tagbox import setup_jbx
from tagbox.utils import encode, to_np
from sota_music_taggers.tag_metadata import JAMENDO_TAGS, MTAT_TAGS
from sota_music_taggers.predict import MusicTagger


warnings.filterwarnings('ignore')

device = None
selection = None
tagger_training_data = None
src_dict = None

disp = lambda a: float(to_np(a))


def decode(vqvae, xs_quantized):
    """Decode quantized codes, `xs_quantized` back to audio."""
    # TODO: Don't pass thru all 3 layers, it's hacky! :D
    x_outs = []
    for level in range(vqvae.levels):
        decoder = vqvae.decoders[level]
        x_out = decoder(xs_quantized[level:level + 1], all_levels=False)
        x_outs.append(x_out)

    return x_outs[0]


def make_masked_audio(input_audio, jbx_audio, n_fft):
    """Use Jukebox's audio to mask the input_audio."""
    eps = 1e-8
    input_audio = input_audio.squeeze()
    jbx_audio = jbx_audio.squeeze()

    window = torch.from_numpy(np.sqrt(get_window('hann', n_fft)))
    window = window.to(input_audio.device).to(input_audio.dtype)

    stft = lambda x: torch.stft(x, n_fft, hop_length=n_fft // 2,
                                window=window,
                                return_complex=True)

    input_stft = stft(input_audio)
    input_spec, input_phase = torch.abs(input_stft), torch.angle(input_stft)
    jbx_spec = torch.abs(stft(jbx_audio))
    mask = jbx_spec / (torch.maximum(input_spec, jbx_spec) + eps)

    masked_spec = mask * input_spec
    masked_stft = masked_spec * torch.exp(1j * input_phase)
    masked_audio = torch.istft(masked_stft, n_fft, hop_length=n_fft // 2,
                               window=window,
                               length=len(input_audio)).unsqueeze(0)
    return masked_audio


def run_tagbox(audio_path, source_tags, n_steps, step_size, device,
               model_types, training_data,
               mask_audio=False, n_fft=None, offset=0.0, dur=None):
    """Run TagBox on audio file."""

    labels = make_labels(source_tags)

    # Load input audio
    audio = load_audio_for_jbx(audio_path, device=device,
                               offset=offset, dur=dur)
    labels = labels.to(device)

    vqvae = setup_jbx('5b', device)
    vqvae.bottleneck.train()

    resampler = torchaudio.transforms.Resample(JUKEBOX_SAMPLE_RATE,
                                               TAGGER_SR).to(device)

    if not isinstance(model_types, list):
        model_types = [model_types]

    taggers = []
    for m in model_types:
        t = MusicTagger(m, training_data, return_feats=False)
        t.model.to(device)
        t.model.eval().requires_grad_(False)
        taggers.append(t)

    if mask_audio:
        if n_fft is None:
            n_fft = [2048, 1024, 512]
        elif not isinstance(n_fft, list):
            n_fft = [n_fft]

    orig_audio = vqvae.decode(vqvae.encode(audio))  # pass thru once -> JBX will resize

    encoded_audio = encode(vqvae, audio)
    encoded_audio = [e.detach().requires_grad_(True) for e in encoded_audio]

    optimizer = optim.Adam(encoded_audio, lr=step_size)

    label_loss = nn.BCELoss()

    jbxd_audio = None
    hist_loss = []

    with trange(n_steps) as t:
        for _ in t:

            optimizer.zero_grad()

            zs, xs_quantised, _, _ = vqvae.bottleneck(encoded_audio)
            jbxd_audio = decode(vqvae, xs_quantised)
            jbxd_audio = jbxd_audio[-1]

            loss_dict = {}
            if mask_audio:
                for fft_size in n_fft:
                    masked_audio = make_masked_audio(orig_audio, jbxd_audio, fft_size)
                    for tagger in taggers:
                        pred_labels = tagger(resampler(masked_audio.squeeze()))
                        pred_labels = torch.mean(pred_labels, dim=0, keepdims=True)
                        loss_dict[f'{tagger}_{fft_size}'] = label_loss(pred_labels, labels)
            else:
                for tagger in taggers:
                    pred_labels = tagger(resampler(jbxd_audio.squeeze()))
                    pred_labels = torch.mean(pred_labels, dim=0, keepdims=True)
                    loss_dict[f'{tagger}'] = label_loss(pred_labels, labels)

            total_loss = sum(loss_dict.values()) / len(loss_dict)
            hist_loss.append({k: to_np(v) for k, v in loss_dict.items()})
            total_loss.backward()
            optimizer.step()

            t.set_postfix(loss=disp(total_loss))

    results = {
        'jbxd_audio': to_np(jbxd_audio),
        'orig_audio': to_np(orig_audio),
        'hist_loss': hist_loss,
    }

    if mask_audio:
        jbxd_masked = make_masked_audio(orig_audio, jbxd_audio, fft_size)
        results.update({
            'jbxd_masked': to_np(jbxd_masked),
            'jbxd_diff': to_np(orig_audio) - to_np(jbxd_masked)
        })

    return results

Collecting git+https://github.com/ethman/tagbox.git
  Cloning https://github.com/ethman/tagbox.git to /tmp/pip-req-build-drgu7md2
  Running command git clone -q https://github.com/ethman/tagbox.git /tmp/pip-req-build-drgu7md2
Collecting jukebox@ git+https://github.com/openai/jukebox.git
  Cloning https://github.com/openai/jukebox.git to /tmp/pip-install-2a0pv9_n/jukebox_cef11d5438bf4719b7a2c61c7f66315f
  Running command git clone -q https://github.com/openai/jukebox.git /tmp/pip-install-2a0pv9_n/jukebox_cef11d5438bf4719b7a2c61c7f66315f
Collecting sota_music_taggers@ git+https://github.com/ethman/sota-music-tagging-models.git
  Cloning https://github.com/ethman/sota-music-tagging-models.git to /tmp/pip-install-2a0pv9_n/sota-music-taggers_ed8622a78aa44faea90769adaf56858b
  Running command git clone -q https://github.com/ethman/sota-music-tagging-models.git /tmp/pip-install-2a0pv9_n/sota-music-taggers_ed8622a78aa44faea90769adaf56858b


In [3]:
#@title Load audio from YouTube

youtube_id = 'ZJUg8UuJHkQ' #@param {type:"string"}

youtube_prefix = 'https://www.youtube.com/watch?v='
if youtube_prefix not in youtube_id:
  yt = youtube_prefix + youtube_id
else:
  yt = youtube_id

!youtube-dl --extract-audio --audio-format wav {yt} 

!mv *.wav input_audio/


[youtube] ZJUg8UuJHkQ: Downloading webpage
[youtube] ZJUg8UuJHkQ: Downloading player bc6d77fc
[download] Destination: Howl's Moving Castle Main Theme (Violin, Piano cover) ft. Zorsy-ZJUg8UuJHkQ.webm
[K[download] 100% of 5.21MiB in 02:06
[ffmpeg] Destination: Howl's Moving Castle Main Theme (Violin, Piano cover) ft. Zorsy-ZJUg8UuJHkQ.wav
Deleting original file Howl's Moving Castle Main Theme (Violin, Piano cover) ft. Zorsy-ZJUg8UuJHkQ.webm (pass -k to keep)


# Set Paramaters

In [4]:
#@title Tagger Params { run: "auto", display-mode: "form" }

tagger_training_data = 'MTG-Jamendo' #@param ["MTG-Jamendo", "MagnaTagATune"] {allow-input: false}


#@markdown ### Tagger architectures
#@markdown Note that some combinations of `tagger_training_data` and models do not work.
#@markdown Sorry for the inconvenience!
fcn = True #@param {type:"boolean"}
hcnn = True #@param {type:"boolean"} 
musicnn = True #@param {type:"boolean"}
crnn = False #@param {type:"boolean"}
sample = False #@param {type:"boolean"}
se = False #@param {type:"boolean"}
attention = False #@param {type:"boolean"}
short = False #@param {type:"boolean"}
short_res = False #@param {type:"boolean"}


all_models = ['fcn', 'musicnn', 'crnn', 'sample', 'se', 'attention', 'hcnn',
              'short', 'short_res']
model_types = []
for m in all_models:
  if locals()[m]:
    model_types.append(m)


### Preset Tag Selection

Shortcuts for combining common sets of tags. Useful for MagnaTagATune, which has a bunch of conceptually overlapping tags ("vocals", "vocal", "singing", etc).

Only one of [`Preset Tags`, `Manual Tags`] will be used. Selected tags in whichever cell was last run will be used.

In [5]:
#@markdown Please run `Tagger Params` cell (above) and rerun this cell to see options.
#@markdown Hold down [ctrl] (or [cmd] on Mac) to select multiple tags.

selection = 'preset'
mtat_sources = {
    'Vocals': [13, 15, 18, 19, 20, 26, 27, 33, 35, 36, 39, 40, 48, 49], 
    'Guitar': [0], 
    'Piano': [9, 22], 
    'Strings': [4, 12, 43], 
    'Drums': [5, 11, 41],
    'Violin': [12],
    'Harpsichord': [22],
    'Sitar': [32],
    'Harp': [42],
    'Cello': [43],
    'Flute': [25], 
    'Synth': [14],
    'Male Vocals': [13, 18, 19, 20, 27, 33, 39],
    'Female Vocals': [13, 15, 26, 36, 40, 48],
    'Choir': [35]
}

mtg_sources = {
    'Synthesizer': [3], 
    'Strings': [8], 
    'Drums': [9], 
    'Drum Machine': [10], 
    'Guitar': [12],
    'Acoustic Guitar': [23],
    'Piano': [25],
    'Electric Guitar': [28],
    'Violin': [37],
    'Voice': [40],
    'Keyboard': [41],
    'Bass': [43],
    'Computer': [44],

}

if tagger_training_data == 'MTG-Jamendo':
  tags_list = mtg_sources.keys()
  src_dict = mtg_sources
  desc = 'MTG'
elif tagger_training_data == 'MagnaTagATune':
  tags_list = mtat_sources.keys()
  src_dict = mtat_sources
  desc = 'MTAT'
else:
  raise ValueError('Run Tagger Params cell before making selection.')
  
tag_picker = widgets.SelectMultiple(
    options=tags_list,
    description=desc,
    disabled=False
)

tag_picker

SelectMultiple(description='MTG', options=('Synthesizer', 'Strings', 'Drums', 'Drum Machine', 'Guitar', 'Acous…

### Manual Tag Selection

Select each tag individually.

Only one of [`Preset Tags`, `Manual Tags`] will be used. Selected tags in whichever cell was last run will be used.

In [None]:
#@markdown ### Select Tags Manually (multi-select)
#@markdown Please run `Tagger Params` cell (above) and rerun this cell to see options.
#@markdown Hold down [ctrl] (or [cmd] on Mac) to select multiple tags.

selection = 'manual'

if tagger_training_data == 'MTG-Jamendo':
  tags_list = JAMENDO_TAGS
  desc = 'MTG'
elif tagger_training_data == 'MagnaTagATune':
  tags_list = MTAT_TAGS
  desc = 'MTAT'
else:
  raise ValueError('Run Tagger Params cell before making selection.')

tag_picker = widgets.SelectMultiple(
    options=tags_list,
    description=desc,
    disabled=False
)

tag_picker

SelectMultiple(description='MTAT', options=('guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'elec…

# Run

In [6]:
#@markdown Pick file to run. (Run this cell to refresh)


file_picker = widgets.Dropdown(
    options=os.listdir('input_audio'),
    description='Select File',
    disabled=False
)


file_picker

Dropdown(description='Select File', options=("Howl's Moving Castle Main Theme (Violin, Piano cover) ft. Zorsy-…

In [33]:
#@title Run

if device is None:
  rank, local_rank, device = setup_dist_from_mpi()

#@markdown ### TagBox Settings
#@markdown Check `use_mask` for source separation, else will do style transfer.
use_mask = True #@param {type:"boolean"}

#@markdown \# FFTs (only used if `use_mask=True`)
fft512 = False #@param {type:"boolean"}
fft1024 = True #@param {type:"boolean"}
fft2048 = True #@param {type:"boolean"}

n_ffts = []
if fft512:
  n_ffts.append(512)
if fft1024:
  n_ffts.append(1024)
if fft2048:
  n_ffts.append(2048)

#@markdown Gradient Ascent params
lr = 10.0  #@param {type: "number"}
steps = 100  #@param {type: "number"}

#@markdown Run on the whole file?
#@markdown If you see an 'Out Of Memory' error (OOM), try decreasing the duration.
start_sec =   60#@param {type: "number"}
duration_sec =   20#@param {type: "number"}


if duration_sec <= 0.0:
  duration_sec = None

audio_file = file_picker.value
audio_path = f'/content/input_audio/{audio_file}'

print_tags = []
if selection == 'preset':
  source_tags = []
  for tags in list(tag_picker.value):
    source_tags.extend(src_dict[tags])
    print_tags.append(tags)
elif selection == 'manual':
  tags_list
  source_tags = [tags_list.index(t) for t in list(tag_picker.value)]
  print_tags = [t for t in list(tag_picker.value)]
else:
  raise ValueError('Have you made a tag selection yet?')

tagger_data = 'mtat' if tagger_training_data == 'MagnaTagATune' else 'jamendo'

if len(source_tags) == 0:
  raise ValueError('No tags selected. Try rerunnig prev cells.')

if len(model_types) == 0:
  raise ValueError('No models selected. Try rerunning prev cells.')

print(f'Running TagBox on {audio_file}.')
print(f'Using model(s): {model_types}.')
print(f'Using tags: {print_tags}.\n')
print('-' * 40)
print()

result_dict = run_tagbox(audio_path, source_tags, steps, lr, device,
                         model_types, tagger_data,
                         mask_audio=use_mask, n_fft=n_ffts,
                         offset=start_sec, dur=duration_sec)

audio_nm = audio_file.replace('.wav', '')
print('\nJukebox Reconstruction')
orig = result_dict['orig_audio']
display(Audio(orig, rate=JUKEBOX_SAMPLE_RATE))
write(f'output_audio/{audio_nm}_jukebox.wav', JUKEBOX_SAMPLE_RATE, orig)

print('\nRaw TagBox output')
raw = result_dict['jbxd_audio']
display(Audio(raw, rate=JUKEBOX_SAMPLE_RATE))
write(f'output_audio/{audio_nm}_tagbox.wav', JUKEBOX_SAMPLE_RATE, raw)

if use_mask:
    print('\nMasked output (separated)')
    diff = result_dict['jbxd_diff']
    display(Audio(diff, rate=JUKEBOX_SAMPLE_RATE))
    write(f'output_audio/{audio_nm}_masked.wav', JUKEBOX_SAMPLE_RATE, diff)

    print('\nRaw Masked output (separated)')
    masked = result_dict['jbxd_masked']
    display(Audio(masked, rate=JUKEBOX_SAMPLE_RATE))
    write(f'output_audio/{audio_nm}_raw_masked.wav', JUKEBOX_SAMPLE_RATE, masked)



Running TagBox on Howl's Moving Castle Main Theme (Violin, Piano cover) ft. Zorsy-ZJUg8UuJHkQ.wav.
Using model(s): ['fcn', 'musicnn', 'hcnn'].
Using tags: ['Strings'].

----------------------------------------

Downloading from azure
Restored from /root/.cache/jukebox/models/5b/vqvae.pth.tar
0: Loading vqvae in eval mode


HBox(children=(FloatProgress(value=0.0), HTML(value='')))



Jukebox Reconstruction



Raw TagBox output



Masked output (separated)



Raw Masked output (separated)
