In [None]:
import sys
import pickle
import torch
import numpy as np
import glob
import librosa
import pandas as pd

import os
from mido import MidiFile
import mido
from scipy.interpolate import interp1d

#--- import HiFiGAN modules
sys.path.append('../')
import models
import common.layers as layers 
from common.utils import load_wav #--- use same method that is used in hifigan for loading audio
from hifigan.data_function import mel_spectrogram
from hifigan.models import Denoiser

%matplotlib notebook
import matplotlib.pyplot as plt
import librosa.display

import IPython.display as ipd
ipd.display(ipd.HTML("<style>.container { width:85% !important; }</style>"))

# load cfg and generator model from checkpoint
also create denoiser instance

In [None]:
#--- get config from checkpoint, so no need to load args from disk
#args = pickle.load(open('../TMP_args.p', 'rb'))
#gen_config = models.get_model_config('HiFi-GAN', args)

DEVICE = 'cuda' # 'cpu' or 'cuda'

#--- in hifigan code they use 2 implementations
#--- (1) from fastpitch, when pre-calculating mel spec in prepare_dataset.sh. This is saved to disk and used for training/inference as mel spec INPUT 
#--- (2) from hifigan, during training, when calculating mel spec for OUTPUT (target) signal, and for inference
#--- NOTE use the same implementation that was used for INPUT in training. >>> starting 2023-05-28 and after, this should be 'hifigan' <<<
MEL_IMPL = 'hifigan' #'fastpitch' # 

assert DEVICE == 'cuda', 'ERROR: cpu not supported yet (mel code assumes torch tensors)'

#m_path = '../results/2023_01_20_hifigan_ssynth44khz_synthesized_input/hifigan_gen_checkpoint_10000.pt'
#m_path = '../results/2023_05_15_hifigan_ssynth44khz_synthesized_input_16k_spl0.5/hifigan_gen_checkpoint_3000.pt'
m_path = '../results/2023_05_28_hifigan_ssynth44khz_synthesized_input_16k_spl0.5_nonorm/hifigan_gen_checkpoint_3000.pt'

checkpoint = torch.load(m_path)
train_config = checkpoint['train_setup']
sampling_rate = train_config['sampling_rate']
gen_config = checkpoint['config']
gen_config['num_mel_filters'] = train_config['num_mels']

gen = models.get_model('HiFi-GAN', gen_config, DEVICE, forward_is_infer = True)
gen.load_state_dict(checkpoint['generator'])
gen.remove_weight_norm()
gen.eval()

denoising_strength = 0.05
denoiser = Denoiser(gen, win_length = train_config['win_length'], num_mel_filters = train_config['num_mels']).to(DEVICE)

# Mel spectrum class
make it identical to code in training, so we get the same features exactly <br/>
NOTE: this is the code used for mel of target audio, for source there is another impl. <br/>
TODO: verify and fix if needed

In [None]:
#--- NOTE starting 2023-05-28, I added a flag to the dataset-creation script, that makes it use the same mel implementation (unless told explicitly not to)
#         for both run-time calculation and save-to-disk calculation

#--- this is the implementation used to generate pre-calculated mels for synthetic wavs/fine-tuning (loaded from disk during training)
class MelSpec:
    def __init__(self, cfg):
        filter_length = cfg['filter_length']
        hop_length = cfg['hop_length']
        win_length = cfg['win_length']
        n_mel_channels = cfg['num_mels']
        sampling_rate = cfg['sampling_rate']
        mel_fmin = cfg['mel_fmin']
        mel_fmax = cfg['mel_fmax']
        self.stft = layers.TacotronSTFT(filter_length, hop_length, win_length,n_mel_channels, sampling_rate, mel_fmin, mel_fmax)        
    
    def get_mel(self, audio):
        #audio_norm = audio / self.max_wav_value
        #audio_norm = audio_norm.unsqueeze(0)
        #audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
        melspec = self.stft.mel_spectrogram(audio)  
        
        return melspec

mel_spec = MelSpec(train_config)

#--- this is the implementation used to calculate mel spec of input on-the-fly during training and validation if we are NOT using synthetic wavs/fine-tuning
from functools import partial
mel_fmax = train_config['mel_fmax'] #--- in train.py, there's option to use different fmax for computing the loss.
mel_spec2 = partial(mel_spectrogram, n_fft=train_config['filter_length'],
                   num_mels = train_config['num_mels'],
                   sampling_rate = train_config['sampling_rate'],
                   hop_size = train_config['hop_length'], 
                   win_size = train_config['win_length'],
                   fmin = train_config['mel_fmin'],
                   fmax = mel_fmax)

if MEL_IMPL == 'fastpitch':
    get_mel_spec =  mel_spec.get_mel
elif MEL_IMPL == 'hifigan':
    get_mel_spec = mel_spec2
else:
    raise Exception('unknown MEL spec implementation')

def array_to_torch(x):
    x = torch.FloatTensor(x.astype(np.float32))
    x = torch.autograd.Variable(x, requires_grad = False)
    x = x.unsqueeze(0)
    return x   

#--- simple wrapper to apply HiFiGAN generator to input audio
def generate_from_audio(x, hifigan_gen, return_numpy_arr = True):
    x = array_to_torch(x)    
    mel = get_mel_spec(x)
        
    x_hat = hifigan_gen(mel.cuda())
    if return_numpy_arr:
        x_hat = x_hat[0].cpu().detach().numpy()[0]
    
    return x_hat

# load wav from validation set, get mel and apply model
Note: synthesis method should fit the one used to train the model (i.e., "10 harmonics" or "16 khz" etc.)

In [None]:
import torch.nn.functional as F
from tqdm import tqdm

flist_validation = open('../data_ssynth/filelists/ssynth_audio_val.txt', 'r').readlines()
flist_validation = [fnm.rstrip() for fnm in flist_validation]

flist_train = open('../data_ssynth/filelists/ssynth_audio_train.txt', 'r').readlines()
flist_train = [fnm.rstrip() for fnm in flist_train]

flist = flist_validation #flist_train #
n_files = len(flist)

