In [None]:
import torch
import torch.nn as nn
from src.models.ddsp_module import DDSP
from src.datamodules.components.ddsp_dataset import DDSPDataset
from src.utils.features import Loudness, get_f0
import numpy as np
import torch.nn.functional as F
import pytorch_lightning as pl

import matplotlib.pyplot as plt
from IPython.display import Audio

import librosa

In [None]:
vln_ckpt = '../checkpoints/violin_baseline.ckpt'

In [None]:
vlc_ckpt = '../checkpoints/cello_baseline.ckpt'

In [None]:
vln = DDSP().load_from_checkpoint(vln_ckpt, map_location='cuda').cuda()
vln.eval()
pass

In [None]:
dataset = DDSPDataset("../data/violin_samples.pth")

In [None]:
f0_, amp_, audio_ = dataset[0]

In [None]:
amp_.shape

In [None]:
plt.plot(amp_[0].cpu().numpy())
plt.plot(amp[0, 0].cpu().numpy())
plt.show()

In [None]:
vlc = DDSP().load_from_checkpoint(vlc_ckpt, map_location='cuda').cuda()
vlc.eval()
pass

In [None]:
def sec_to_ctrl(sec: int):
    return int(48000 * sec / 192 + 1)

def midi_to_freq(m):
    return 440. * 2 ** ((m - 69) / 12)

In [None]:
sec = 2
ctrl = sec_to_ctrl(sec)
f0s = []
amps = [torch.exp(-5. * torch.linspace(-sec/2., sec/2, ctrl) ** 2)[None, None, :] * 70. - 80.] * 8
for pitch in [36, 38, 40, 41, 43, 45, 47, 48]:
    f = torch.ones(1, 1, ctrl) * midi_to_freq(pitch + 36)
    f += torch.sin(torch.linspace(0., sec, ctrl) * 2 * 3.14159265 * 4) * (midi_to_freq(pitch+0.25) - midi_to_freq(pitch))
    f0s.append(f)

silence = torch.linspace(amps[-1][0, 0, -1], -80., sec_to_ctrl(3))[None, None, :]
f_silence = torch.ones(1, 1, sec_to_ctrl(3)) * f0s[-1][0, 0, -1]
f0 = torch.cat(f0s + [f_silence], dim=-1).cuda()
amp = (torch.cat(amps + [silence], dim=-1) - 40.).cuda()

In [None]:
amp.shape

In [None]:
plt.plot(amp.cpu().numpy()[0][0])

In [None]:
with torch.no_grad():
    y = vln(f0, amp)

Audio(data=y[0].cpu().numpy(), rate=48000)

In [None]:
SAMPLE_RATE = 48000
CREPE_SAMPLE_RATE = 16000
SR_RATIO = SAMPLE_RATE // CREPE_SAMPLE_RATE
CREPE_N_FFT = 1024
N_FFT = 1024 * SR_RATIO

# TODO: FRAME_RATE should be adjustable but valid values depend on audio example duration
FRAME_RATE = 250
HOP_LENGTH = SAMPLE_RATE // FRAME_RATE
CREPE_HOP_LENGTH = HOP_LENGTH // SR_RATIO

In [None]:
class Preprocess:
    def __init__(self, device):
        self.ld = Loudness().to(device)
    
    def do(self, y):
        if (diff := len(y) % HOP_LENGTH) != 0:
            F.pad(y, (0, HOP_LENGTH - diff))
        
        audio = F.pad(y[None, None, :], (N_FFT // 2, N_FFT // 2))
        loudness = self.ld.get_amp(audio)
        f0 = get_f0(audio)
        
        return f0, loudness

In [None]:
preprocessor = Preprocess('cuda')

In [None]:
y, _ = librosa.load('/home/kureta/Music/Violin Samples/yee_bach_theme#33.wav', sr=48000)

In [None]:
y = torch.from_numpy(y).cuda()

In [None]:
f0, amp = preprocessor.do(y)

In [None]:
with torch.no_grad():
    y = vln(f0, amp)

Audio(data=y[0].cpu(), rate=48000)

In [None]:
with torch.no_grad():
    y = vlc_alt(f0, amp)

Audio(data=y[0].cpu(), rate=48000)

In [None]:
with torch.no_grad():
    y = vln(f0 * 2.0, amp)

Audio(data=y[0].cpu(), rate=48000)

In [None]:
from src.utils.multiscale_stft_loss import multiscale_stft, distance

In [None]:
y, _ = librosa.load('/home/kureta/Music/Cello Samples/SchummTrau-00003-.wav', sr=48000, mono=False)
y = torch.from_numpy(y)[None, ...]

In [None]:
distance(y, y)

In [None]:
ss = multiscale_stft(y, [4096, 2048, 1024, 512, 256, 128], 0.75)

In [None]:
ss[0].shape

In [None]:
plt.matshow(ss[-4][0])

In [None]:
ss[-4].shape