# Liquid and Dancefloor Drum and Bass Style Transfer Demo

This notebook demonstrates the pretrained style transfer system for Drum & Bass music from [our paper](https://biblio.ugent.be/publication/8619952).  
This notebook was written by Len Vande Veire.

In order to use it, copy some Drum & Bass tracks into the `./_music` directory, execute the following cells, and select your song from the list in the GUI that will appear below.

In [None]:
import glob
import IPython
import ipywidgets as widgets
import matplotlib.pyplot as plt
import os
from PIL import Image
import sys

import autodj
import autodj.dj.annotators.wrappers as annot 
from autodj.dj.song import Song 
from autodj.dj.timestretching import *
from data import CreateDataLoader
from generate_util import *
from models import create_model
from options.test_options import TestOptions
from torchvision import transforms

# Definition of functions

## Step 1: extract a fragment from the input audio

The Drum & Bass tracks you provide in the `./_music` directory will be analyzed and segmented at the (estimated) position of the drop. In the style transfer application, only the extracted fragments of the selected song will be transformed into the other genre.

In [None]:
def extract_segments_from_song(filename, phrases_before_drop = 2, length_in_phrases = 2):
    
    annotation_modules = [
        annot.BeatAnnotationWrapper(),
        annot.OnsetCurveAnnotationWrapper(),
        annot.DownbeatAnnotationWrapper(),
        annot.StructuralSegmentationWrapper(),
        annot.ReplayGainWrapper(),
    ]
    song = Song(filename, annotation_modules=annotation_modules)
    if not song.hasAllAnnot():
        print('Annotating song...')
        song.annotate()
    else:
        print('Song already annotated!')

    song.open()
    song.openAudio()
    segments_H = [i for i in range(len(song.segment_types)) if song.segment_types[i] == 'H']

    extracted_segments = []
    seg_idx = segments_H[0]
    start = phrases_before_drop
    for i in range(length_in_phrases):
        start_idx = int(song.downbeats[song.segment_indices[seg_idx] + (i-start)*4] * 44100)
        end_idx = int(song.downbeats[song.segment_indices[seg_idx] + (i-start+1)*4] * 44100)
        extracted_segments.append(song.audio[start_idx:end_idx])
        
    return extracted_segments

## Step 2: convert the audio fragment to a spectrogram

The code below transforms each extract into the spectrogram representation that will be transformed by the CycleGAN model.

In [None]:
def load_extract(widgets):
    return librosa.load(os.path.join('./_music', widgets['file'].value), sr=44100)

In [None]:
def convert_extract_to_png(y, sr, widgets):
    
    if not os.path.exists('./_temp/'):
        os.mkdir('./_temp/')
    
    for f in glob.glob("./_temp/*"):
        os.remove(f)

    S_mel, S_stft = mel_spectrogram(y, sr, crop_to_multiple_of_4=True)
    S_mel = (S_mel - S_mel.min()) / (S_mel.max() - S_mel.min())
    S_mel_as_uint8 = (S_mel * 255).astype(np.uint8)
    im = Image.fromarray(S_mel_as_uint8).convert("L")
    im.save(os.path.join('./_temp', os.path.splitext(widgets['file'].value)[0] + '.png'))
    return S_mel

## Step 3: apply the CycleGAN model

This code loads the CycleGAN model and applies it to the extracted audio segment.

In [None]:
def apply_model():
    
    argv = [
        '--dataroot', './_temp',
        '--name', 'maps_cyclegan', '--model', 'cycle_gan',
        '--dataset_mode', 'single',
        '--resize_or_crop', 'none',
        '--gpu_ids', '-1',
    ]

    opt = TestOptions().parse(argv)
    opt.num_threads = 1   # test code only supports num_threads = 1
    opt.batch_size = 1    # test code only supports batch_size = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True    # no flip
    opt.display_id = -1 # no visdom display

    model = create_model(opt)
    model.setup(opt)

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    _, data = next(enumerate(dataset))

    model.set_input(data)
    model.test()
    return model.get_current_visuals()

This code gets the required audio streams from the transformed results.

In [None]:
def display_results(y_orig, img_orig, visuals, widgets):
    A_or_B = 'A' if widgets['genre'].value == w_genre_options['A'] else 'B'
    y_lws, S_lws, y_no_lws, S_orig, y_lws_orig, img = tensor_to_spectrogram_and_audio(visuals['fake_{}'.format(A_or_B)], y_orig)
    
    if widgets['show_intermediate'].value:
        plt.figure()
        plt.imshow(img)    
        plt.figure()
        plt.imshow(img_orig)
        plt.show()
        
        def _display_audio(y, title):  
            a = IPython.display.Audio(y_orig, rate=44100,)
            print(title)
            display(a)
        
        _display_audio(y_lws, 'LWS-reconstructed phase')
        _display_audio(y_no_lws, 'Original phase')
    
    return y_lws, y_no_lws, y_lws_orig

# The actual application

This code defines a GUI with which you can easily interact with the demo. Have fun!

In [None]:
file_options = [f for f in os.listdir('./_music') if f.endswith('.wav') or f.endswith('.mp3')]

w_file = widgets.Select(
    options=file_options,
    # rows=10,
    description='Select a file:',
    disabled=False
)
w_file.layout.width = '100%'

w_genre_options = {'B' : 'liquid > dancefloor', 'A': 'dancefloor > liquid'}
w_genre = widgets.RadioButtons(
    options = w_genre_options.values(),
    description='Convert to:',
    disabled=False
)
w_genre.layout.width = '100%'

w_button = widgets.Button(
    description='Go!',
    disabled=False,
)

w_debug = widgets.Checkbox(
    value=False,
    description='Show intermediate steps',
    disabled=False,
)

w_before = widgets.IntSlider(
    value=0, min=-4, max=16,
    description='Number of downbeats before drop:',
)

w_length = widgets.IntSlider(
    value=2, min=1, max=8,
    description='Number of downbeats to process:',
)

w_out = widgets.Output()


widgets_ = {
    'file' : w_file, 
    'genre' : w_genre, 
    'show_intermediate' : w_debug,
    'before' : w_before,
    'length' : w_length,
}
y_orig, y_lws_all, y_no_lws_all, y_orig_all = None, None, None, None
def on_button_clicked(b):
    global y_orig, y_lws_all, y_no_lws_all, y_orig_all
    
    with w_out:
        IPython.display.clear_output()
        
        extracts = extract_segments_from_song(
            os.path.join('./_music', widgets_['file'].value), 
            -widgets_['before'].value,
            widgets_['length'].value,
        )
        extracts = list(zip(extracts, [44100] * len(extracts)))
        
        y_all = [y for y, _ in extracts]
        y_lws_all = []
        y_no_lws_all = []
        y_orig_all = []
        
        for i, (y, sr) in enumerate(extracts):
            print('\tProcessing extract {:1d}/{:1d}'.format(i+1, len(extracts)))            
            # Feed the audio into the model
            img_orig = convert_extract_to_png(y, sr, widgets_)
            visuals = apply_model()
            # Retrieve the generated image, convert to audio, and output
            y_lws, y_no_lws, y_lws_orig = display_results(y, img_orig, visuals, widgets_)
            # Concatenate into one big audio stream
            y_lws_all.append(y_lws)
            y_no_lws_all.append(y_no_lws)
            y_orig_all.append(y_lws_orig)
            
        def concat_and_display(y_array, title):
            y_ = np.concatenate(y_array)
            a_ = IPython.display.Audio(y_, rate=44100,)
            print(title)
            display(a_)
        
        print('AMPLITUDE, PHASE:')
        concat_and_display(y_all, 'Original')
        concat_and_display(y_no_lws_all, 'Transformed, original')
        
        # Display original audio, that has been first transformed to the Mel scale first,
        # then reconstructed from that representation, with phase inferred using RTISI-LA
        #
        # concat_and_display(y_orig_all, 'Original, mel-scale and then RTISI-LA')
        #
        # Display transformed audio, where phase has been inferred using RTISI-LA:
        #
        # concat_and_display(y_lws_all, 'Transformed, RTISI-LA')
        
w_button.on_click(on_button_clicked)

display(w_file)
display(w_genre)
display(w_before, w_length)
display(w_debug)
display(w_button)
display(w_out)

That's all folks :)