#wav_fnm = '../data_ssynth/wavs_synth_10h/01_Free_Improv_dynamic_mic_phrase000.wav'
synth_wavs_folder = 'wavs_synth_16k_spl0.5' # 'wavs_synth_10h'

mel_loss = np.zeros(n_files)
mel_len = np.zeros(n_files)
for file_index in tqdm(range(n_files)): #[5] #1
    wav_fnm_target = flist[file_index]
    y_target, sr, sample_type = load_wav(f'../data_ssynth/{wav_fnm_target}')

    wav_fnm = flist[file_index].replace('wavs/', f'{synth_wavs_folder}/')
    y, sr, sample_type = load_wav(f'../data_ssynth/{wav_fnm}')

    if sample_type == 'PCM_24':
        max_wav_value = 2**31 # data type in this case is int32
    elif sample_type == 'PCM_16':
        max_wav_value = 2**15

    #--- convert to float in [-1., 1.]
    y = y.astype(np.float32) / np.float32(max_wav_value)
    y_target = y_target.astype(np.float32) / np.float32(max_wav_value)

    # if DEVICE == 'cuda':
    #     y = torch.FloatTensor(y.astype(np.float32))
    #     y = torch.autograd.Variable(y, requires_grad = False)
    #     y = y.unsqueeze(0)
    # else:
    #     y = y[np.newaxis, :]
    # mel = mel_spec2(y) #mel_spec.get_mel(y)
    # y_hat = gen(mel.cuda())
    # y_ = y.numpy()[0]

    y_hat = generate_from_audio(y, gen, return_numpy_arr = False)
    y_hat_den = denoiser(y_hat.squeeze(1), denoising_strength)
    y_hat = y_hat[0].cpu().detach().numpy()[0]
    y_hat_den = y_hat_den[0].cpu().detach().numpy()[0]

    mel_target = get_mel_spec(array_to_torch(y_target)).squeeze(0)
    mel_hat = get_mel_spec(array_to_torch(y_hat)).squeeze(0)
    mloss = F.l1_loss(mel_target, mel_hat)
    mel_loss[file_index] = mloss
    mel_len[file_index] = mel_target.shape[1]
    #print(f'file {file_index}/{n_files}: mel loss {mloss}')

In [None]:
sort_ind = np.argsort(mel_loss)
print(sort_ind[0:10],sort_ind[-10:])
fig, ax = plt.subplots(figsize = (12,4))
ax.plot(mel_len, mel_loss,'.')
ax.grid()

if False:
    fig, ax = plt.subplots(1,2,figsize = (12,4), sharex=True, sharey=True)
    ax[0].imshow(mel_target, aspect='auto',interpolation='none',origin='lower')
    ax[1].imshow(mel_hat, aspect='auto',interpolation='none',origin='lower')
    fig, ax = plt.subplots(figsize = (12,4))
    k1,k2 = 360,450 #317 #90
    ax.plot(mel_target[:,k1:k2].mean(1))
    ax.plot(mel_hat[:,k1:k2].mean(1))
    ax.grid()

## play result

In [None]:
play_normalize = True #False

print('Original audio:')
ipd.display(ipd.Audio(y_target, rate = sampling_rate, normalize = play_normalize))

#print('Synthesized input:')
#ipd.display(ipd.Audio(y_, rate = sampling_rate, normalize = play_normalize))

print('Generated audio:')
ipd.display(ipd.Audio(y_hat, rate = sampling_rate, normalize = play_normalize))

print('Generated audio (denoised):')
ipd.display(ipd.Audio(y_hat_den, rate = sampling_rate, normalize = play_normalize))

# Try with synthetic input
### I define a naive ADSR envelopes with straight lines, probably not the best option

In [None]:
def get_num_harmonics(min_freq_src_hz, max_freq_src_hz, sr, max_freq_tgt_hz):
    fmin = max(alto_sax_range[0], min_freq_src_hz) # librosa.note_to_hz(range_notes[0]) # can't naively use fnew.min() since we interpolate to f=0 Hz
    
    num_harmonics = int(max_freq_tgt_hz / fmin)
    new_sr = 2 * max_freq_src_hz * num_harmonics
    #--- take the smallest multiple of sr which is high enough (6 is the highest, assuming freqs.max() <= 932 Hz)
    new_sr_factor = [k for k in range(1, 10) if k * sr > new_sr][0]
    return num_harmonics, new_sr_factor

In [None]:
from scipy.signal import decimate, butter, dlti # resample_poly
from scipy.interpolate import UnivariateSpline

def additive_synth_sawtooth(freq, env, sampling_rate, additive_synth_k = None, max_freq_hz = None):
    ''' TDOO add code to synthesize up to f_max (and not a given number of harmonics)
        given input frequency and envelope sampled at sampling_rate, synthesize a band-limited
        sawtooth wave using additive synthesis of 10 (or k) harmonies
    '''    
    #--- set number of harmonics of sawtooth wave
    if additive_synth_k is not None:
        should_downsample = False
        sampling_rate_new = None
    else:
        num_harmonics, new_sr_factor = get_num_harmonics(freq[freq > 20].min(), freq.max(), sampling_rate, max_freq_hz)
        #--- make sure we stay below new nyquist
        assert freq.max() * num_harmonics < 0.5 * sampling_rate * new_sr_factor, f'Nyquist says you cannot synthesize {num_harmonics} harmonics at {new_sr_factor} X (current sampling rate)'
        additive_synth_k = num_harmonics
        sampling_rate_new = sampling_rate * new_sr_factor
        should_downsample = True
    
    dt = 1 / sampling_rate
    
    #--- interpolate (upsample) to sampling-rate grid, if needed
    if sampling_rate_new is not None:
        tmax = len(freq) * dt
        t_old = np.arange(0, tmax, dt)
        fintrp = interp1d(t_old, freq)
        dt = 1 / sampling_rate_new
        t_new = np.arange(0, tmax, dt)
        t_new = t_new[(t_new <= t_old.max()) & (t_new >= t_old.min())] # avoid interpolation out of bounds
        freq = fintrp(t_new)  

    #--- phase is the integral of instantanous freq
    phi = np.cumsum(2 * np.pi * freq * dt)
    # to wrap: phi = (phi + np.pi) % (2 * np.pi) - np.pi 
        
    x = np.sin(phi) #(np.sin(phi) + .5*np.sin(2*phi) + .333*np.sin(3*phi) + .25*np.sin(4*phi))
    for k in range(2, additive_synth_k + 1):
        x += (-1)**(k-1) * np.sin(k * phi) / k
    
    #--- if we upsampled, go back to original rate
    if should_downsample:
        #--- for x, give a "anti-alias" filter to "decimate", but actually use it to filter above the desired max_freq_hz
        zpk = butter(12, max_freq_hz, output = 'zpk', fs = sampling_rate_new)
        aa_filt = dlti(*zpk) 
        x = decimate(x, new_sr_factor, ftype = aa_filt)
        freq = decimate(freq, new_sr_factor) #--- fnew is just used to zero the envelope, so decimate so size fits
        #sr = int(sr / new_sr_factor)
    
    x *= env
    
    return x

