# HarmoniX

*"We have Auto-Tune at home"*

In [1]:
import pickle

import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
import IPython.display as ipd
import fmplib as fmp
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

torch.manual_seed(1)

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 4)
plt.rcParams['image.interpolation'] = 'nearest'

## Part 1: Pitch detection

I use the [YIN algorithm](http://audition.ens.fr/adc/pdf/2002_JASA_YIN.pdf)
for pitch detection. The following code is based on
[Patrice Guyot's implementation](https://github.com/patriceguyot/Yin/tree/master).

The algorithm works as follows:
1. First, we split the audio into overlapping windows. A default window
   size of $W = 512$ with a hop size of $H = 256$ typically works quite well.
2. For each window:
   1. We compute the "difference function":
      $$d_t(\tau) = \sum_{i = 1}^W (x_i - x_{i + \tau})$$
      which is essentially the squared Euclidean distance between the window
      and the window shifted by the period $\tau$.
   2. Next, we compute the "cumulative mean normalized difference function":
      $$d_t'(\tau) = \begin{cases}1 & \text{if $\tau = 0$}\\ d_t(\tau) \big/
      \left(\frac{1}{\tau} \sum_{i = 1}^\tau d_t(i) \right) & \text{otherwise} \end{cases}$$
      which prevents too-small $\tau$ from being chosen as the correct period.
   3. Finally, we scan through values of $\tau$ and find the first trough below
      some "harmonicity" threshold. A threshold of $0.1$ typically works quite well.
3. If no valid $\tau$ is found for a window, the algorithm labels it as "no pitch".

### Difference function

A naive implementation of the difference function is as follows:

In [2]:
def yin_df_slow(x, tau_max):
    df = np.zeros(tau_max)
    for tau in range(tau_max):
        for i in range(len(x) - tau_max):
            df[tau] += (x[i] - x[i + tau])**2
    return df

However, this implementation runs in $O(W^2)$ time and will not scale well. Luckily,
we can express $d_t(\tau)$ in terms of the autocorrelation function $r_t$ as:
$$d_t(\tau) = r_t(0) + r_{t + \tau}(0) - 2r_t(\tau)$$
where $r_t(\tau)$ is the autocorrelation of the signal with autocorrelation period $\tau$
starting at sample $t$.

By the [Wiener-Khinchin Theorem](https://mathworld.wolfram.com/Wiener-KhinchinTheorem.html),
we can efficiently compute $r_t$ using the FFT! This brings the computation down to $(W \log W)$ time.

In [3]:
def yin_df_fast(x, tau_max):
    w = len(x)
    cumsum = np.concatenate((np.array([0]), (x * x).cumsum()))
    conv = scipy.signal.fftconvolve(x, x[::-1])
    df = 2 * cumsum[w] - cumsum[:w] - 2 * conv[w - 1:]
    return df[:tau_max]

### Everything else

Everything else is rather straightforward and follows directly from the steps described in the paper.

One additional step I implemented here was median filtering the output to remove noisy outliers.

In [4]:
def yin_cmndf(df, w):
    cmndf = df[1:] * np.arange(1, w) / np.cumsum(df[1:])
    return np.insert(cmndf, 0, 1)


def yin_pitch_from_cmndf(cmndf, tau_min, tau_max, trough_thresh):
    tau = tau_min
    while tau < tau_max:
        if cmndf[tau] < trough_thresh:
            while tau + 1 < tau_max and cmndf[tau + 1] < cmndf[tau]:
                tau += 1
            return tau
        tau += 1
    return -1  # Failed to find a pitch


def detect_pitches(x, fs, ac_win_len=512, ac_hop_size=256, f_min=100, f_max=1200,
                   trough_thresh=0.1, med_filt_win_len=5):
    w = len(x)
    tau_min = fs // f_max
    tau_max = fs // f_min
    windows = [x[i:i + ac_win_len] for i in range(0, w - ac_win_len, ac_hop_size)]
    pitches = np.empty(len(windows))
    for i, win in enumerate(windows):
        # Steps 2: Difference function
        df = yin_df_fast(win, tau_max)
        # Step 3: Cumulative mean normalized difference function
        cmndf = yin_cmndf(df, tau_max)
        # Step 6: Estimate pitches from the CMNDF
        p = yin_pitch_from_cmndf(cmndf, tau_min, tau_max, trough_thresh)
        pitches[i] = p
    # Postprocessing: Convert samples -> frequencies -> pitches
    quantized_pitches = np.full_like(pitches, -1)
    quantized_pitches[pitches > 0] = fmp.freq_to_pitch(fs / pitches[pitches > 0]).round()
    # Postprocessing: Filter outliers using a median filter
    quantized_pitches = scipy.signal.medfilt(quantized_pitches, med_filt_win_len)
    return (quantized_pitches, np.arange(len(windows)))

As a bonus, pitch detection also gives us the positions of onsets
fairly reliably, so we don't need an extra step for that!

## Part 2: Melody-to-chord generation

I use a bidirectional LSTM (BiLSTM) to generate chords from a melody. BiLSTMs
have several advantages over the standard HMM approach, with the main ones being:
* They are stateful and can learn chord *progressions*, unlike HMMs which are
  memoryless by definition.
* They have more tunable parameters.
* Neural networks are cool.

The BiLSTM architecture I use is taken from [this paper](https://archives.ismir.net/ismir2017/paper/000134.pdf).
There are probably better architectures out there, but this one works well enough
and is super quick and easy to train without overfitting.

(I mean technically there will always be some overfitting when working with
neural networks, but the effects of overfitting are negligible here.)

Here it is implemented in PyTorch:

In [5]:
class ChordGenerator(nn.Module):
    def __init__(self, hidden_size=128, num_layers=2):
        super().__init__()
        self.lstm = nn.LSTM(12, hidden_size, num_layers, bidirectional=True, dropout=0.2)
        self.hidden2chords = nn.Linear(hidden_size * 2, 24)

    def forward(self, melody):
        # Assumes `melody` is an L x 12 numpy array
        lstm_out, _ = self.lstm(torch.from_numpy(melody).view(len(melody), 1, -1))
        chord_logits = self.hidden2chords(lstm_out.view(len(melody), -1))
        return chord_logits

### (Optional) Training the model

To train the model, we first load the dataset (the same as the one used in the paper):

In [13]:
def load_training_data(corpus_path, val_ratio=0.1):
    np.random.seed(0)
    with open(corpus_path, 'rb') as f:
        data_corpus = pickle.load(f)
    # Inputs and targets for training/validation
    input_melody_train = []
    output_chord_train = []
    input_melody_val = []
    output_chord_val = []
    # Process each song in the corpus
    for idx, song in enumerate(data_corpus):
        melody = np.array(song[0][0]).astype(np.float32)
        chord = np.array(song[0][1])
        # Filter out the rests
        not_rests = (chord != 0)
        melody = melody[not_rests]
        chord = chord[not_rests] - 1
        if len(chord) == 0:
            continue
        # Convert chord to one-hot vector
        chord_onehot = np.zeros((len(chord), 24))
        chord_onehot[np.arange(len(chord)), chord] = 1
        # Randomly assign to training or validation set
        if np.random.rand() > val_ratio:
            input_melody_train.append(melody)
            output_chord_train.append(chord_onehot)
        else:
            input_melody_val.append(melody)
            output_chord_val.append(chord_onehot)
    print(f'Successfully loaded {len(data_corpus)} pieces!')
    return list(zip(input_melody_train, output_chord_train)), list(zip(input_melody_val, output_chord_val))

data_train, data_val = load_training_data('./corpus.bin')

Successfully loaded 5786 pieces!


Then we use stochastic gradient descent with cross-entropy loss to train the model:

In [None]:
from random import shuffle

model = ChordGenerator()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.7)

NUM_EPOCHS = 10
for epoch in range(NUM_EPOCHS):
    # Train
    model.train()
    shuffle(data_train)
    train_losses = []
    for melody, chord_target in tqdm(data_train):
        optimizer.zero_grad()
        chord_logits = model(melody)
        loss = loss_function(chord_logits, torch.from_numpy(chord_target))
        train_losses.append(loss.item())
        loss.backward()
        optimizer.step()
    # Validate
    model.eval()
    val_losses = []
    val_accuracies = []
    with torch.no_grad():
        for melody, chord_target in tqdm(data_val):
            chord_logits = model(melody)
            chord_preds = torch.argmax(chord_logits, dim=1)
            chord_labels = torch.argmax(torch.from_numpy(chord_target), dim=1)
            loss = loss_function(chord_logits, torch.from_numpy(chord_target))
            val_losses.append(loss.item())
            val_accuracies.append((chord_preds == chord_labels).float().mean())
    print(f'Average training loss: {np.mean(train_losses):.4}')
    print(f'Average validation loss: {np.mean(val_losses):.4}')
    print(f'Average validation accuracy: {np.mean(val_accuracies):.4}')

Finally, we can save the model to load later:

In [None]:
# torch.save(model.state_dict(), './blstm_model2.pt')

## Part 2.5: Chord synthesis

The model above outputs a list of numbers between 0 and 23 (inclusive) to
denote the chords. To turn these numbers into actual chords, we can create
a binary matrix where each column is a 12-vector encoding a "chord template":

In [6]:
def binary_chord_template(pitches):
    ret = np.zeros(12)
    ret[np.mod(pitches, 12)] = 1
    return ret / np.linalg.norm(ret)

Then we can greedily choose the closest pitches below the melody as the notes in the chord to play:

In [7]:
def get_shift_steps(pitch_classes, model, horrible=False):
    pitch_classes_onehot = np.zeros((len(pitch_classes), 12), dtype=np.float32)
    pitch_classes_onehot[np.arange(len(pitch_classes)), pitch_classes] = 1

    major_chord_templates = np.array([binary_chord_template(np.array([0, 4, 7]) + i) for i in range(12)])
    minor_chord_templates = np.array([binary_chord_template(np.array([0, 3, 7]) + i) for i in range(12)])
    chord_templates = np.empty((24, 12))
    chord_templates[1::2] = major_chord_templates
    chord_templates[::2] = minor_chord_templates

    if horrible:
        chord_preds = model(pitch_classes_onehot).argmin(dim=1)
    else:
        chord_preds = model(pitch_classes_onehot).argmax(dim=1)
    shift_steps = []
    for pitch, chord in zip(pitch_classes, chord_preds):
        shift_steps.append([])
        for k in range(12):
            if chord_templates[chord, k]:
                n_steps = k - pitch
                if n_steps >= 0:
                    n_steps -= 12
                shift_steps[-1].append(n_steps)
        shift_steps[-1].sort()
    return np.array(shift_steps)

(By the way, one fun thing we can do with the BiLSTM is choose the argmin of the logits to
get the worst possible chords for the melody. This is represented by the `horrible` parameter
in the function above.)

## Part 3: Pitch shifting

I use a [digital phase vocoder](https://www.di.ens.fr/~mallat/papiers/Vocoder.pdf) to pitch shift.
It's kind of hard to explain exactly how this algorithm works
([here's a great video explanation](https://youtu.be/PjKlMXhxtTM?si=JFyHBT-c-U7vVCQS)),
but the TL;DR is:
1. First, we take the STFT of the signal.
2. Next, we resample the STFT to get a stretched/squeezed STFT.
    * Suppose I want to pitch one octave down (i.e., halve the frequency).
      Then I would discard every other column of the STFT to get a squeezed
      STFT that represents a signal played at the same frequency but twice
      as fast.
    * However, we can't just naively resample, or else we get discontinuities
      in the reconstructed signal! To address this problem, we resample the
      *magnitude* of the STFT as described above, but keep track of phase separately
      to ensure continuity.
3. Next, we take the inverse STFT to get a time-stretched signal
   at the same frequency.
4. Finally, we resample the reconstructed signal to get a pitch-shifted signal.

In [8]:
def pitch_shift(x, win_len, hop_size, indices, shift_steps):
    # Step 1: STFT
    stft = fmp.stft(x, win_len, hop_size)
    
    # Step 2: Construct list of points to resample at
    resample_idx = np.arange(indices[0])
    shifted_indices = np.concatenate((indices[1:], [len(x)]))
    for l, r, step in zip(indices, shifted_indices, shift_steps):
        scale = 2**(-step / 12)
        resample_idx = np.concatenate((resample_idx, np.arange(l, r, scale)))
    resample_stft = resample_idx[::hop_size] / hop_size
    stretched_stft = np.zeros((win_len // 2 + 1, len(resample_stft)), dtype=np.complex_)
    
    # Step 3: Resample the STFT based on the resample points
    curr_phase = np.ones(win_len // 2 + 1, dtype=np.complex_)
    for idx, i in enumerate(resample_stft):
        # Interpolate phase and magnitude between two windows
        if int(i) + 1 >= stft.shape[1]:
            continue
        win1 = stft[:, int(i)]
        win2 = stft[:, int(i) + 1]
        delta = i - int(i)
        mag = (1 - delta) * np.abs(win1) + delta * np.abs(win2)
        # Reconstruct the complex values of the stretched STFT
        stretched_stft[:, idx] = mag * curr_phase
        # Accumulate phase
        curr_phase *= win2 / win1
        curr_phase /= np.abs(curr_phase)
        
    # Step 4: Reconstruct the signal using the ISTFT
    stretched = fmp.istft(stretched_stft, hop_size)
    
    # Step 5: Resample the ISTFT signal to pitch shift
    inverted_resample_idx = np.interp(np.arange(len(x)), resample_idx, np.arange(len(resample_idx)))
    shifted = np.interp(inverted_resample_idx, np.arange(len(stretched)), stretched)
    return shifted

## Putting it all together

In [9]:
def harmonize(x, fs, horrible=False, show_parts=False, ac_win_len=1024, ac_hop_size=256, f_min=100, f_max=1200,
              trough_thresh=0.05, med_filt_win_len=9, stft_win_len=2048, stft_hop_size=128):
    # Step 1: Pitch estimation
    pitches, times = detect_pitches(x, fs, ac_win_len, ac_hop_size, f_min, f_max,
                                    trough_thresh, med_filt_win_len)
    consecutive_unique = np.concatenate(([True], pitches[1:] % 12 != pitches[:-1] % 12))
    nonzero = (pitches != -1)
    indices = times[consecutive_unique & nonzero] * ac_hop_size

    pitch_classes = pitches[consecutive_unique & nonzero].astype(int) % 12

    # Step 2: Chord generation
    model = ChordGenerator()
    model.load_state_dict(torch.load('./blstm_model.pt', weights_only=True))
    model.eval()
    shift_steps = get_shift_steps(pitch_classes, model, horrible)

    # Step 3: Pitch shifting
    shifted = [None] * 3
    for i in range(3):
        shifted[i] = pitch_shift(x, stft_win_len, stft_hop_size, indices, shift_steps[:, i])
        if show_parts:
            ipd.display(ipd.Audio(shifted[i], rate=fs))
    combined = np.sum(shifted, axis=0) + x
    return combined

### Example: Happy Birthday

In [12]:
snd = fmp.load_wav("audio/happy_birthday.wav")
fs = 22050
ipd.display(ipd.Audio(snd, rate=fs))

Nice:

In [13]:
harmonized = harmonize(snd, fs, ac_win_len=1024, trough_thresh=0.1, med_filt_win_len=11, show_parts=True)
ipd.display(ipd.Audio(harmonized, rate=fs))

Horrible:

In [14]:
harmonized = harmonize(snd, fs, ac_win_len=1024, trough_thresh=0.1, med_filt_win_len=11, horrible=True)
ipd.display(ipd.Audio(harmonized, rate=fs))