In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import Audio
import numpy as np
from scipy.stats import norm

import torch
import torch.nn as nn
import torchaudio.functional as AF

from einops import rearrange

In [None]:
from src.datamodules.components.ddsp_dataset import DDSPDataset
from src.utils.crepe_loss import CrepeLoss

In [None]:
dataset = DDSPDataset('../data/cello.pth')

In [None]:
loudness = torch.cat([l['loudness'][0] for l in dataset.features]).numpy()

In [None]:
plt.rcParams['figure.figsize'] = [16, 8]

n, bins, patches = plt.hist(loudness, 128)
plt.title("Loudness Histogram")
plt.xlabel("Db")
plt.ylabel("Frequency")

l_min = loudness.min()
l_max = loudness.max()
mean = loudness.mean()
std = loudness.std()
start = mean - std
end = mean + std

plt.xticks([-70.0, -65.0, -60.0, -30.0, mean, l_min, l_max, start, end, start-std, end+std])
plt.grid(axis='x')

plt.axvline(x=l_min, linewidth=2, label=f'min={l_min:.2f}', color='k')
plt.axvline(x=l_max, linewidth=2, label=f'max={l_max:.2f}', color='k')
plt.axvline(x=mean, linewidth=2, label=f'mean={mean:.2f}', color='k', linestyle='dashed')
plt.axvline(x=start, linewidth=2, label=f'-sigma={start:.2f}', color='g', linestyle='dashed')
plt.axvline(x=end, linewidth=2, label=f'+sigma={end:.2f}', color='g', linestyle='dashed')
plt.axvline(x=start-std, linewidth=2, label=f'-2sigma={start-std:.2f}', color='y', linestyle='dashed')
plt.axvline(x=end+std, linewidth=2, label=f'+2sigma={end+std:.2f}', color='y', linestyle='dashed')

plt.legend(loc='upper left')
plt.show()

In [None]:
n, bins, patches = plt.hist(norm.cdf((loudness - mean) / std), 128)
plt.title("Loudness Histogram")
plt.xlabel("Normalized Db")
plt.ylabel("Frequency")

In [None]:
f0 = torch.cat([l['f0'][0] for l in dataset.features])

In [None]:
def bins_to_cents(bins):
    """Converts pitch bins to cents"""
    cents = 20 * bins + 1997.3794084376191

    # Trade quantization error for noise
    return cents

def cents_to_frequency(cents):
    """Converts cents to frequency in Hz"""
    return 10 * 2 ** (cents / 1200)

def freqs_to_cents(freq):
    return 1200 * torch.log2(freq / 10.)

def cents_to_bins(cents):
    return (cents - 1997.3794084376191) / 20

In [None]:
f0 = cents_to_bins(freqs_to_cents(f0)) / 359

In [None]:
f0 = f0.numpy()

In [None]:
plt.rcParams['figure.figsize'] = [16, 8]

n, bins, patches = plt.hist(f0, 360)
plt.title("F0 Histogram")
plt.xlabel("Normalized pitch")
plt.ylabel("Frequency")

l_min = f0.min()
l_max = f0.max()
mean = f0.mean()
std = f0.std()
start = mean - std
end = mean + std

plt.xticks([mean, l_min, l_max, start, end, start-std, end+std])
plt.grid(axis='x')

plt.axvline(x=l_min, linewidth=2, label=f'min={l_min:.2f}', color='k')
plt.axvline(x=l_max, linewidth=2, label=f'max={l_max:.2f}', color='k')
plt.axvline(x=mean, linewidth=2, label=f'mean={mean:.2f}', color='k', linestyle='dashed')
plt.axvline(x=start, linewidth=2, label=f'-sigma={start:.2f}', color='g', linestyle='dashed')
plt.axvline(x=end, linewidth=2, label=f'+sigma={end:.2f}', color='g', linestyle='dashed')
plt.axvline(x=start-std, linewidth=2, label=f'-2sigma={start-std:.2f}', color='y', linestyle='dashed')
plt.axvline(x=end+std, linewidth=2, label=f'+2sigma={end+std:.2f}', color='y', linestyle='dashed')

plt.legend(loc='upper left')
plt.show()

In [None]:
f, amp, audio = dataset[2]

In [None]:
audio = audio.unsqueeze(0)

In [None]:
def get_amp(example):
    b, c, _ = example.shape
    example = rearrange(example, "b c t -> (b c) t")
    example = torch.nn.functional.pad(example, (19200 // 2, 19200 // 2))
    example = example.unfold(1, 19200, 3 * 256)
    _, f, _ = example.shape
    example = rearrange(example, "(b c) f t -> (b f) c t", b=b, c=c, f=f)

    amp = AF.loudness(example, 48000)
    amp = rearrange(amp, "(b f) -> b f", b=b, f=f)

    return amp

In [None]:
shit = get_amp(audio)

In [None]:
audio.shape, shit.shape

In [None]:
plt.rcParams['figure.figsize'] = [8, 4]
plt.plot(torch.nan_to_num(shit[0], nan=-70))
plt.show()
plt.plot(f[0])
plt.show()
Audio(data=audio[0], rate=48000)

In [None]:
audios = torch.stack([f['audio'] for f in dataset.features])

In [None]:
audios.shape

In [None]:
amps = []
for i in range(0, 6075, 25):
    print(i)
    amps.append(get_amp(audios[i:i+25].cuda()))

In [None]:
amps[0].shape

In [None]:
shit = torch.cat([a.cpu() for a in amps], dim=0)

In [None]:
torch.nan_to_num_(shit, nan=100.)

In [None]:
shit[shit==100.] = -70.

In [None]:
shit.min(), shit.max()

In [None]:
idx = 8
plt.rcParams['figure.figsize'] = [8, 4]
plt.plot((shit[idx] + 70) / 70)
plt.show()
plt.plot(dataset.features[idx]['f0'][0])
plt.show()
Audio(data=audios[idx, 0], rate=48000)

In [None]:
dynamic_range = 70  # dB

In [None]:
frame_rate = 250
hop_size = 48000 // frame_rate
hop_size, 48000 * 5 // hop_size  # 5 seconds in samples

In [None]:
n_harmonics = 60 * 3
n_noise = 65 * 3
n_harmonics, n_noise

- normalize f0:
  - `f0 = cents_to_bins(freqs_to_cents(f0)) / 359`
- un-normalize f0:
  - `f0 = cents_to_freqs(bins_to_cents(f0 * 359))`
- normalize dB:
  - `db = (db + 70) / 70`