## 2 octaves major scale in the range of the alto sax

In [None]:
range_notes = ['Db3', 'A5'] #['C3', 'A#5'] # alto sax range is ['Db3', 'A5'], take half-step below/above
alto_sax_range = librosa.note_to_hz(range_notes)

#--- envelope parameters
note_len_samples = 24000 #20000 #20000
onset_samples = 4500 #3000
amp = 0.03
amp_sustain = 0.8 # decay envelope to this relative level at the end of the note
freq_glide_level = 0.7 #--- during onset, glide into target frequency starting at this pitch (relative)

freq = np.zeros(note_len_samples)
env = np.zeros(note_len_samples)

#--- single note envelope
env_single = np.r_[np.linspace(0, 1, onset_samples),  np.linspace(1, amp_sustain, note_len_samples - onset_samples)]
env_single = env_single ** 3
env_single *= amp


#--- major scale in the alto sax range
for note in ['D3', 'E3', 'F#3', 'G3', 'A3', 'B3', 'C#4', 'D4', 'E4', 'F#4', 'G4', 'A4', 'B4', 'C#5', 'D5', 'E5', 'F#5', 'G5', 'A5']:
    f0 = librosa.note_to_hz(note)
    freq_env = np.ones(note_len_samples)
    freq_env[:onset_samples] *= np.linspace(freq_glide_level, 1, onset_samples)
    
    freq = np.r_[freq, f0 * freq_env]
    env = np.r_[env, env_single]
    
freq = np.r_[freq, np.zeros(note_len_samples)]
freq[freq <= alto_sax_range[0]] = alto_sax_range[0]
freq[freq >= alto_sax_range[1]] = alto_sax_range[1]
env = np.r_[env, np.zeros(note_len_samples)]

In [None]:
#x = additive_synth_sawtooth(freq, env, sampling_rate, additive_synth_k=30)
x = additive_synth_sawtooth(freq, env, sampling_rate, max_freq_hz = 16000)
#--- in order to apply denoiser, we need the pytorch Tensor, so set return_numpy_arr to False
x_hat = generate_from_audio(x, gen, return_numpy_arr = False)

x_hat_den = denoiser(x_hat.squeeze(1), 4*denoising_strength)
#x = x.numpy()[0]

In [None]:
print('Original synthesized input:')
ipd.display(ipd.Audio(x, rate = sampling_rate, normalize = play_normalize))

print('Generated audio:')
x_hat = x_hat[0].cpu().detach().numpy()[0]
ipd.display(ipd.Audio(x_hat, rate = sampling_rate, normalize = play_normalize))

print('Generated audio (denoised):')
x_hat_den = x_hat_den[0].cpu().detach().numpy()[0]
ipd.display(ipd.Audio(x_hat, rate = sampling_rate, normalize = play_normalize))

# Synthesize from parallel audio+midi, and compare

In [None]:
def binary_array_to_seg_inds(arr, shift_end_ind = True):
    seg_inds = np.diff(np.r_[0, np.int_(arr), 0]).nonzero()[0]
    n_segs = int(seg_inds.shape[0] / 2)
    seg_inds = seg_inds.reshape((n_segs, 2)) # + np.c_[np.zeros(n_segs),-np.ones(n_segs)]   
    if shift_end_ind:
        seg_inds[:,1] -= 1
    return seg_inds    

def read_midi_to_df(midi_fnm, try_to_fix_note_order = True, time_offset_sec = 0.):
    mid = MidiFile(midi_fnm)
    
    #assert(len(mid.tracks) == 1)
    tr = mido.merge_tracks(mid.tracks)
    df =  pd.DataFrame([m.dict() for m in tr])
    tempo = df.set_index('type').loc['set_tempo','tempo']
    if type(tempo) == pd.Series:
        uniq_tempo = tempo.unique()
        if len(uniq_tempo) > 1:
            raise Exception('multiple tempo changes not supported')
        else:
            tempo = uniq_tempo[0]
            
    df['ts_sec'] = mido.tick2second(df.time.cumsum(), mid.ticks_per_beat, tempo)
    if time_offset_sec != 0.:
        df['ts_sec'] += time_offset_sec
        df = df[df['ts_sec'] >= 0.]
        
    #--- extract controls (pitch and aftertouch has seperated messages, not included in CC)
    df_aftertouch = df[df.type == 'aftertouch'].dropna(axis = 1).reset_index(drop = True)
    df_pitch = df[df.type == 'pitchwheel'].dropna(axis = 1).reset_index(drop = True)
    df_cc = df[df.type == 'control_change'].dropna(axis = 1).reset_index(drop = True)
    
    #--- some mete-messages like "channel prefix" contain non-zero time value. so remove them *after* calculating 'ts_sec'
    for type_remove in ['channel_prefix', 'track_name', 'instrument_name', 'time_signature', 'key_signature', 
                        'smpte_offset', 'set_tempo', 'end_of_track', 'midi_port', 'program_change', 'control_change', 'pitchwheel', 'aftertouch', 'marker']:
        df = df[df.type != type_remove]
    
    df = df.dropna(axis = 1).reset_index(drop = True)
    
    #--- sometimes, instead of a sequence of on-off notes, we get on-on-off-off. try to fix that
    if try_to_fix_note_order:
        try:
            verify_midi(df)
        except AssertionError:
            print(f'{midi_fnm}: note order problem in midi dataframe, trying to fix...')
            df_copy = df.copy()
            for ii, inote in df.iterrows():
                if inote.type == 'note_on':
                    assoc_note_off = df.iloc[ii:].query('type == "note_off" and note == @inote.note')
                    if len(assoc_note_off) == 0:
                        raise Exception('note on with no associated note off')
                    inote_off = assoc_note_off.iloc[0]
                    inote_off_ind = inote_off.name
                    if inote_off_ind > ii + 1:
                        next_note_on = df.iloc[ii+1:].query('type == "note_on"')
                        if len(next_note_on) > 0:
                            next_note_on = next_note_on.iloc[0]
                            if next_note_on.ts_sec < inote_off.ts_sec:
                                df.loc[inote_off_ind, 'ts_sec'] = next_note_on.ts_sec - .001        
