In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import Audio
import numpy as np
from scipy.stats import norm
import librosa

import torch
import torch.nn as nn
import torchaudio.functional as AF
import torchcrepe

from einops import rearrange
from pathlib import Path
import math

In [None]:
from performer.datamodules.components.ddsp_dataset import DDSPDataset
from performer.utils.constants import *
from performer.utils.helpers import freqs_to_cents, cents_to_bins
from performer.utils.features import Loudness, get_f0

In [None]:
from performer.models.components.harmonic_oscillator import HarmonicOscillator
from performer.models.components.controller import Controller, TransformerController

In [None]:
shit = HarmonicOscillator()

In [None]:
bok = Controller(64, 4)

In [None]:
kaka = TransformerController(n_harmonics=64, n_noise_filters=4)

In [None]:
f0 = torch.ones(1, 1, 250) * 440
f0 += torch.randn_like(f0) * 10.
loudness = torch.ones(1, 1, 250) * -40
loudness += torch.randn_like(loudness) * 12.

with torch.no_grad():
    (_, master, harms), _ = kaka(f0, loudness)

In [None]:
plt.plot(master[0, 0])

In [None]:
def modified_sigmoid(x):
    return 2 * torch.sigmoid(x) ** 2.3 + 1e-7

In [None]:
plt.plot(modified_sigmoid(torch.linspace(-1, 1, 512)))

In [None]:
wav_path = '/home/kureta/Music/Cello Samples/ArpAm-00000-.wav'
y1, _ = librosa.load(wav_path, sr=48000, mono=False, dtype='float32')
y2, _ = librosa.load(wav_path, sr=44100, mono=False, dtype='float32')
y2 = AF.resample(torch.from_numpy(y2), 44100, 48000)

y1.shape, y2.shape

In [None]:
dataset = DDSPDataset('../data/cello_samples.pth')

In [None]:
dataset.f0.shape, dataset.loudness.shape, dataset.audio.shape

In [None]:
dataset.audio.unfold(1, 48000*4, 48000*1).transpose(0, 1).shape

In [None]:
f0, amp, audio = dataset[13]
plt.plot(f0[0])
plt.show()
plt.plot(amp[0])
plt.show()
Audio(data=audio, rate=48000)

In [None]:
loudness = torch.cat([l['loudness'][0] for l in dataset.features]).numpy()

In [None]:
loudness.min(), loudness.max()

In [None]:
calc = Loudness().cuda()

In [None]:
audios = torch.stack([l['audio'] for l in dataset.features]).cuda()

In [None]:
audios.shape

In [None]:
f0s = []
for wav in audios:
    f0s.append(get_f0(wav.unsqueeze(0).cuda()))

In [None]:
freqs = torch.cat(f0s)

In [None]:
plt.plot(freqs[0, 0].cpu().numpy())

In [None]:
amps = []
for i in range(0, 6075, 25):
    amps.append(calc.get_amp(audios[i:i+25].cuda()))

In [None]:
loudness = torch.cat(amps)

In [None]:
loudness = loudness.cpu().numpy()

In [None]:
loudness.max() - loudness.min(), loudness.min(), loudness.max()

In [None]:
loudness = loudness.cpu().numpy().flatten()

In [None]:
plt.rcParams['figure.figsize'] = [16, 8]

n, bins, patches = plt.hist(loudness, 128)
plt.title("Loudness Histogram")
plt.xlabel("Db")
plt.ylabel("Frequency")

l_min = loudness.min()
l_max = loudness.max()
mean = loudness.mean()
std = loudness.std()
start = mean - std
end = mean + std

plt.xticks([mean, l_min, l_max, start, end, start-std, end+std])
plt.grid(axis='x')

plt.axvline(x=l_min, linewidth=2, label=f'min={l_min:.2f}', color='k')
plt.axvline(x=l_max, linewidth=2, label=f'max={l_max:.2f}', color='k')
plt.axvline(x=mean, linewidth=2, label=f'mean={mean:.2f}', color='k', linestyle='dashed')
plt.axvline(x=start, linewidth=2, label=f'-sigma={start:.2f}', color='g', linestyle='dashed')
plt.axvline(x=end, linewidth=2, label=f'+sigma={end:.2f}', color='g', linestyle='dashed')
plt.axvline(x=start-std, linewidth=2, label=f'-2sigma={start-std:.2f}', color='y', linestyle='dashed')
plt.axvline(x=end+std, linewidth=2, label=f'+2sigma={end+std:.2f}', color='y', linestyle='dashed')

