In [None]:
import sys
from pathlib import Path

try:
    sys.path.index(str(Path.cwd().parent / 'src'))
except ValueError:
    sys.path.insert(0, str(Path.cwd().parent / 'src'))

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [20, 10]
from IPython.display import Audio

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import librosa
from librosa.display import specshow
from librosa.filters import get_window
import torchcrepe
from einops import rearrange, parse_shape
import opt_einsum as oe

In [None]:
from models.modules.harmonic_oscillator import OscillatorBank
from models.modules.loss import MorletTransform, STFT
from models.modules.utils import pad_audio, get_frames, pad_audio_basic

In [None]:
n_harmonics = 256
sample_rate = 16000 * 3
hop_length = 64 * 3
win_length = 1024 * 3
f0 = 110.
dur = 1000
batch_size = 4
ch = 2

In [None]:
osc = OscillatorBank(n_harmonics, sample_rate, hop_length)

In [None]:
# morlet_transform = MorletTransform(sample_rate, win_length, n_harmonics)
stft_transform = STFT(sample_rate, win_length, n_harmonics)

In [None]:
with torch.no_grad():
    dist = torch.abs(torch.randn(1, 1, ch, n_harmonics))
    dist = torch.tile(dist, (batch_size, dur, 1, 1))
    
    # dist = torch.zeros(1, dur, n_harmonics)
    # dist[..., 0] = 1.0
    
    # dist = 61. - torch.arange(1, 61)
    
    # dist = torch.ones(1, dur, 60)
    
    dist /= dist.sum(-1, keepdim=True)

    amp = 1.0
    
    freq = (torch.sin(torch.linspace(0, hop_length * dur / sample_rate, dur) * 3.14159265 * 0.5).unsqueeze(0).unsqueeze(-1) + 2) * f0 / 3
    freq = torch.tile(freq, (batch_size, 1, ch))
    freq[:, :, 1] += 3.
    
    # freq = torch.ones(1, dur, ch) * f0
    audio = osc(
        freq,
        torch.ones(batch_size, dur, ch) * amp,
        dist
    )

In [None]:
Audio(data=audio[0].T, rate=sample_rate, normalize=False)

- TODO: during synthesis (maybe also stft transform) as higher overtones go above nyquist and dissappear during an upward glissando, remaining harmonics are normalized among
  themselves, so perceptually, the sound seems to get louder.
- TODO: given crepe pitch, learn inharmonicity factor by maximizing real sound's total energy in this new transform
- TODO: ampllitude doesn't seem right
- TODO: noise component transform

In [None]:
padded_audio = pad_audio_basic(audio, win_length, hop_length)

In [None]:
framed_audio = get_frames(padded_audio, win_length, hop_length)

In [None]:
new_dist, amp = stft_transform(framed_audio, freq)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(dist[0, :, 1, :].flip(1).T)
ax2.imshow(new_dist[0, :, 1, :].flip(1).T)

In [None]:
with torch.no_grad():
    # freq = torch.ones(2, dur, 1) * f0
    new_audio = osc(
        freq,
        amp * np.sqrt(2),
        new_dist
    )

In [None]:
Audio(data=new_audio[0].T, rate=sample_rate, normalize=False)

In [None]:
path = '/home/kureta/Music/cello/Cello Samples/BrahmsSonata1-00002-.wav'

In [None]:
timbre_violin = librosa.load(path, sr=16000, mono=False)[0][:, :16000*4].T
violin = librosa.load(path, sr=sample_rate, mono=False)[0][:, :sample_rate*4].T

In [None]:
Audio(data=violin.T, rate=sample_rate, normalize=False)

In [None]:
# add batch dimension
padded_violin = pad_audio_basic(torch.from_numpy(violin).unsqueeze(0), win_length, hop_length)
padded_timbre_violin = pad_audio_basic(torch.from_numpy(timbre_violin).unsqueeze(0), 1024, hop_length//3)

In [None]:
b, c = padded_timbre_violin.shape[0], padded_timbre_violin.shape[2]

In [None]:
stacked_violin = rearrange(padded_timbre_violin, 'b t c -> (b c) t')
freqs = torchcrepe.predict(stacked_violin, 16000, hop_length//3, decoder=torchcrepe.decode.weighted_argmax, pad=False, device='cuda')
freqs = rearrange(freqs, '(b c) t -> b t c', b=b, c=c)

In [None]:
violin_frames = get_frames(padded_violin, win_length, hop_length)

In [None]:
new_dist, amp = stft_transform(violin_frames, freqs)

In [None]:
plt.imshow(new_dist[0, :100, 0].flip(1).T)

In [None]:
plt.plot(amp[0, :, 0])
plt.plot(amp[0, :, 1])

In [None]:
with torch.no_grad():
    new_audio = osc(
        freqs,
        amp * np.sqrt(2),
        new_dist
    )

In [None]:
Audio(data=new_audio[0].T, rate=sample_rate, normalize=False)

In [None]:
# noise bands center bin and bandwidth calculations
for n in range(100):
   #  print(80 * n, 80 * n + 40, 80 * n + 80)
    pass