#             #--- indices of where we expect to see "note off" and see "note on"
#             off_err_ind = df[((df.index % 2) == 1) & (df.type == 'note_on')].index
#             for ind in off_err_ind:
#                 curr_note = df.loc[ind]
#                 next_note = df.loc[ind + 1]
#                 prev_note = df.loc[ind - 1]
#                 if next_note.type == 'note_off' and next_note.note == prev_note.note:
#                     df.loc[ind + 1, 'ts_sec'] = curr_note.ts_sec - 0.001
            df = df.sort_values(by = 'ts_sec', kind = 'stable').reset_index(drop = True)
            try:
                verify_midi(df)
                print('fixed')
            except AssertionError:
                print('fix failed, calling verify_midi() on returned dataframe will fail')
                #--- if fix failed, return the original copy
                df = df_copy
                
    return df, df_pitch, df_aftertouch, df_cc

def verify_midi(midi_df):
    #--- validate the assumption that we have series of note-on/note-off events
    assert((midi_df['type'].iloc[::2] == 'note_on').all() and 
       (midi_df['type'].iloc[1::2] == 'note_off').all() and
       (midi_df['note'].iloc[::2].to_numpy() == midi_df['note'].iloc[1::2].to_numpy()).all())

def midi_phrase_from_dataframe(p, midi_df, sr):
    t0 = p.sample_start / sr
    t1 = p.sample_end / sr
    midi_p = midi_df[(midi_df.ts_sec >= t0) & (midi_df.ts_sec <= t1)]
    
    #--- check for missing note_off (at end) or note_on (at start)
    first_note = midi_p.iloc[0]
    if first_note['type'] == 'note_off':
        candidate = midi_df.loc[first_note.name - 1]
        if candidate['type'] == 'note_on' and candidate['note'] == first_note['note']:
            midi_p = pd.concat([candidate.to_frame().T, midi_p])
            
    last_note = midi_p.iloc[-1]
    if last_note['type'] == 'note_on':
        candidate = midi_df.loc[last_note.name + 1]
        if candidate['type'] == 'note_off' and candidate['note'] == last_note['note']:
            midi_p = pd.concat([midi_p, candidate.to_frame().T])
    
    return midi_p
    
def phrase_to_midi_string(p, midi_df, sr):    
    midi_p = midi_phrase_from_dataframe(p, midi_df, sr)            
    try:
        verify_midi(midi_p)
    except Exception as e:
        print(f'phrase {p.phrase_id} verification failed')
        return ''
    
    note_on = midi_p.loc[midi_p.type == 'note_on']
    s = f"wavs/{p.phrase_id}.wav|{' '.join(note_on.note.astype(int).astype(str).to_list())}"
    return s

In [None]:
from scipy.signal import decimate, butter, dlti # resample_poly
from scipy.interpolate import UnivariateSpline

range_notes = ['C3', 'A#5'] # alto sax range is ['Db3', 'A5'], take half-step below/above
alto_sax_range = librosa.note_to_hz(range_notes)

win = 1024
ac_win = 512 # autocorrelation window
hop = 256

