In [None]:
import sys
from pathlib import Path

try:
    sys.path.index(str(Path.cwd().parent / 'src'))
except ValueError:
    sys.path.insert(0, str(Path.cwd().parent / 'src'))

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [20, 10]
from IPython.display import Audio

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import librosa
from librosa.display import specshow
from librosa.filters import get_window
import torchcrepe
from einops import rearrange, parse_shape
import opt_einsum as oe

In [None]:
from models.modules.harmonic_oscillator import OscillatorBank
from models.modules.loss import MorletTransform
from models.modules.utils import get_frames, pad_audio

In [None]:
wave_length = int(16000 / (110*50))
dur = 1/(110*50)
t = np.linspace(0, dur, wave_length)

In [None]:
dur, wave_length

In [None]:
plt.plot(np.sin(2*np.pi*110*50*t))

In [None]:
def build_sin(f, frame_size=1024, sample_rate=16000):
    wave_length = int(sample_rate/f)
    n_samples = wave_length + frame_size
    dur = n_samples * (1/sample_rate)
    t = np.linspace(0, dur, n_samples)
    sin = np.sin(2 * np.pi * f * t)
    return sin

In [None]:
sin = build_sin(110*50)

In [None]:
plt.plot(sin[100:104])

In [None]:
frame = audio[0, :1024, 0]

In [None]:
plt.plot(frame)

In [None]:
corre = np.correlate(sin, frame, mode='valid')

In [None]:
plt.plot(corre)

In [None]:
sin.shape, frame.shape, corre.shape, sin.shape[0] - frame.shape[0]

In [None]:
corre.max() / (1024/2)

In [None]:
dist[0, 0, 0, 49]

In [None]:
shit = []
for i in range(1, 65):
    sin = build_sin(110.0 * i)
    val = np.correlate(sin, frame, mode='valid').max() / 512
    shit.append(val)

In [None]:
plt.plot(np.array(shit))

In [None]:
plt.plot(dist[0, 0, 0])

In [None]:
plt.plot(dist[0, 0, 0] / np.array(shit))

In [None]:
n_harmonics = 64
n_noise = 64
sample_rate = 16000
hop_length = 64
win_length = 1024
f0 = 110
dur = 250
batch_size = 4
ch = 2

In [None]:
osc = OscillatorBank(n_harmonics, sample_rate, hop_length)

