In [1]:
import argparse
import json
import math
import numpy as np
from librosa.feature import chroma_stft
import IPython.display as ipd
from tqdm import tqdm
from scipy.io.wavfile import write

import torch
import torchcrepe

import params
from model import MusicGenerator

import controls

import sys
sys.path.append('hifi-gan/')
from env import AttrDict
from models import Generator as HiFiGAN

import matplotlib.pyplot as plt
%matplotlib inline

FEAT_STD = 2.39
FEAT_MEAN = -5.38

  - 0.5 * np.log10(f_sq + const[3]))


Number of classifier parameters: 153928
Number of pitch extractor parameters: 837074


In [2]:
import librosa
from librosa.core import load
from librosa.filters import mel as librosa_mel_fn

mel_basis = librosa_mel_fn(22050, 1024, 80, 0, 8000)

def get_mel(wav_path):
    wav, _ = load(wav_path, sr=22050)
    wav = wav[:(wav.shape[0] // 256)*256]
    wav = np.pad(wav, 384, mode='reflect')
    stft = librosa.core.stft(wav, n_fft=1024, hop_length=256, win_length=1024, window='hann', center=False)
    stftm = np.sqrt(np.real(stft) ** 2 + np.imag(stft) ** 2 + (1e-9))
    mel_spectrogram = np.matmul(mel_basis, stftm)
    log_mel_spectrogram = np.log(np.clip(mel_spectrogram, a_min=1e-5, a_max=None))
    return log_mel_spectrogram.astype("float32")

In [3]:
generator = MusicGenerator(params.n_feats, params.n_classes, params.base_dim, params.class_dim,
                           params.beta_min, params.beta_max, FEAT_MEAN, FEAT_STD).cuda()
generator.load_state_dict(torch.load('checkpts/music_generator.pt'))
generator.eval()
print(f'Number of parameters: {generator.nparams}')

Number of parameters: 16940583


In [4]:
with open('checkpts/hifigan-config.json') as f:
    h = AttrDict(json.load(f))
hifigan = HiFiGAN(h).cuda()
hifigan.load_state_dict(torch.load('checkpts/hifigan-music.pt')['generator'])
hifigan.eval()
hifigan.remove_weight_norm()

Removing weight norm...


In [5]:
compute_pitch_x = controls.compute_pitch_x
compute_pitch_z = controls.compute_pitch_z
compute_chroma_x = controls.compute_chroma_x
compute_chroma_z = controls.compute_chroma_z
compute_loudness_x = controls.compute_loudness_x
compute_loudness_z = controls.compute_loudness_z
compute_prob_x = controls.compute_prob_x
compute_prob_z = controls.compute_prob_z

def compute_pitch(x):
    return compute_pitch_z(x[:, :params.pitch_mel, :], FEAT_MEAN, FEAT_STD)

def compute_chroma(x):
    return compute_chroma_z(x, FEAT_MEAN, FEAT_STD)

def compute_loudness(x):
    return compute_loudness_z(x, FEAT_MEAN, FEAT_STD)

def compute_prob(x):
    return compute_prob_z(x, FEAT_MEAN, FEAT_STD)

In [6]:
instrument2id = {'piano': 0, 'flute': 1, 'harpsichord': 2, 'string': 3, 'unknown': -1}
id2instrument = ['piano', 'flute', 'harpsichord', 'string']

weights = dict()
weights['piano'] = {'flute': [0.0, 0.25, 0.05, 0.05], 
                    'harpsichord': [0.0, 0.3, 0.05, 0.05], 
                    'string': [0.0, 0.2, 0.03, 0.02]}
weights['flute'] = {'piano': [0.0, 0.25, 0.07, 0.01], 
                    'harpsichord': [0.01, 0.15, 0.1, 0.02], 
                    'string': [0.01, 0.25, 0.1, 0.05]}
weights['harpsichord'] = {'piano': [0.0, 0.25, 0.07, 0.05], 
                          'flute': [0.01, 0.25, 0.07, 0.1], 
                          'string': [0.0, 0.3, 0.05, 0.1]}
weights['string'] = {'piano': [0.0, 0.25, 0.1, 0.01], 
                     'flute': [0.0, 0.25, 0.04, 0.1], 
                     'harpsichord': [0.0, 0.25, 0.1, 0.1]}
weights['unknown'] = {'piano': [0.0, 0.25, 0.05, 0.05], 
                      'flute': [0.0, 0.25, 0.05, 0.05], 
                      'harpsichord': [0.0, 0.25, 0.05, 0.05], 
                      'string': [0.0, 0.25, 0.05, 0.05]}

In [7]:
#piano/0sDleZkIK-w_116
#flute/6GwfuWhOOdY_37
#harpsichord/5B4eEcvBIek_114
#string/npQJP_nF7NI_125

# performing timbre transfer
source_id = 'piano/0sDleZkIK-w_116'           # path to source musical excerpt
target_name = 'flute'                         # target musical instrument
n_timesteps = 200                             # number of steps for both reverse and (in case of OT sampling mode) forward diffusion
use_ot = True                                 # whether to use ot samling mode or not
use_control = True                            # whether to use gradient guidance (controlled sampling) or not

source_name = source_id.split('/')[0]
c_source_id = instrument2id[source_name]
c_target_id = instrument2id[target_name]
control_weights = weights[source_name][target_name]
print('Conversion: %s -> %s' % (source_name, target_name))
print('control weights are', control_weights)

source_mel = get_mel('examples/%s.wav' % source_id)
source_mel = source_mel[:, :4*(source_mel.shape[1]//4)]
x_source = torch.from_numpy(source_mel).float().cuda().unsqueeze(0)
lengths_source = torch.LongTensor([x_source.shape[-1]]).cuda()
c_source = torch.LongTensor([c_source_id]).cuda()
c_target = torch.LongTensor([c_target_id]).cuda()

pitch_value = compute_pitch_x(x_source[:, :params.pitch_mel, :]).detach()
chroma_value = compute_chroma_x(x_source).detach()
loudness_value = compute_loudness_x(x_source).detach()

if use_control:
    x_converted = generator.convert(x_source, lengths_source, c_source, c_target, 
                                    compute_pitch, compute_chroma, compute_loudness, compute_prob, 
                                    pitch_value, chroma_value, loudness_value, 
                                    pitch_weight=control_weights[0], chroma_weight=control_weights[1], 
                                    loudness_weight=control_weights[2], clf_weight=control_weights[3], 
                                    use_ot=use_ot, n_timesteps=n_timesteps)
else:
    x_converted = generator.convert(x_source, lengths_source, c_source, c_target, 
                                    compute_pitch, compute_chroma, compute_loudness, compute_prob, 
                                    pitch_value, chroma_value, loudness_value, 
                                    pitch_weight=0.0, chroma_weight=0.0, 
                                    loudness_weight=0.0, clf_weight=0.0, 
                                    use_ot=use_ot, n_timesteps=n_timesteps)    

Conversion: piano -> flute
control weights are [0.0, 0.25, 0.05, 0.05]

Initial MSE in pitch = 0.3541
Initial MSE in chroma = 0.0914
Initial MSE in loudness = 43.0007
Initial min probability = 0.06%

t = 0.950
MSE in pitch = 0.1034
MSE in chroma = 0.0924
MSE in loudness = 17.5978
Min probability = 0.00%

t = 0.945
MSE in pitch = 0.1123
MSE in chroma = 0.0903
MSE in loudness = 17.0347
Min probability = 0.01%

t = 0.941
MSE in pitch = 0.1164
MSE in chroma = 0.0906
MSE in loudness = 20.5103
Min probability = 0.46%

t = 0.936
MSE in pitch = 0.1248
MSE in chroma = 0.0881
MSE in loudness = 20.8065
Min probability = 1.19%

t = 0.931
MSE in pitch = 0.1335
MSE in chroma = 0.0839
MSE in loudness = 19.3368
Min probability = 0.28%

t = 0.926
MSE in pitch = 0.1366
MSE in chroma = 0.0790
MSE in loudness = 17.4297
Min probability = 0.38%

t = 0.921
MSE in pitch = 0.1410
MSE in chroma = 0.0776
MSE in loudness = 20.3655
Min probability = 0.09%

t = 0.917
MSE in pitch = 0.1571
MSE in chroma = 0.0743
MSE

In [8]:
with torch.no_grad():
    audio_source = hifigan.forward(x_source).cpu().squeeze().clamp(-1, 1)
ipd.display(ipd.Audio(audio_source, rate=22050))

In [9]:
with torch.no_grad():
    audio_converted = hifigan.forward(x_converted).cpu().squeeze().clamp(-1, 1)
ipd.display(ipd.Audio(audio_converted, rate=22050))