def phrase_to_synth(seg, sr, midi_p, t0, num_harmonics = None, max_freq_hz = None, spline_smoothing = None, verbose = False):
    ''' Exactly one of these should be given (and the other set to None):
            - num harmonics: how many harmonics (inc the fundamental) are used in the saw-tooth additive synthesis
                             in this case the max-freq is note-dependent (f0*num_harmonics) and the caller is responsible
                             to make sure that (highest note in hz) * (num_harmonics) < nyquist
            - max_freq_hz:   synthesize up to this frequency. This is done by upsampling, synthesizing the required amound of harmonics,
                             and downsampling back to sr
            - smooth_env:    flag to apply smooting using 2nd order splines
    '''
    if verbose:
        print(f'pitch detection range: {alto_sax_range.round(1)} Hz, {(sr/alto_sax_range).astype(int)} samples')
        print(f'pitch detection: frame len {win}, auto-corr len {ac_win} (min freq of {sr/ac_win:.1f} Hz), hop len {hop}')
    
    assert(num_harmonics is None or max_freq_hz is None)
    f1, vflag1, vprob1 = librosa.pyin(seg, 
                                      fmin = alto_sax_range[0], 
                                      fmax = alto_sax_range[1], 
                                      sr = sr, 
                                      frame_length=win, 
                                      win_length=ac_win, 
                                      hop_length=hop, 
                                      center=True, 
                                      max_transition_rate=100)
    times1 = librosa.times_like(f1, sr = sr, hop_length = hop)
    no_note1 = (~vflag1)
    tmin = times1[0]
    tmax = times1[-1]
    
    note_on = midi_p.loc[midi_p.type == 'note_on']
    note_off = midi_p.loc[midi_p.type == 'note_off']
    #note_hz = librosa.midi_to_hz(note_on.note)
    note_on_ts = note_on['ts_sec'].values - t0
    note_off_ts = note_off['ts_sec'].values - t0
    
    #-------------------------------------------------------------------------------------------------------------------
    #--- interpolate missing pitch, where possible. otherwise, set to 0 (in order to accumulate 0 phase when integrating)
    #-------------------------------------------------------------------------------------------------------------------
    #--- step A, interpolate within (intra-) midi notes
    n_notes = note_on.shape[0]
    if verbose:
        print(f'samples with non-detected pitch: {np.isnan(f1).sum()}')
    for k in range(n_notes):
        #--- first, find missing pitch samples which are inside a detected midi note
        midi_note_span = (times1 >= note_on_ts[k]) & (times1 <= note_off_ts[k])
        
        #--- if no missing pitch samples are in the midi note span, we don't need this note, so skip
        if not (midi_note_span & no_note1).any():
            continue
        
        #--- if we don't have at least 2 pitch samples in the note span, we can't extrapolate, so skip
        if (midi_note_span & ~no_note1).sum() < 2:
            continue
            
        #--- build the interpolating function from detected pitch samples
        pitch_intrp = interp1d(times1[midi_note_span & ~no_note1], 
                               f1[midi_note_span & ~no_note1], 
                               fill_value = 'extrapolate', 
                               kind = 'nearest',
                               assume_sorted = True)
        #--- the time samples where we want to interpolate: inside midi note AND missing pitch
        t_intrp = times1[midi_note_span & no_note1]
        f1[midi_note_span & no_note1] = pitch_intrp(t_intrp)

    if verbose:
        print(f'after interpolating using midi notes: samples with non-detected pitch: {np.isnan(f1).sum()}')

    #--- step B, interpolate across (inter-) midi notes
    max_gap_to_interpolate_sec = 0.1 #--- don't interpolate gaps above this interval in seconds
    no_note1 = np.isnan(f1)
    seg_inds = binary_array_to_seg_inds(no_note1, shift_end_ind = False)
    seg_lens_sec = np.diff(seg_inds, 1)[:,0] * hop / sr
    for k, inds in enumerate(seg_inds):
        #--- don't interpolate head or tail of signal, or if gap is too long
        #--- TODO check energy envelope in gap (interpolate only above env threshold)
        gap_len = seg_lens_sec[k]
        if (inds[0] == 0) or (inds[1] == len(f1)) or gap_len > max_gap_to_interpolate_sec:
            continue
        gap_len_samples = inds[1] - inds[0]
        if verbose:
            print(f'interpolating over {gap_len_samples} samples over gap of {gap_len:.3f} sec')
        #--- linear interpolation using 1 sample before and after
        new_freqs = np.linspace(f1[inds[0] - 1], f1[inds[1]], gap_len_samples + 2)
        f1[inds[0]:inds[1]] = new_freqs[1:-1]

    no_note1 = np.isnan(f1)
    seg_inds = binary_array_to_seg_inds(no_note1, shift_end_ind = False)
    if verbose:
        print(f'after interpolating over small gaps: samples with non-detected pitch: {np.isnan(f1).sum()}')
    #--- lastly, fill with zeros the samples that are still missing
    f1[np.isnan(f1)] = 0.
    
    #--- set number of harmonics of sawtooth wave
    if num_harmonics is not None:
        additive_synth_k = num_harmonics # 10
        should_downsample = False
    else:
        num_harmonics, new_sr_factor = get_num_harmonics(f1[f1 > 20].min(), f1.max(), sr, max_freq_hz)
        #--- make sure we stay below new nyquist
        assert f1.max() * num_harmonics < 0.5 * sr * new_sr_factor, f'Nyquist says you cannot synthesize {num_harmonics} harmonics at {new_sr_factor} X (current sampling rate)'
        additive_synth_k = num_harmonics
        sr *= new_sr_factor
        should_downsample = True
    
    #--- now interpolate to sampling-rate grid
    dt = 1 / sr
    fintrp = interp1d(times1, f1)
    tnew = np.arange(tmin, tmax, dt)
    fnew = fintrp(tnew)
    
    #--- phase is the integral of instantanous freq
    phi = np.cumsum(2 * np.pi * fnew * dt)
    # to wrap: phi = (phi + np.pi) % (2 * np.pi) - np.pi 
        
    x = np.sin(phi) #(np.sin(phi) + .5*np.sin(2*phi) + .333*np.sin(3*phi) + .25*np.sin(4*phi))
    for k in range(2, additive_synth_k + 1):
        x += (-1)**(k-1) * np.sin(k*phi) / k
    
    #--- if we upsampled, go back to original rate
    if should_downsample:
        #--- for x, give a "anti-alias" filter to "decimate", but actually use it to filter above the desired max_freq_hz
        zpk = butter(12, max_freq_hz, output = 'zpk', fs = sr)
        aa_filt = dlti(*zpk) 
        x = decimate(x, new_sr_factor, ftype = aa_filt)
        fnew = decimate(fnew, new_sr_factor) #--- fnew is just used to zero the envelope, so decimate so size fits
        sr = int(sr / new_sr_factor)
    
    env = librosa.feature.rms(y = seg, frame_length = 512, hop_length = 1, center = True)
    env = 1.3 * np.sqrt(2)*env[0, :len(x)]
    env[fnew == 0] = 0. # don't apply envelope where there was no pitch found

    #--- make envelope go to zero smoothly. This also takes care of the non-continous phase at jumps of f1 to 0
    env_segments = binary_array_to_seg_inds(env == 0)
    decay_time_sec = 0.05 #--- 50 msec decay time
    decay_time_samples = int(decay_time_sec * sr)
    for env_seg in env_segments:
        if env_seg[0] == 0:
            continue
        ind_start = max(0, env_seg[0] - decay_time_samples)
        decay_len = env_seg[0] - ind_start
        decay_factor = np.linspace(1, 0, decay_len)
        env[ind_start: env_seg[0]] *= decay_factor  
    
    if spline_smoothing is not None:
        ts = t0 + np.arange(0, len(env)) / sr
        spl = UnivariateSpline(ts, env, s = spline_smoothing, k = 2)
        env = spl(ts)
        env[env < 0.] = 0.
        
    x *= env
    gain = np.sqrt((x**2).mean()) / np.sqrt((seg**2).mean()) 
    x /= gain
    env /= gain
    
    return x, env, fnew

In [None]:
#--- read phrase info for 1 file
file_id = 'Funky_Nadley'
midi_fnm = f'../data_ssynth_TMP/midi/{file_id}.mid'
print(f'reading midi file {os.path.basename(midi_fnm)}')
midi_df, midi_pitch, midi_aftertouch, midi_cc = read_midi_to_df(midi_fnm)
verify_midi(midi_df)