plt.legend(loc='upper left')
plt.show()

In [None]:
n, bins, patches = plt.hist(norm.cdf((loudness - mean) / std), 128)
plt.title("Loudness Histogram")
plt.xlabel("Normalized Db")
plt.ylabel("Frequency")

In [None]:
f0 = torch.cat([l['f0'][0] for l in dataset.features])

In [None]:
def bins_to_cents(bins):
    """Converts pitch bins to cents"""
    cents = 20 * bins + 1997.3794084376191

    # Trade quantization error for noise
    return cents

def cents_to_frequency(cents):
    """Converts cents to frequency in Hz"""
    return 10 * 2 ** (cents / 1200)

def freqs_to_cents(freq):
    return 1200 * torch.log2(freq / 10.)

def cents_to_bins(cents):
    return (cents - 1997.3794084376191) / 20

In [None]:
f0 = cents_to_bins(freqs_to_cents(f0)) / 359

In [None]:
f0 = f0.numpy()

In [None]:
plt.rcParams['figure.figsize'] = [16, 8]

n, bins, patches = plt.hist(f0, 360)
plt.title("F0 Histogram")
plt.xlabel("Normalized pitch")
plt.ylabel("Frequency")

l_min = f0.min()
l_max = f0.max()
mean = f0.mean()
std = f0.std()
start = mean - std
end = mean + std

plt.xticks([mean, l_min, l_max, start, end, start-std, end+std])
plt.grid(axis='x')

plt.axvline(x=l_min, linewidth=2, label=f'min={l_min:.2f}', color='k')
plt.axvline(x=l_max, linewidth=2, label=f'max={l_max:.2f}', color='k')
plt.axvline(x=mean, linewidth=2, label=f'mean={mean:.2f}', color='k', linestyle='dashed')
plt.axvline(x=start, linewidth=2, label=f'-sigma={start:.2f}', color='g', linestyle='dashed')
plt.axvline(x=end, linewidth=2, label=f'+sigma={end:.2f}', color='g', linestyle='dashed')
plt.axvline(x=start-std, linewidth=2, label=f'-2sigma={start-std:.2f}', color='y', linestyle='dashed')
plt.axvline(x=end+std, linewidth=2, label=f'+2sigma={end+std:.2f}', color='y', linestyle='dashed')

plt.legend(loc='upper left')
plt.show()

