# Rewriting DDSP components in PyTorch from scratch

## Some imports for plotting and audio plalyback

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [17, 6]
from IPython.display import Audio
from ipywidgets import HTML

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F  # noqa
from torch import Tensor

from fftconv import fft_conv
import librosa
import librosa.display
import numpy as np

## Utility functions for plotting signals and FFTs (NOT STFTs)

In [None]:
def plot_wave(y, sr, title=None):
    n_samples = len(y)
    n_seconds = n_samples / sr
    plt.plot(np.linspace(0, n_seconds, n_samples), y)
    plt.title(title)
    plt.show()

In [None]:
def plot_fft(fft, sr, title=None):
    freqs = np.fft.rfftfreq(2 * len(fft) - 1, 1/sr)
    plt.plot(freqs, fft)
    plt.title(title)
    plt.show()

## Utility function for showing audio controls with title.

In [None]:
def title_audio(data, rate, title=None):
    template = """
    <figure>
        <figcaption>{title}:</figcaption>
        {data}
    </figure>
    """
    return HTML(template.format(data=Audio(data, rate=rate)._repr_html_(), title=title))

## Some constants, which will be DDSP/Network hyperparameters

In [None]:
sample_rate = 16000
hop_size = 512

In [None]:
signal, sample_rate = librosa.load('/home/kureta/Music/violin/Violin Samples/yee_bach_dance_D#52.wav', sample_rate)
signal = torch.from_numpy(signal[None, None, :])

# Notes

- kernel dimensions in fft_conv are (out_channels, in_channels, size)
- signal dimensions are (batch, channels, length)
- I might need to implement my own FFT convolution

# Reverb
- Data coming to reverb has the shape (batch, features, sequence) (features = num_channels?).
- If operating in realtime mode, we keep a buffer of reverb tails to add onto the next piece of audio.

In [None]:
plot_wave(signal[0, 0].numpy(), sample_rate, title='Original audio')
Audio(signal[0, 0], rate=sample_rate)

Preparing an impulse response, which results in the original audio plus half a second delay.