data_dir = '../data_ssynth_TMP/wavs'
phrase_df_fnm = '../data_ssynth_TMP/phrase_df.csv'
phrase_df = pd.read_csv(phrase_df_fnm, index_col = 0).reset_index(drop = True)
phrase_df = phrase_df[phrase_df.file_nm.str.contains(file_id)]

## synthesize using original envelopes of pitch and amplitude

In [None]:
#--- choose a phrase
phrase_ind = 42 #14 #12 #5
p = phrase_df.iloc[phrase_ind]

t0 = p.sample_start / sampling_rate
wav_fnm = f'{data_dir}/{p.phrase_id}.wav'
seg, sr = librosa.load(wav_fnm, sr = sampling_rate)
midi_p = midi_phrase_from_dataframe(p, midi_df, sampling_rate)
midi_p_cc = midi_phrase_from_dataframe(p, midi_cc, sampling_rate)

#--- filter 'errors'
min_velocity = 3
err_notes = (midi_p.type == 'note_on') & (midi_p.velocity <= min_velocity)
err_notes.loc[err_notes[err_notes].index + 1] = True #--- add the corresponding note-off
midi_p = midi_p[~err_notes]

#x, env, freq = phrase_to_synth(seg, sr, midi_p, t0, num_harmonics = 30, spline_smoothing = 2, verbose = False)
x, env, freq = phrase_to_synth(seg, sr, midi_p, t0, max_freq_hz=16000, spline_smoothing = .5, verbose = False)

#--- apply hifi-gan
pre_gain = 0.6
x_hat = generate_from_audio(pre_gain * x, gen)
ipd.display(ipd.Audio(seg, rate = sr, normalize=play_normalize))
ipd.display(ipd.Audio(x_hat, rate = sr, normalize=play_normalize))

do_plot = False
if do_plot:
    fig, ax = plt.subplots(figsize = (8,4))
    ax.plot(seg)
    ax.plot(x)
    ax.plot(x_hat)
    ax.legend(['orig', 'synth_in', 'generated'])

# choose a note and fit (manually...) an ADSR env using cubic-Bezier curves

In [None]:
#--- compare 2 spline smoothing params (0.5 was used for training)
x1, env1, freq1 = phrase_to_synth(seg, sr, midi_p, t0, max_freq_hz=16000, spline_smoothing = 2, verbose = False)
x2, env2, freq2 = phrase_to_synth(seg, sr, midi_p, t0, max_freq_hz=16000, spline_smoothing = .1, verbose = False)
k1, k2 = 52283, 67500
env0 = env1[k1:k2] # use this to manually fir an ADSR env using Bezier etc.

In [None]:
import numpy as np
from scipy.special import comb

def get_bezier_parameters(X, Y, degree=3):
    """ Least square qbezier fit using penrose pseudoinverse.

    Parameters:

    X: array of x data.
    Y: array of y data. Y[0] is the y point for X[0].
    degree: degree of the Bézier curve. 2 for quadratic, 3 for cubic.

    Based on https://stackoverflow.com/questions/12643079/b%C3%A9zier-curve-fitting-with-scipy
    and probably on the 1998 thesis by Tim Andrew Pastva, "Bézier Curve Fitting".
    """
    if degree < 1:
        raise ValueError('degree must be 1 or greater.')

    if len(X) != len(Y):
        raise ValueError('X and Y must be of the same length.')

    if len(X) < degree + 1:
        raise ValueError(f'There must be at least {degree + 1} points to '
                         f'determine the parameters of a degree {degree} curve. '
                         f'Got only {len(X)} points.')

    def bpoly(n, t, k):
        """ Bernstein polynomial when a = 0 and b = 1. """
        return t ** k * (1 - t) ** (n - k) * comb(n, k)
        #return comb(n, i) * ( t**(n-i) ) * (1 - t)**i

    def bmatrix(T):
        """ Bernstein matrix for Bézier curves. """
        return np.matrix([[bpoly(degree, t, k) for k in range(degree + 1)] for t in T])

    def least_square_fit(points, M):
        M_ = np.linalg.pinv(M)
        return M_ * points

    T = np.linspace(0, 1, len(X))
    M = bmatrix(T)
    points = np.array(list(zip(X, Y)))
    
    final = least_square_fit(points, M).tolist()
    final[0] = [X[0], Y[0]]
    final[len(final)-1] = [X[len(X)-1], Y[len(Y)-1]]
    return final

#--- functions copied from: https://stackoverflow.com/questions/12643079/b%C3%A9zier-curve-fitting-with-scipy
def bernstein_poly(i, n, t):
    """
     The Bernstein polynomial of n, i as a function of t
    """
    return comb(n, i) * ( t**(n-i) ) * (1 - t)**i


def bezier_curve(points, nTimes=50):
    """
       Given a set of control points, return the
       bezier curve defined by the control points.

       points should be a list of lists, or list of tuples
       such as [ [1,1], 
                 [2,3], 
                 [4,5], ..[Xn, Yn] ]
        nTimes is the number of time steps, defaults to 1000

        See http://processingjs.nihongoresources.com/bezierinfo/
    """

    nPoints = len(points)
    xPoints = np.array([p[0] for p in points])
    yPoints = np.array([p[1] for p in points])

    t = np.linspace(0.0, 1.0, nTimes)

    polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints)   ])

    xvals = np.dot(xPoints, polynomial_array)
    yvals = np.dot(yPoints, polynomial_array)

    return xvals, yvals

from matplotlib.patches import Polygon

fig, ax = plt.subplots(figsize = (14,6))
#ax.plot(env1,'.')
#ax.plot(env2,':')
adsr = np.array([0, 1900, 3600, 8100, len(env0)])
n0, n1, n2, n3, n4 = adsr
e0, e1, e2, e3, e4 = [env0[k-1] for k in adsr]
ax.plot(env0, '.-')

a = e1
cols = ['b', 'r', 'g', 'c']
for k in range(4):
    poly = Polygon([[adsr[k], 0], [adsr[k+1], 0], [adsr[k+1], a], [adsr[k],a]], facecolor=cols[k], alpha = 0.15, edgecolor='0.2', closed=True)
    ax.add_patch(poly)

