In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

import librosa
from librosa.display import specshow, waveshow

from IPython.display import Audio

from pathlib import Path

from einops import rearrange

In [None]:
from models.modules.dsp import HarmonicOscillator, FilteredNoise
from models.modules.loss import Loudness, CrepeFeaturesAndCents
from models.modules.utils import pad_audio, get_frames
from models.modules.crepe import cents_to_frequency, bins_to_cents, PITCH_BINS
from models.modules.controller import Controller

In [None]:
sample_rate = 16000
window_length = 1024
hop_length = 64
n_harmonics = 64
n_bands = 128
n_channels = 2
time_steps = 1000
batch_size = 8

In [None]:
osc = HarmonicOscillator(sample_rate, hop_length, n_harmonics, n_channels)

In [None]:
base_pitch_1 = torch.linspace(55, 440, time_steps)
base_pitch_2 = torch.linspace(55.1, 441, time_steps)
base_pitch = torch.concat([base_pitch_1[None, :, None], base_pitch_2[None, :, None]], -1)
base_pitch = torch.tile(base_pitch, (batch_size, 1, 1))
amplitude = torch.ones(batch_size, time_steps, n_channels)

harmonic_distribution = []
for _ in range(n_harmonics):
    envelop_1 = torch.sin(torch.linspace(0, np.random.uniform() * 40, time_steps)) + 1
    envelop_2 = torch.sin(torch.linspace(0, np.random.uniform() * 40, time_steps)) + 1
    envelop = torch.concat([envelop_1[None, :, None, None], envelop_2[None, :, None, None]], 2)
    harmonic_distribution.append(envelop)

harmonic_distribution = torch.concat(harmonic_distribution, dim=-1)
harmonic_distribution = torch.tile(harmonic_distribution, (batch_size, 1, 1, 1))

In [None]:
with torch.no_grad():
    y = osc(base_pitch, amplitude, harmonic_distribution)

In [None]:
y_np = y[0].numpy().T

In [None]:
waveshow(y_np, sr=sample_rate)

In [None]:
Audio(y_np, rate=sample_rate, normalize=False)

In [None]:
noise = FilteredNoise(sample_rate, window_length, hop_length, n_bands, n_channels)

In [None]:
filter_bands = []
for _ in range(n_bands):
    envelop_1 = torch.cos(torch.linspace(0, np.random.uniform() * 40, time_steps)) + 1
    envelop_2 = torch.cos(torch.linspace(0, np.random.uniform() * 40, time_steps)) + 1
    envelop = torch.concat([envelop_1[None, :, None, None], envelop_2[None, :, None, None]], 2) / (2 ** 5)
    filter_bands.append(envelop)

filter_bands = torch.concat(filter_bands, dim=-1)

In [None]:
y = noise(filter_bands)

In [None]:
y_np = y[0, :].numpy().T

In [None]:
waveshow(y_np, sr=sample_rate)

In [None]:
Audio(y_np, rate=sample_rate, normalize=False)

## Prepare data

In [None]:
rms = Loudness(window_length)

In [None]:
sample_path = Path('/home/kureta/Music/cello/Cello Samples/BrahmsSonata1-00110-.wav')

In [None]:
np_audio, _ = librosa.load(sample_path, sr=sample_rate, mono=False)

In [None]:
Audio(np_audio, rate=sample_rate, normalize=False)

In [None]:
audio = np_audio.T
audio = torch.from_numpy(audio[None, :, :])
audio = pad_audio(audio, window_length, hop_length, strict=False)

In [None]:
frames = get_frames(audio, window_length, hop_length)

In [None]:
# flatten frames
batch, n_frames, n_channels, window_length = frames.shape
flat_frames = rearrange(frames, 'b f c w -> (b f c) w')

In [None]:
# calcualte rms
loudness = rms(flat_frames)

In [None]:
# unflatten loudness
loudness = rearrange(loudness, '(b f c) -> b f c', b=batch, f=n_frames, c=n_channels)

In [None]:
plt.plot(loudness[0])
plt.show()

