In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [17, 6]
from IPython.display import Audio

In [None]:
import numpy as np
import torch
import torch.fft as fft
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from typing import Dict, Tuple

In [None]:
import librosa
import librosa.display

In [None]:
class OscillatorBank(nn.Module):
    def __init__(self, batch_size=4, sample_rate=16000, n_harmonics=100, hop_size=512):
        super().__init__()

        self.n_harmonics = n_harmonics
        self.sample_rate = sample_rate
        self.hop_size = hop_size

        self.harmonics = nn.Parameter(
            torch.arange(1, self.n_harmonics + 1, step=1), requires_grad=False
        )
        self.last_phases = nn.Parameter(
            torch.rand(batch_size, n_harmonics) * 2. * np.pi - np.pi, requires_grad=False
        )

    def prepare_harmonics(self, f0: Tensor, harm_amps: Tensor) -> Tuple[Tensor, Tensor]:
        # Hz (cycles per second)
        harmonics = (
            self.harmonics.unsqueeze(0).unsqueeze(0).repeat(f0.shape[0], f0.shape[1], 1)
            * f0
        )
        # zero out above nyquist
        mask = harmonics > self.sample_rate // 2
        harm_amps = harm_amps.masked_fill(mask, 0.0)
        harm_amps /= harm_amps.sum(-1, keepdim=True)
        harmonics *= 2 * np.pi  # radians per second
        harmonics /= self.sample_rate  # radians per sample
        harmonics = self.rescale(harmonics)
        return harmonics, harm_amps

    @staticmethod
    def generate_phases(harmonics: Tensor) -> Tensor:
        phases = torch.cumsum(harmonics, dim=1)
        phases %= 2 * np.pi
        return phases

    def generate_signal(
        self, harm_amps: Tensor, loudness: Tensor, phases: Tensor
    ) -> Tensor:
        loudness = self.rescale(loudness)
        harm_amps = self.rescale(harm_amps)
        signal = loudness * harm_amps * torch.sin(phases)
        signal = torch.sum(signal, dim=2)
        return signal

    def rescale(self, x: Tensor) -> Tensor:
        return F.interpolate(
            x.permute(0, 2, 1),
            scale_factor=float(self.hop_size),
            mode='linear',
            align_corners=False,
        ).permute(0, 2, 1)

    def forward(self, x: Dict[str, Tensor]) -> Tensor:
        f0 = x['f0_hz']
        harm_amps = x['harmonic_distribution']
        loudness = x['amplitudes']
        
        harmonics, harm_amps = self.prepare_harmonics(f0, harm_amps)
        harmonics[:, 0, :] += self.last_phases  # phase offset from last sample
        phases = self.generate_phases(harmonics)
        self.last_phases[...] = phases[:, -1, :]  # update phase offset
        signal = self.generate_signal(harm_amps, loudness, phases)

        return signal

In [None]:
synth = OscillatorBank()

In [None]:
def static_sawtooth_features(fundamental_frequency: float,
                             base_amplitude: float,
                             n_harmonics: int = 30,
                             n_frames: int = 1000,
                             batch_size: int = 3) -> Dict[str, torch.Tensor]:
    amp = torch.zeros(batch_size, n_frames, 1) + base_amplitude

    harmonic_distribution = 1 / torch.arange(1, n_harmonics + 1)
    # harmonic_distribution = torch.ones(n_harmonics)  # impulse features
    harmonic_distribution = harmonic_distribution[None, None, :].repeat(batch_size, n_frames, 1)

    f0_hz = torch.zeros(batch_size, n_frames, 1) + fundamental_frequency

    return {
        'amplitudes': amp,
        'harmonic_distribution': harmonic_distribution,
        'f0_hz': f0_hz
    }

In [None]:
params = static_sawtooth_features(220.0, 1.0, 100, 500, 4)

In [None]:
synth.cuda()
for key, value in params.items():
    params[key] = value.cuda()

In [None]:
with torch.no_grad():
    signal = synth(params)
np_signal = signal[0].cpu().numpy()

In [None]:
params['f0_hz'].shape

In [None]:
rt_signals = []
with torch.no_grad():
    for idx in range(500):
        signal = synth({k: v[:, idx].unsqueeze(1) for k, v in params.items()})
        rt_signals.append(signal)
signal = torch.cat(rt_signals, dim=1)
np_signal = signal[0].cpu().numpy()

In [None]:
Audio(np_signal, rate=16000, normalize=False)

In [None]:
def amp_to_impulse_response(amp: Tensor, target_size: int) -> Tensor:
    amp = torch.stack([amp, torch.zeros_like(amp)], -1)
    amp = torch.view_as_complex(amp)
    amp = fft.irfft(amp)

    filter_size = int(amp.size(-1))

    amp = torch.roll(amp, filter_size // 2, -1)
    win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device)

    amp = amp * win

    amp = F.pad(amp, (0, int(target_size) - int(filter_size)))
    amp = torch.roll(amp, -filter_size // 2, -1)

    return amp


def fft_convolve(signal: Tensor, kernel: Tensor) -> Tensor:
    signal = F.pad(signal, (0, signal.shape[-1]))
    kernel = F.pad(kernel, (kernel.shape[-1], 0))

    output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel))
    output = output[..., output.shape[-1] // 2 :]

    return output


class FilteredNoise(nn.Module):
    def __init__(self, hop_size=512):
        super().__init__()
        self.block_size = hop_size

    def forward(self, x: Dict[str, Tensor]) -> Tensor:
        param = x['noise_bands']

        impulse = amp_to_impulse_response(param, self.block_size)
        noise = (
            torch.rand(
                impulse.shape[0],
                impulse.shape[1],
                self.block_size,
            ).to(impulse.device)
            * 2
            - 1
        )

        noise = fft_convolve(noise, impulse).contiguous()
        noise = noise.reshape(noise.shape[0], -1)

        return noise

In [None]:
noise_synth = FilteredNoise(512)

In [None]:
noise_params = torch.zeros(4, 20, 16000)
noise_params[:, :, 2] = 0.5
noise_params[:, :, 22] = 0.5
noise_params = dict(noise_bands=noise_params)

In [None]:
noise_signal = noise_synth(noise_params)
np_noise_signal = noise_signal[0].numpy()

In [None]:
Audio(np_noise_signal, rate=16000, normalize=False)

In [None]:
noise_spectrum = np.abs(librosa.stft(np_noise_signal))

In [None]:
librosa.display.specshow(librosa.amplitude_to_db(noise_spectrum), sr=16000)

In [None]:
results = []
for time_step in noise_params['noise_bands'].permute(1, 0, 2):
    time_step = time_step.unsqueeze(1)
    results.append(noise_synth({'noise_bands': time_step}))

In [None]:
result = torch.cat(results, dim=1)

In [None]:
np_result =result[0].numpy()

In [None]:
Audio(np_result, rate=16000)

In [None]:
np_result_spectrum = np.abs(librosa.stft(np_result))

In [None]:
librosa.display.specshow(librosa.amplitude_to_db(np_result_spectrum), sr=16000)

In [None]:
plt.plot(fft.rfft(result)[0].abs().numpy())