nA, nD, nS, nR = np.diff(adsr)

#curveA = get_bezier_parameters(np.arange(adsr[1]), env0[:nA])
#Ax, Ay =  bezier_curve(curveA, nA)

#=== Attack
Ax, Ay =  bezier_curve([[n0, e0],  [.8 * n1, e0],      [.5 * n1, e1],          [n1, e1]], nA)

#=== Decay
de1 = (e1 - e2)
Dx, Dy =  bezier_curve([[n1, e1], [n1 + .5 * nD, e1], [n2 - .5 * nD, e2 + .4 * de1], [n2, e2]], nD)

#--- Sustain
de2 = (e2 - e3)
Sx, Sy =  bezier_curve([[n2, e2], [n2 + .5 * nD, e2 - .4 * de1], [n3 - .4 * nS, e3 + .1 * de2], [n3, e3]], nS)

#--- Release
de3 = (e3 - e4)
Rx, Ry =  bezier_curve([[n3, e3], [n3 + .4 * nS, e3 - .1 * de2], [n4 - 1.2 * nR, e4 + 0 * de3], [n4, e4]], nR)

ax.plot(Ax, Ay, 'r')
ax.plot(Dx, Dy, 'r')
ax.plot(Sx, Sy, 'r')
ax.plot(Rx, Ry, 'r')

#--- compare with cubic spline
from scipy.interpolate import CubicSpline
spl = CubicSpline(adsr, [env0[k-1] for k in adsr], bc_type='clamped')
env_spl = spl(np.arange(n4))
ax.plot(env_spl, ':')

#ax.plot(np.arange(len(env1)) - k1, env1,':')

#ax.set_xlim([n1,15000])
#ax.set_ylim([.1,.14])
#ax.legend(['original envelope', 'piece-wise cubic Bezier', 'cubic spline'])
ax.grid()

## synthesize using midi phrase

In [None]:
EWI_TEST = False
if EWI_TEST:
    midi_fnm = '/home/mlspeech/itamark/ssynth/code/ewimididemo.mid'
    #-- we only need pitch and aftertouch, since CC messages from the EWI are identical to aftertouch (CC we get are: 2 - breath control, 7 - main volume, 11 - expression) 
    #--- midi CC 5 (portamento time) - don't use it at the moment
    midi_p, midi_pitch, midi_aftertouch, midi_p_cc = read_midi_to_df(midi_fnm, time_offset_sec = -3.)
    min_velocity = 7
    err_notes = (midi_p.type == 'note_on') & (midi_p.velocity <= min_velocity)
    err_notes.loc[err_notes[err_notes].index + 1] = True #--- add the corresponding note-off
    midi_p = midi_p[~err_notes]
    #--- clip to alto sax range
    note_min, note_max = librosa.hz_to_midi(alto_sax_range)
    note_min +=1
    note_max -=1
    midi_p.loc[midi_p.note < note_min, 'note'] = note_min
    midi_p.loc[midi_p.note > note_max, 'note'] = note_max
    
    #tmin = 18 #3
    tmax = np.inf #24.75 #27.75 #18
    midi_p, midi_pitch, midi_aftertouch, midi_p_cc = (df[(df.ts_sec <= tmax)] for df in [midi_p, midi_pitch, midi_aftertouch, midi_p_cc])
    env_control = midi_aftertouch
    freq_control = midi_pitch
    freq_control_gain_st = 2 #--- in unit of semi-tones (the pitch-bend at max control value)
    seg_len = int((0.5 + midi_p.ts_sec.max()) * sampling_rate)
    t0 = 0
    use_midi_cc = True
    use_midi_pitch = True
    gain = 0.07
else:
    midi_pitch = None
    midi_aftertouch = None
    env_control = midi_p_cc
    seg_len = len(x)
    use_midi_cc = True
    use_midi_pitch = False
    gain = 0.16
    
#--- parameters

attack_time_sec = 15e-3
attack_time_samples = int(attack_time_sec * sampling_rate)

attack_gain_lin = np.linspace(0, 1, attack_time_samples)
attack_gain_quartic = attack_gain_lin ** 4
attack_gain_sigmoid = 1 / (1 + np.exp(-10*(attack_gain_lin-.5)))

range_to_zero_one = lambda x: (x - x.min()) / (x.max() - x.min())
attack_gain = range_to_zero_one(attack_gain_sigmoid)

num_notes = midi_p.shape[0]

freq_midi = np.zeros(seg_len)
env_midi = np.zeros(seg_len)

for k in np.arange(0, num_notes, 2):
    row_on = midi_p.iloc[k]
    row_off = midi_p.iloc[k + 1]
    
    t_on, t_off = row_on.ts_sec, row_off.ts_sec
    k_on, k_off = int((t_on - t0) * sampling_rate), int((t_off - t0) * sampling_rate)
    cc_note = env_control[(env_control.ts_sec >= t_on) & (env_control.ts_sec <= t_off)]
    pitch_note = freq_control[(freq_control.ts_sec >= t_on) & (freq_control.ts_sec <= t_off)]
    if len(cc_note) < 3:
        continue
    note_len = k_off - k_on

    #t_on, t_off = row_on.ts_sec - t0, row_off.ts_sec - t0
    #k_on, k_off = int(t_on * sampling_rate), int(t_off * sampling_rate)
    #note_len = k_off - k_on
    
    #--- freq env
    note_hz = librosa.midi_to_hz(row_on.note)
    freq_midi[k_on:k_off] = note_hz #+ 5*np.sin(2*np.pi*15 * np.arange(0,note_len) / sampling_rate)
    
    #--- amplitude env
    note_attack_time_samples = min(int(note_len / 2), attack_time_samples) #--- make sure attack is not longer than the note itself
    if not use_midi_cc:
        note_gain = gain * row_on.velocity / 128
        env_midi[k_on:k_off] = note_gain * np.r_[np.linspace(0, 1, note_attack_time_samples), np.linspace(1, 0, note_len - note_attack_time_samples)]
    else:
        #env_cc = gain * interp1d(cc_note.ts_sec, cc_note.value, fill_value = 'extrapolate')(np.linspace(t_on, t_off, note_len)) / 128 
        env_cc = gain * interp1d(cc_note.ts_sec, cc_note.value, kind='linear', bounds_error=False, fill_value = 0)(np.linspace(t_on, t_off, note_len)) / 128 
        env_cc[env_cc < 0] = 0
        env_cc[:attack_time_samples] *= attack_gain
        env_midi[k_on:k_off] = env_cc
        
    if use_midi_pitch:
        #env_pitch = interp1d(pitch_note.ts_sec, pitch_note.pitch / 2**13, kind='linear', bounds_error=False, fill_value = 'extrapolate')(np.linspace(t_on, t_off, note_len))
        if pitch_note.shape[0] > 1:
            env_pitch = interp1d(pitch_note.ts_sec, pitch_note.pitch / 2**13, kind='linear', bounds_error=False, fill_value = 'extrapolate')(np.linspace(t_on, t_off, note_len))
            freq_mult = 2 ** (freq_control_gain_st * env_pitch / 12)
            freq_midi[k_on:k_off] *= freq_mult