In [None]:
crepe = CrepeFeaturesAndCents()

In [None]:
cents, features = crepe(flat_frames)

In [None]:
cents = rearrange(cents, '(b f c) 1 -> b f c', b=batch, f=n_frames, c=n_channels)
features = rearrange(features, '(b f c) x y 1 -> b f c (x y)', b=batch, f=n_frames, c=n_channels)

In [None]:
plt.plot(cents[0])
plt.ylim(5200, 7000)
plt.show()

In [None]:
f0 = cents_to_frequency(cents)
amps = loudness * 5.0
overtones = torch.zeros(*f0.shape, n_harmonics)
overtones[:, :, :, 0] = 1.0

In [None]:
y = osc(f0, amps, overtones)

In [None]:
y_np = y[0].numpy().T

In [None]:
Audio(y_np, rate=sample_rate, normalize=False)

In [None]:
audio = np_audio.T
audio = torch.from_numpy(audio[None, :, :])
audio = pad_audio(audio, sample_rate*2, sample_rate, strict=False)

In [None]:
batches = get_frames(audio, sample_rate*2, sample_rate)[0].transpose(1, 2)

In [None]:
frames = get_frames(batches, window_length, hop_length)

In [None]:
ctrl = Controller(n_harmonics, n_bands)

In [None]:
frames = frames.cuda()

In [None]:
ctrl = ctrl.cuda()

In [None]:
rms = rms.cuda()

In [None]:
crepe = crepe.cuda()

In [None]:
osc = osc.cuda()

In [None]:
# flatten frames
batch, n_frames, n_channels, window_length = frames.shape
flat_frames = rearrange(frames, 'b f c w -> (b f c) w')

In [None]:
# calcualte rms
loudness = rms(flat_frames)

In [None]:
# unflatten loudness
loudness = rearrange(loudness, '(b f c) -> b f c', b=batch, f=n_frames, c=n_channels)

In [None]:
cents, features = crepe(flat_frames)

In [None]:
cents = rearrange(cents, '(b f c) 1 -> b f c', b=batch, f=n_frames, c=n_channels)
features = rearrange(features, '(b f c) x y 1 -> b f c (x y)', b=batch, f=n_frames, c=n_channels)

In [None]:
pitch = (cents - bins_to_cents(0)) / bins_to_cents(PITCH_BINS-1)

In [None]:
optimizer = torch.optim.SGD(ctrl.parameters(), lr=0.001, momentum=0.9)

In [None]:
pitch.shape, loudness.shape

In [None]:
flat_pitch = rearrange(pitch, 'b t c -> (b c) t 1')
flat_loudness = rearrange(loudness, 'b t c -> (b c) t 1')

In [None]:
f0 = cents_to_frequency(cents)

In [None]:
for _ in range(100):
    # Zero your gradients for every batch!
    optimizer.zero_grad()

    # Make predictions for this batch
    (_, amps, overtones), _ = ctrl(flat_pitch, flat_loudness)
    amps = rearrange(amps, '(b c) t 1 -> b t c', c=2)
    overtones = rearrange(overtones, '(b c) t o -> b t c o', c=2)
    sound = osc(f0, amps, overtones)
    sound = pad_audio(sound, window_length, hop_length)
    
    p_frames = get_frames(sound, window_length, hop_length)
    p_flat_frames = rearrange(p_frames, 'b f c w -> (b f c) w')
    
    _, p_features = crepe(p_flat_frames)
    p_loudness = rms(flat_frames)
    p_loudness = rearrange(p_loudness, '(b f c) -> b f c', b=8, c=2)
    p_features = rearrange(p_features, '(b f c) x y 1 -> b f c (x y)', b=8, c=2)

    # Compute the loss and its gradients
    feature_loss = F.mse_loss(features, p_features)
    loudness_loss = F.mse_loss(loudness, p_loudness)
    loss = feature_loss + loudness_loss
    loss.backward()

    # Adjust learning weights
    optimizer.step()
    
    print(loss.item())