In [None]:
def get_amp(example):
    b, c, _ = example.shape
    example = rearrange(example, "b c t -> (b c) t")
    example = torch.nn.functional.pad(example, (19200 // 2, 19200 // 2))
    example = example.unfold(1, 19200, HOP_LENGTH)
    _, f, _ = example.shape
    example = rearrange(example, "(b c) f t -> (b f) c t", b=b, c=c, f=f)

    amp = AF.loudness(example, SAMPLE_RATE)
    amp = rearrange(amp, "(b f) -> b f", b=b, f=f).unsqueeze(1)

    return amp


def get_pitch(x, device='cuda'):
    # to mono
    x = AF.resample(x.mean(1), SAMPLE_RATE, CREPE_SAMPLE_RATE)
    f0 = torchcrepe.predict(x,
                            sample_rate=CREPE_SAMPLE_RATE,
                            hop_length=CREPE_HOP_LENGTH,
                            fmin=31.7,
                            decoder=torchcrepe.decode.weighted_argmax,
                            device=device, return_periodicity=False).unsqueeze(1)

    return f0

In [None]:
f, amp, audio = dataset[0]

In [None]:
audio = audio.unsqueeze(0)

In [None]:
shit = get_amp(audio)

In [None]:
bok = get_pitch(audio)

In [None]:
bok.shape, shit.shape

In [None]:
s = librosa.stft(audio[0].mean(0).numpy(), n_fft=N_FFT, hop_length=HOP_LENGTH, pad_mode='reflect').T
print(s.shape)
# Compute power.
amplitude = np.abs(s)
power = amplitude**2

frequencies = librosa.fft_frequencies(sr=SAMPLE_RATE, n_fft=N_FFT)
a_weighting = librosa.A_weighting(frequencies)[None, :]
weighting = 10**(a_weighting/10)
power = power * weighting

power = np.mean(power, axis=-1)
# loudness = np.log(power*100 + 1)
loudness = librosa.power_to_db(power)

In [None]:
batch = torch.stack([dataset[0][2], dataset[1][2], dataset[2][2], dataset[3][2]])

In [None]:
s = torch.stft(batch.mean(1), n_fft=N_FFT, hop_length=HOP_LENGTH, window=torch.hann_window(N_FFT), return_complex=True, pad_mode='reflect')
s = s.transpose(1, 2)

# Compute power.
amplitude = torch.abs(s)
power = amplitude**2

frequencies = torch.from_numpy(librosa.fft_frequencies(sr=SAMPLE_RATE, n_fft=N_FFT).astype('float32'))
a_weighting = torch.from_numpy(librosa.A_weighting(frequencies)[None, None, :].astype('float32'))
weighting = 10**(a_weighting/10)
power = power * weighting

power = torch.mean(power, axis=-1)
torchness = 10.0 * np.log10(np.maximum(1e-10, power))
torchness = np.maximum(torchness, torchness.max() - 80.)

In [None]:
audio.shape

In [None]:
loudness.shape, loudness.min(), loudness.max(), torchness.shape, torchness.min(), torchness.max()

In [None]:
torchness.dtype

In [None]:
def normalize(x):
    return (x - x.min()) / (x.max() - x.min())

In [None]:
plt.rcParams['figure.figsize'] = [8, 4]
plt.plot((torch.nan_to_num(shit[0, 0], nan=-70)))
plt.plot((loudness))
plt.plot((torchness[0]))
plt.show()
plt.plot(bok[0, 0])
plt.show()
Audio(data=audio[0], rate=48000)

In [None]:
plt.rcParams['figure.figsize'] = [8, 4]
plt.plot(torch.nan_to_num(shit[0], nan=-70))
plt.show()
plt.plot(f[0])
plt.show()
Audio(data=audio[0], rate=48000)

In [None]:
audios = torch.stack([f['audio'] for f in dataset.features])

In [None]:
audios.shape

In [None]:
amps = []
for i in range(0, 6075, 25):
    print(i)
    amps.append(get_amp(audios[i:i+25].cuda()))

In [None]:
amps[0].shape

In [None]:
shit = torch.cat([a.cpu() for a in amps], dim=0)

In [None]:
torch.nan_to_num_(shit, nan=100.)

In [None]:
shit[shit==100.] = -70.

In [None]:
shit.min(), shit.max()

In [None]:
idx = 8
plt.rcParams['figure.figsize'] = [8, 4]
plt.plot((shit[idx] + 70) / 70)
plt.show()
plt.plot(dataset.features[idx]['f0'][0])
plt.show()
Audio(data=audios[idx, 0], rate=48000)

In [None]:
dynamic_range = 70  # dB

In [None]:
frame_rate = 250
hop_size = 48000 // frame_rate
hop_size, 48000 * 5 // hop_size  # 5 seconds in samples

In [None]:
n_harmonics = 60 * 3
n_noise = 65 * 3
n_harmonics, n_noise

- normalize f0:
  - `f0 = cents_to_bins(freqs_to_cents(f0)) / 359`
- un-normalize f0:
  - `f0 = cents_to_freqs(bins_to_cents(f0 * 359))`
- normalize dB:
  - `db = (db + 70) / 70`

In [None]:
x = AF.resample(dataset.features[0]['audio'].unsqueeze(0).mean(1), SAMPLE_RATE, CREPE_SAMPLE_RATE)
f0 = torchcrepe.predict(x,
                        sample_rate=16000,
                        hop_length=CREPE_HOP_LENGTH,
                        fmin=31.7,
                        decoder=torchcrepe.decode.weighted_argmax,
                        device='cuda', return_periodicity=False)

In [None]:
f0.shape

In [None]:
plt.rcParams['figure.figsize'] = [12, 3]
plt.plot(cents_to_bins(freqs_to_cents(f0[0]).cpu().numpy()))
plt.show()
# plt.matshow(prod[0].cpu().numpy(), origin='lower')
# plt.show()
Audio(data=dataset.features[0]['audio'].numpy(), rate=48000)

In [None]:
tt = np.random.rand(500)

In [None]:
np.log((tt+1).mean()), np.log(tt + 1).mean()