#x_midi = additive_synth_sawtooth(freq_midi, env_midi, sampling_rate, additive_synth_k=30)
x_midi = additive_synth_sawtooth(freq_midi, env_midi, sampling_rate, max_freq_hz = 16000)
x_midi_hat = generate_from_audio(x_midi, gen, return_numpy_arr = False)
x_midi_hat_den = denoiser(x_midi_hat.squeeze(1), denoising_strength)

In [None]:
fig, ax = plt.subplots(figsize = (8,4))
ts = t0 + np.arange(0, len(env_midi)) / sampling_rate
ax.plot(ts, env_midi / gain,':.')

if midi_aftertouch and midi_pitch:
    ax.plot(midi_pitch.ts_sec, midi_pitch.pitch / 4096,'.')
    ax.plot(midi_aftertouch.ts_sec, midi_aftertouch.value / 128,'.')
else:
    
for irow, row in midi_p.iterrows():
    if row.type == 'note_on':
        ax.plot([row.ts_sec, row.ts_sec], [0., 1], 'r-o')
    else:
        ax.plot([row.ts_sec, row.ts_sec], [0., 0.6], 'c.-')

ax.grid()
ax.legend(['out midi env', 'pitchwheel','aftertouch'])

In [None]:
if not EWI_TEST:
    print('Original recording:')
    ipd.display(ipd.Audio(seg, rate = sr))
    print('Sawtooth signal synthesized using real envelopes:')
    ipd.display(ipd.Audio(x, rate = sr))
    print('Generated audio, using real envelopes:')
    #x_hat = x_hat[0].cpu().detach().numpy()[0]
    ipd.display(ipd.Audio(x_hat, rate = sampling_rate, normalize = False))

print('Sawtooth signal synthesized using "midi" envelopes:')
ipd.display(ipd.Audio(x_midi, rate = sr))

print('Generated audio, using "midi" envelopes:')
x_midi_hat = x_midi_hat[0].cpu().detach().numpy()[0]
ipd.display(ipd.Audio(x_midi_hat, rate = sampling_rate, normalize = False))

print('Generated audio, using "midi" envelopes (denoised):')
x_midi_hat_den = x_midi_hat_den[0].cpu().detach().numpy()[0]
ipd.display(ipd.Audio(x_midi_hat_den, rate = sampling_rate, normalize = False))

In [None]:
#--- plot envelope from origianl recording, and note on/off from parallel midi

ts = t0 + np.arange(0, len(x_midi)) / sampling_rate
#spl = UnivariateSpline(ts, env, s=.5, k = 2)
#env_spline = spl(ts)

fig, ax = plt.subplots(figsize = (8,4))
ax.plot(ts[::10],x_midi[::10],'g:')
#ax.plot(ts,env,'.')
ax.plot(ts,env_midi,'x')
#ax.plot(midi_p_cc.ts_sec, midi_p_cc.value / 128 * gain, 'o')
ax.plot(midi_aftertouch.ts_sec, midi_aftertouch.value / 128 * gain, 'o')
for irow, row in midi_p.iterrows():
    if row.type == 'note_on':
        ax.plot([row.ts_sec, row.ts_sec], [0., 0.15], 'g:')
    else:
        ax.plot([row.ts_sec, row.ts_sec], [0., 0.12], 'r.-')

# "APPENDIX"
## Compare MEL spectra of the 2 implementations that are used in the HiFiGAN code
(they are not the same :-( )

In [None]:
mel1 = mel_spec.get_mel(y)
mel2 = mel_spec2(y)

In [None]:
fig, ax = plt.subplots(figsize = (8,4))
k = 25
ax.plot(mel1[0, :,k], 'bo')
ax.plot(mel2[0, :,k], 'r.')
ax.legend(['mel-1', 'mel-2'])

## measure timing of mel + inference

## network's impulse response

In [None]:
imp = np.r_[np.zeros(sampling_rate) , 1., np.zeros(sampling_rate)].astype(np.float32)
ir = generate_from_audio(imp, gen)

imp = imp.reshape((1,len(imp)))
imp = torch.tensor(imp).clone().detach()
imp_mel = mel_spec2(imp).cpu().numpy()

In [None]:
fig, ax = plt.subplots(figsize = (8,4))
#ax.imshow(imp_mel[0], aspect='auto')
imp_ = imp.cpu().detach().numpy()[0,40000:50000]
#ax.plot(imp_mel[0,:,170:174])
ax.plot(ir[40000:50000])
ax.plot(0.1*imp_,'r.-')
ipd.display(ipd.Audio(ir[40000:50000], rate = sampling_rate, normalize = True))

In [None]:
fig, ax = plt.subplots(figsize = (8,4))
df_cc2 = df_cc[df_cc.control==2]
ax.plot(midi_aftertouch.ts_sec, midi_aftertouch.value,'o')
ax.plot(df_cc2.ts_sec, df_cc2.value,'x')
ax.plot(midi_pitch.ts_sec, midi_pitch.pitch / 2 ** 5,'.')

In [None]:
df_cc2