In [None]:
class MorletTransform(nn.Module):
    def __init__(self, sample_rate, win_length, n_harmonics, half_bandwidth=1.0):
        super().__init__()
        self.sample_rate = sample_rate
        self.win_length = win_length
        n = torch.arange(win_length, dtype=torch.float32)
        k = torch.arange(1, n_harmonics + 1, dtype=torch.float32)
        self.register_buffer("n", n)
        self.register_buffer("k", k)
        self.tp = 1.0 / half_bandwidth

    def generate_morlet_matrix(self, f0):
        # f0.shape = [batch, time, ch]
        tp = self.tp * self.sample_rate
        fc = oe.contract("btc,k->btck", f0, self.k, backend="torch") / self.sample_rate
        fc_n = oe.contract("btck,n->btckn", fc, self.n, backend="torch")

        normalizer = (1 / np.sqrt(np.pi * tp)).astype("float32")
        gauss = torch.exp(-((self.n - self.win_length // 2) ** 2) / tp)
        exp = torch.exp(-2j * np.pi * fc_n)
        result = normalizer * gauss * exp

        # Cut above nyquist
        result[fc > 0.5] = 0.0

        # result.shape = [batch, time, ch, n_harmonics, win_length]
        return result

    def forward(self, audio_frames, f0):
        # audio_frames.shape = [batch, time, ch, win_length]
        # f0.shape = [batch, time, ch]
        morlet = self.generate_morlet_matrix(f0)
        transform = oe.contract("btckn,btcn->btck", morlet, audio_frames.type(torch.complex64), backend="torch")
        transform = torch.abs(transform)
        amp = torch.sum(transform, dim=-1, keepdim=True)
        harmonic_distribution = transform / amp
        amp *= 2.0
        amp = torch.clip(amp, 0.0, 1.0).squeeze(-1)

        # harmonic_distribution.shape = [batch, time, ch, n_harmonics]
        # amp.shape = [batch, time, ch]
        return harmonic_distribution, amp

In [None]:
morlet_transform = MorletTransform(sample_rate, win_length, n_harmonics)

In [None]:
shit = np.array(shit)

In [None]:
shit.shape

In [None]:
with torch.no_grad():
#     dist = torch.abs(torch.randn(1, 1, ch, n_harmonics))
#     dist = torch.tile(dist, (batch_size, dur, 1, 1))
    dist = torch.from_numpy(shit.astype('float32'))
    dist = torch.tile(dist, (batch_size, dur, ch, 1))
    
    # dist = torch.zeros(1, dur, n_harmonics)
    # dist[..., 0] = 1.0
    
    # dist = 61. - torch.arange(1, 61)
    
    # dist = torch.ones(1, dur, 60)
    
    dist /= dist.sum(-1, keepdim=True)

    
    
#     freq = (torch.sin(torch.linspace(0, hop_length * dur / sample_rate, dur) * 3.14159265 * 0.5).unsqueeze(0).unsqueeze(-1) + 2) * f0 / 3
#     freq = torch.tile(freq, (batch_size, 1, ch))
    # freq[:, :, 1] += 3.
    
    # freq = torch.ones(1, dur, ch) * f0
    freq = torch.ones(batch_size, dur, ch) * f0
    amp = 1.0
    amps = torch.ones(batch_size, dur, ch) * amp
    # dist = torch.zeros(batch_size, dur, ch, n_harmonics)
    # dist[..., 0] = 1.
    # dist[..., 1] = 0.5
    # dist /= dist.sum(-1, keepdim=True)
    
    audio = osc(
        freq,
        amps,
        dist
    )

8.84 s ± 183 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [None]:
Audio(data=audio[0].T, rate=sample_rate, normalize=False)

In [None]:
Audio(data=audio[0].T, rate=sample_rate, normalize=False)

- TODO: during synthesis (maybe also stft transform) as higher overtones go above nyquist and dissappear during an upward glissando, remaining harmonics are normalized among
  themselves, so perceptually, the sound seems to get louder.
- TODO: given crepe pitch, learn inharmonicity factor by maximizing real sound's total energy in this new transform
- TODO: ampllitude doesn't seem right
- TODO: noise component transform

In [None]:
padded_audio = pad_audio(audio, win_length, hop_length)

In [None]:
framed_audio = get_frames(padded_audio, win_length, hop_length)

In [None]:
new_dist, new_amps = morlet_transform(framed_audio, freq)

In [None]:
new_amps.max(), new_dist.max(), amps.max(), dist.max()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(dist[0, :, 0, :].flip(1).T)
ax2.imshow(new_dist[0, :, 0, :].flip(1).T)

In [None]:
plt.plot(amps[0, :, 0])
plt.plot(new_amps[0, :, 0])

In [None]:
with torch.no_grad():
    # freq = torch.ones(2, dur, 1) * f0
    new_audio = osc(
        freq,
        new_amps,
        new_dist
    )

In [None]:
plt.bar(np.arange(len(new_dist[0, 128, 0])), new_dist[0, 128, 0])
plt.bar(np.arange(len(dist[0, 128, 0])), dist[0, 128, 0])

In [None]:
Audio(data=new_audio[0].T, rate=sample_rate, normalize=False)

In [None]:
path = '/home/kureta/Music/cello/Cello Samples/BrahmsSonata1-00002-.wav'

In [None]:
timbre_violin = librosa.load(path, sr=16000, mono=False)[0][:, :16000].T
violin = librosa.load(path, sr=sample_rate, mono=False)[0][:, :sample_rate].T

In [None]:
Audio(data=violin.T, rate=sample_rate, normalize=True)

In [None]:
# add batch dimension
padded_violin = pad_audio(torch.from_numpy(violin).unsqueeze(0), win_length, hop_length)
padded_timbre_violin = pad_audio(torch.from_numpy(timbre_violin).unsqueeze(0), 1024, hop_length)

In [None]:
b, c = padded_timbre_violin.shape[0], padded_timbre_violin.shape[2]

stacked_violin = rearrange(padded_timbre_violin, 'b t c -> (b c) t')
freqs = torchcrepe.predict(stacked_violin, 16000, hop_length, decoder=torchcrepe.decode.weighted_argmax, pad=False, device='cuda')
freqs = rearrange(freqs, '(b c) t -> b t c', b=b, c=c)

In [None]:
violin_frames = get_frames(padded_violin, win_length, hop_length)

In [None]:
morlet_transform = MorletTransform(sample_rate, win_length, n_harmonics, 1/3)

In [None]:
new_dist, amp = morlet_transform(violin_frames, freqs)

In [None]:
plt.imshow(new_dist[0, :100, 0].flip(1).T)

In [None]:
plt.plot(amp[0, :, 0])
plt.plot(amp[0, :, 1])

In [None]:
with torch.no_grad():
    new_audio = osc(
        freqs,
        amp,
        new_dist
    )

In [None]:
Audio(data=new_audio[0].T, rate=sample_rate, normalize=False)

# Filtered noise

- shape = [batch, time, ch, n_bands]

In [None]:
filter_bands = torch.ones(batch_size, dur, ch, n_noise)
filter_bands[:, :, :, 1] = 1.
filter_bands[:, dur//2:, :, 2] = 1.

In [None]:
base_noise = torch.rand(batch_size, dur*hop_length, ch) * 2 - 1
base_noise.min(), base_noise.max()

In [None]:
Audio(data=base_noise[0].T, rate=sample_rate, normalize=False)

## hop_length * 2 size windows for 50% overlap

In [None]:
padded_noise = pad_audio(base_noise, win_length, hop_length)
noise_frames = get_frames(padded_noise, win_length, hop_length)
windowed_noise_frames = noise_frames

## filter

In [None]:
real_filter = filter_bands.repeat_interleave(8, -1)
real_filter = torch.concat([torch.zeros(4, 250, 2, 1), real_filter], -1)

In [None]:
fft_noise = torch.fft.rfft(windowed_noise_frames)
filtered_noise_fft_frames = real_filter * fft_noise
filtered_noise_frames = torch.fft.irfft(filtered_noise_fft_frames) * torch.hann_window(win_length, periodic=False)

## overlap add

In [None]:
b, c = filtered_noise_frames.shape[0], filtered_noise_frames.shape[2]
stacked_noise = rearrange(filtered_noise_frames, 'b t c w -> (b c) w t')
filtered_noise = F.fold(stacked_noise, (1, padded_noise.shape[1]), (1, win_length), stride=(1, hop_length))
filtered_noise = rearrange(filtered_noise, '(b c) 1 1 t -> b t c', b=b, c=c)
filtered_noise = filtered_noise[:, hop_length:, :]
filtered_noise.max(), filtered_noise.min()

In [None]:
Audio(data=filtered_noise[0].T, rate=sample_rate, normalize=True)

In [None]:
torch.abs(filtered_noise_fft_frames).max()

# Extract noise

In [None]:
vln = torch.from_numpy(violin).unsqueeze(0)

In [None]:
new_audio.shape

In [None]:
padded_new_audio = pad_audio(new_audio, win_length, hop_length)
new_audio_frames = get_frames(padded_new_audio, win_length, hop_length)
windowed_new_audio_frames = new_audio_frames
fft_new_audio = torch.abs(torch.fft.rfft(windowed_new_audio_frames))

In [None]:
padded_vln = pad_audio(vln, win_length, hop_length)
vln_frames = get_frames(padded_vln, win_length, hop_length)
windowed_vln_frames = vln_frames
fft_vln = torch.abs(torch.fft.rfft(windowed_vln_frames))

In [None]:
librosa.display.specshow(fft_vln[0, :, 0, :].T.numpy())

In [None]:
librosa.display.specshow(fft_new_audio[0, :, 0, :].T.numpy())

In [None]:
vln_bands = rearrange(fft_vln[..., 1:], 'b t c (n g) -> b t c n g', n=64, g=8)
vln_bands = vln_bands.sum(-1)

new_bands = rearrange(fft_new_audio[..., 1:], 'b t c (n g) -> b t c n g', n=64, g=8)
new_bands = new_bands.sum(-1)

In [None]:
vln_bands.max(), new_bands.max(), (vln_bands - new_bands * 0.009947).min(), (vln_bands - new_bands*0.009947).max()

In [None]:
librosa.display.specshow((vln_bands - new_bands * 0.009947)[0, :, 0, :].T.numpy())

In [None]:
filter_bands = vln_bands - new_bands * 0.009947

In [None]:
filter_bands.min(), filter_bands.max()

In [None]:
filter_bands.shape

In [None]:
padded_noise = pad_audio(base_noise, win_length, hop_length)
noise_frames = get_frames(padded_noise, win_length, hop_length)
windowed_noise_frames = noise_frames

real_filter = filter_bands.repeat_interleave(8, -1)
real_filter = torch.concat([torch.zeros(1, 250, 2, 1), real_filter], -1)

fft_noise = torch.fft.rfft(windowed_noise_frames)
filtered_noise_fft_frames = real_filter * fft_noise
filtered_noise_frames = torch.fft.irfft(filtered_noise_fft_frames) * torch.hann_window(win_length, periodic=False)

b, c = filtered_noise_frames.shape[0], filtered_noise_frames.shape[2]
stacked_noise = rearrange(filtered_noise_frames, 'b t c w -> (b c) w t')
filtered_noise = F.fold(stacked_noise, (1, padded_noise.shape[1]), (1, win_length), stride=(1, hop_length))
filtered_noise = rearrange(filtered_noise, '(b c) 1 1 t -> b t c', b=b, c=c)
filtered_noise = filtered_noise[:, hop_length:, :]
filtered_noise.max(), filtered_noise.min()

In [None]:
Audio(data=filtered_noise[0, 512:-384].T / 4096 + new_audio[0].T, rate=sample_rate, normalize=True)

In [None]:
filter_bands.shape