In [None]:
ir = torch.zeros(1, 1, sample_rate)
ir[0, 0, 0] = 1.0
ir[0, 0, sample_rate // 2] = 1.0
ir = ir.flip(-1)

In [None]:
plot_wave(ir[0, 0].flip(-1).numpy(), sample_rate, 'Impulse Response')

Do the actual convolution.

In [None]:
with torch.no_grad():
    result = fft_conv(F.pad(signal, (ir.shape[-1]-1, ir.shape[-1])), ir)

In [None]:
plot_wave(result[0, 0].numpy(), sample_rate, title='With 0.5 sec. delay')
Audio(result[0, 0], rate=sample_rate)

# Final Reverb Class
Works both offline and realtime

## TODOs:
- [ ] Reverberated signal is too loud. IR should be normalized or something.

In [None]:
class Reverb(nn.Module):
    def __init__(self, sample_rate=16000, duration=1.0, batch_size=1, live=False, n_channels=1):
        super().__init__()
        
        self.duration = int(sample_rate * duration)
        self.sample_rate = sample_rate
        self.batch_size = batch_size
        self.live = live
        self.n_channels = n_channels
        
        # ir.shape = (out_channels, in_channels, size)
        self.ir = nn.Parameter(torch.rand(n_channels, n_channels, self.duration) * 2.0 - 1.0, requires_grad=True)
        self.register_buffer('buffer', torch.zeros(self.batch_size, n_channels, self.duration), persistent=False)
    
    def forward(self, signal):
        if self.live:
            with torch.no_grad():
                return self.forward_live(signal)
        else:
            return self.forward_learn(signal)
    
    def forward_learn(self, signal):
        ir = self.ir.flip(-1)
        signal_length = signal.shape[-1]
        
        result = fft_conv(F.pad(signal, (self.duration-1, self.duration)), ir)
        
        return result[..., :signal_length]
    
    def forward_live(self, signal):
        ir = self.ir.flip(-1)
        signal_length = signal.shape[-1]
        
        # Do the thing
        result = fft_conv(F.pad(signal, (self.duration-1, self.duration)), ir)
        
        # Separate reverberated signal and tail
        out = result[..., :signal_length]
        tail = result[..., signal_length:]
        
        # add AT MOST first signal_length samples of the old buffer to the result
        # reverb duration might be shorter than signal length. In that case, tail of the previous signal
        # is shorter than the current signal.
        previous_tail = self.buffer[..., :signal_length]
        prev_tail_len = previous_tail.shape[-1]
        out[..., :prev_tail_len] += previous_tail
        
        # zero out used samples of the old buffer
        self.buffer[..., :prev_tail_len] = 0.0
        
        # roll used samples to the end
        self.buffer = self.buffer.roll(-prev_tail_len, dims=-1)
        
        # add new tail to buffer
        self.buffer += tail
        
        return out

In [None]:
reverb = Reverb(batch_size=10, n_channels=2)

In [None]:
signal, sample_rate = librosa.load('/home/kureta/Music/violin/Violin Samples/yee_bach_dance_D#52.wav', sample_rate)
signal = torch.from_numpy(signal)
while signal.ndim < 3:
    signal.unsqueeze_(0)

In [None]:
signal = signal.repeat(10, 1, 1)

In [None]:
reverb.live = True
result = []
with torch.no_grad():
    for i in range(53):
        result.append(reverb(signal[..., i*855:(i+1)*855]))
result = torch.cat(result, -1)

In [None]:
reverb.live = False
with torch.no_grad():
    result = reverb(signal)

In [None]:
plot_wave(result[0, 0].numpy(), sample_rate, title='With 0.5 second delay')
Audio(result[0, :], rate=sample_rate)

# Noise
For every control input generates `hop_size` length band filtered noise samples.

In [None]:
def get_noise(batch_size=1, hop_size=512):
    return torch.rand(batch_size, 1, hop_size) * 2.0 - 1.0

In [None]:
batch_size = 10
seq_len = 100

In [None]:
noise_buffer = get_noise(batch_size, hop_size * 2)
windowed_buffer = torch.zeros(batch_size, 1, 3 * hop_size)

In [None]:
bands = torch.zeros(batch_size, seq_len, 1, sample_rate // 2 + 1)
for i in range(seq_len):
    bands[:, i, 0, i*1:i*1+110] = 10.0
    bands[:, i, 0, i*40+1200:i*40+2200] = 1.0
nir = torch.fft.irfft(bands, dim=-1)
nir = torch.fft.fftshift(nir, dim=-1)
nir = nir.permute(1, 0, 2, 3)

In [None]:
plot_fft(bands[0, 0, 0].numpy(), sample_rate, 'Noise filter bands (freqency domain)')

In [None]:
result = []
with torch.no_grad():
    for i, ir in enumerate(nir):
        f = fft_conv(F.pad(noise_buffer.view(1, -1, hop_size*2), (sample_rate-1, sample_rate)), ir.view(10, 1, 16000), groups=10)
        f = f[..., sample_rate//2:-sample_rate//2]
        w = f * torch.hann_window(hop_size * 2)
        w = w.permute(1, 0, 2)
        windowed_buffer[..., -2*hop_size:] += w
        r = windowed_buffer[..., -2*hop_size:-hop_size]
        
        windowed_buffer = windowed_buffer.roll(-hop_size, -1)
        windowed_buffer[..., -hop_size:] = 0.0
        noise_buffer = noise_buffer.roll(-hop_size, -1)
        noise_buffer[..., -hop_size:] = get_noise(batch_size, hop_size)
        result.append(r)
    
    result = torch.cat(result, -1)
    result[..., -512:] *= torch.hann_window(hop_size * 2)[-512:]

In [None]:
plot_wave(result[1, 0].numpy(), sample_rate, title='Filtered noise')
Audio(result[1, 0], rate=sample_rate)

In [None]:
plot_wave(result[0, 0].numpy(), sample_rate, title='Filtered noise')
Audio(result[0, 0], rate=sample_rate)

In [None]:
stft = librosa.stft(result[0, 0].numpy())
librosa.display.specshow(librosa.amplitude_to_db(np.abs(stft)))

## Get rid of the for loop

In [None]:
bands = torch.zeros(batch_size, seq_len, 1, sample_rate // 2 + 1)
for i in range(seq_len):
    bands[:, i, 0, i*1:i*1+110] = 10.0
    bands[:, i, 0, i*40+1200:i*40+2200] = 1.0
nir = torch.fft.irfft(bands, dim=-1)
nir = torch.fft.fftshift(nir, dim=-1)

In [None]:
unfold = nn.Unfold(kernel_size=(1, hop_size * 2), stride=(1, hop_size), padding=(0, 0))
# fold = nn.Fold((1, hop_size*seq_len + 2 * hop_size), kernel_size=(1, hop_size*2), stride=(1, hop_size), padding=(0, 0))

In [None]:
noise_buffer = torch.zeros(batch_size, 1, hop_size * 2) # get_noise(batch_size, hop_size * 2)
noise = torch.rand(batch_size, 1, 1, seq_len * hop_size + hop_size) * 2.0 - 1.0
framed_noise = unfold(noise)
windowed_noise = framed_noise * torch.hann_window(hop_size*2).unsqueeze(0).unsqueeze(-1)
windowed_noise = torch.cat([noise_buffer.permute(0, 2, 1), windowed_noise], dim=-1)
print(windowed_noise.shape)
re_noise = F.fold(windowed_noise, (1, hop_size*seq_len + 2 * hop_size), kernel_size=(1, hop_size*2), stride=(1, hop_size), padding=(0, 0))
print(framed_noise.shape, noise.shape, re_noise.shape)

In [None]:
framed_noise = framed_noise.permute(0, 2, 1)
framed_noise.shape

In [None]:
with torch.no_grad():
    f = fft_conv(F.pad(framed_noise.reshape(1, -1, hop_size*2), (sample_rate-1, sample_rate)), nir.reshape(batch_size*seq_len, 1, sample_rate), groups=batch_size*seq_len)
    f = f[..., sample_rate//2:-sample_rate//2]
    w = f * torch.hann_window(hop_size * 2)

print(w.shape)
w = w.reshape(10, 100, 1024)
w = w.permute(0, 2, 1)
print(w.shape)
w, noise_buffer = torch.cat([noise_buffer.permute(0, 2, 1), w], dim=-1), w[:, : -1]
print(w.shape)

shit = F.fold(w, (1, hop_size*seq_len + 2 * hop_size), kernel_size=(1, hop_size*2), stride=(1, hop_size), padding=(0, 0))
shit = shit.squeeze_(1)
print(shit.shape)

first `2 hop_size` bit is the tail of last noise generated. So:

- initialize a `2 * hop_size` `noise_buffer` to zero
- generate noise `hop_size` longer than requested
- unfold, filter, window
- store the last element as the next `noise_buffer`
- prepend the previous `noise_buffer` and overlap-add to `2 * hop_size` longer noise
- drop first and last `hop_size` segments and return

In [None]:
plot_wave(shit[1, 0, hop_size:-hop_size].numpy(), sample_rate, title='Filtered noise')
Audio(shit[1, 0, hop_size:-hop_size], rate=sample_rate)

In [None]:
stft = librosa.stft(shit[1, 0].numpy())
librosa.display.specshow(librosa.amplitude_to_db(np.abs(stft)))