<a href="https://colab.research.google.com/github/cifkao/ss-vq-vae/blob/main/experiments/colab_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Timbre transfer demo

Copyright 2020 InterDigital R&D and Télécom Paris.  
Author: Ondřej Cífka

## Install packages

In [None]:
!git clone https://github.com/cifkao/ss-vq-vae.git

In [None]:
!pip uninstall torch torchvision torchaudio accelerate

In [None]:
!pip install ./ss-vq-vae/src 'numba>0.57' ddsp

## Download the model

In [None]:
logdir = 'ss-vq-vae/experiments/model'

In [None]:
!wget https://adasp.telecom-paris.fr/rc-ext/demos_companion-pages/vqvae_examples/ssvqvae_model_state.pt -O $logdir/model_state.pt

## Load the model

In [None]:
import os

import confugue
from ddsp.colab import colab_utils
import librosa
import torch

from ss_vq_vae.models.vqvae_oneshot import Experiment

In [None]:
cfg = confugue.Configuration.from_yaml_file(os.path.join(logdir, 'config.yaml'))
exp = cfg.configure(Experiment, logdir=logdir, device='cpu')
exp.model.load_state_dict(torch.load(os.path.join(logdir, 'model_state.pt'), map_location=exp.device))
exp.model.train(False)

In [None]:
INPUT_ROOT = 'https://adasp.telecom-paris.fr/rc-ext/demos_companion-pages/vqvae_examples/'
INPUT_URLS = {
    'Electric Guitar': INPUT_ROOT + 'real/content/UnicornRodeo_Maybe_UnicornRodeo_Maybe_Full_25_ElecGtr2CloseMic3.0148.mp3',
    'Electric Organ': INPUT_ROOT + 'real/style/AllenStone_Naturally_Allen%20Stone_Naturally_Keys-Organ-Active%20DI.0253.mp3',
    'Jazz Piano': INPUT_ROOT + 'real/style/MaurizioPagnuttiSextet_AllTheGinIsGone_MaurizioPagnuttiSextet_AllTheGinIsGone_Full_12_PianoMics1.08.mp3',
    'Synth': INPUT_ROOT + 'real/content/Skelpolu_TogetherAlone_Skelpolu_TogetherAlone_Full_13_Synth.0190.mp3'
}

## Choose or record inputs

In [None]:
#@title Content input
content_input = 'Electric Guitar'  #@param ["Record", "Electric Guitar", "Electric Organ", "Jazz Piano", "Synth"]
record_seconds = 8 #@param {type:"number"}

if content_input == 'Record':
    a_content = colab_utils.record(seconds=record_seconds, sample_rate=exp.sr, normalize_db=0.1)
else:
    !wget {INPUT_URLS[content_input]} -O content_input.mp3
    a_content, _ = librosa.load('content_input.mp3', sr=exp.sr)
colab_utils.play(a_content, sample_rate=exp.sr)

In [None]:
#@title Style input
style_input = 'Jazz Piano'  #@param ["Record", "Electric Guitar", "Electric Organ", "Jazz Piano", "Synth"]
record_seconds = 8 #@param {type:"number"}

if style_input == 'Record':
    a_style = colab_utils.record(seconds=record_seconds, sample_rate=exp.sr, normalize_db=0.1)
else:
    !wget {INPUT_URLS[style_input]} -O style_input.mp3
    a_style, _ = librosa.load('style_input.mp3', sr=exp.sr)
colab_utils.play(a_style, sample_rate=16000)

## Run the model

In [None]:
s_content = torch.as_tensor(exp.preprocess(a_content), device=exp.device)[None, :]
s_style = torch.as_tensor(exp.preprocess(a_style), device=exp.device)[None, :]
l_content, l_style = (torch.as_tensor([x.shape[2]], device=exp.device) for x in [s_content, s_style])
with torch.no_grad():
    s_output = exp.model(input_c=s_content, input_s=s_style,
                         length_c=l_content, length_s=l_style)
a_output = exp.postprocess(s_output.cpu().numpy()[0])
colab_utils.play(a_output, sample_rate=16000)