In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [17, 6]
from IPython.display import Audio

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F  # noqa
from torch import Tensor

from fftconv import fft_conv
import librosa
import librosa.display
import numpy as np

In [None]:
def plot_wave(y, sr):
    n_samples = len(y)
    n_seconds = n_samples / sr
    plt.plot(np.linspace(0, n_seconds, n_samples), y)
    plt.show()

In [None]:
def plot_fft(fft, sr):
    freqs = np.fft.rfftfreq(2 * len(fft) - 1, 1/sr)
    plt.plot(freqs, fft)
    plt.show()

In [None]:
signal, sample_rate = librosa.load('/home/kureta/Music/violin/Violin Samples/yee_bach_dance_D#52.wav', 16000)

# Reverb
Data coming to reveb has the shape (batch, features, sequence) (features = num_channels?).

If operating in realtime mode, wee keep a buffer of reverb tails to add onto the next piece of audio.

When reverb size is 16000 samples, tail turns out to be 16001 samples. I don't know what to do with that 1 residual sample but I guess I'll just discard the last one. It is 1 sample after all, shouldn't make much difference.

In [None]:
signal = torch.from_numpy(signal[None, None, :])

In [None]:
Audio(signal[0, 0], rate=sample_rate)

In [None]:
plot_wave(signal[0, 0].numpy(), sample_rate)

In [None]:
ir = torch.zeros(1, 1, 16000)
ir[0, 0, 0] = 1.0
ir[0, 0, 5000] = 1.0
ir = ir.flip(-1)

In [None]:
plt.plot(ir[0, 0].flip(-1).numpy())

In [None]:
signal.shape, ir.shape

In [None]:
with torch.no_grad():
    result = fft_conv(signal, ir, padding=ir.shape[-1])

In [None]:
result.shape

In [None]:
Audio(result[0, 0], rate=sample_rate)

In [None]:
plt.plot(result[0, 0].numpy())

# Final Reverb Class
Works both offline and realtime

In [None]:
class Reverb(nn.Module):
    def __init__(self, sample_rate=16000, duration=1.0, batch_size=1, live=False):
        super().__init__()
        
        self.duration = int(sample_rate * duration)
        self.sample_rate = sample_rate
        self.batch_size = batch_size
        self.live = live
        
        self.ir = nn.Parameter(torch.rand(self.duration) * 2.0 - 1.0, requires_grad=True)
        self.register_buffer('buffer', torch.zeros(self.batch_size, 1, self.duration), persistent=False)
    
    def forward(self, signal):
        if self.live:
            with torch.no_grad():
                return self.forward_live(signal)
        else:
            return self.forward_learn(signal)
    
    def forward_learn(self, signal):
        ir = self.ir[None, None, :].flip(-1)
        signal_length = signal.shape[-1]
        
        result = fft_conv(signal, ir, padding=self.duration)
        
        return result[..., :signal_length]
    
    def forward_live(self, signal):
        ir = self.ir[None, None, :].flip(-1)
        signal_length = signal.shape[-1]
        
        # TODO: Understand why this is so.
        # Drop the last residual sample
        result = fft_conv(signal, ir, padding=ir.shape[-1])[..., :-1]
        
        # Separate reverberated signal and tail
        out = result[..., :signal_length]
        tail = result[..., signal_length:]
        
        # add AT MOST first signal_length samples of the old buffer to the result
        # reverb duration might be shorter than signal length. In that case, tail of the previous signal
        # is shorter than the current signal.
        previous_tail = self.buffer[..., :signal_length]
        prev_tail_len = previous_tail.shape[-1]
        out[..., :prev_tail_len] += previous_tail
        
        # zero out used samples of the old buffer
        self.buffer[..., :prev_tail_len] = 0.0
        
        # roll used samples to the end
        self.buffer = self.buffer.roll(-prev_tail_len, dims=-1)
        
        # add new tail to buffer
        self.buffer += tail
        
        return out

In [None]:
reverb = Reverb()

In [None]:
reverb.live = True
result = []
for i in range(53):
    result.append(reverb(signal[..., i*855:(i+1)*855]))
result = torch.cat(result, -1)

In [None]:
reverb.live = False
result = reverb(signal.reshape(5, 1, 45315//5))

In [None]:
result.shape

In [None]:
Audio(result[2, 0].detach(), rate=sample_rate)

In [None]:
plt.plot(result[0, 0].detach().numpy())

# Noise
For every control input generates `hop_size` length band filtered noise samples.

In [None]:
buffer = torch.zeros(1, 1, 16000-1+512)
noise = torch.rand(1, 1, 512) * 2.0 - 1.0
buffer[..., -512:] = noise

In [None]:
bands = torch.zeros(100, 1, 8001)
for i in range(100):
    bands[i, 0, i*20:i*20+110] = 1.0
    bands[i, 0, i*40:i*40+110] = 1.0
nir = torch.fft.irfft(bands, dim=-1)
nir = torch.fft.fftshift(nir, dim=-1)
nir.shape

In [None]:
plt.plot(bands[99, 0].numpy())

In [None]:
plot_wave(nir[0, 0].numpy(), sample_rate)

In [None]:
noise.shape, nir.shape

In [None]:
result = []
with torch.no_grad():
    for i in range(100):
        r = fft_conv(buffer, nir[i:i+1])
        buffer = buffer.roll(-512, -1)
        buffer[..., -512:] = torch.rand(1, 1, 512) * 2.0 - 1.0
        result.append(r)
    
    result = torch.cat(result, -1)

In [None]:
result.shape

In [None]:
Audio(result[0, 0], rate=sample_rate)

In [None]:
plot_wave(result[0, 0, :2048].numpy(), sample_rate)

In [None]:
fft = np.fft.rfft(result[0, 0].numpy())
plot_fft(np.abs(fft), sample_rate)

In [None]:
bok = torch.rand(16000) * 2.0 - 1.0
shit = torch.cat(4 * [bok])
Audio(shit, rate=16000)