In [None]:
import torch
import torch.nn as nn
from performer.models.ddsp_module import DDSP
from performer.datamodules.components.ddsp_dataset import DDSPDataset
from performer.utils.features import Loudness, get_f0
import numpy as np
import torch.nn.functional as F
import pytorch_lightning as pl

import matplotlib.pyplot as plt
from IPython.display import Audio

import librosa
import soundfile as sf

In [None]:
vln_ckpt = '../checkpoints/violin_longrun.ckpt'
vlc_ckpt = '../checkpoints/cello_longrun.ckpt'
flt_ckpt = '../checkpoints/flute_longrun.ckpt'

In [None]:
with torch.inference_mode():
    model = DDSP.load_from_checkpoint(flt_ckpt, map_location='cuda')
    model = model.to('cuda')
    model.eval()
    pass

In [None]:
dataset = DDSPDataset("../data/flute_samples.pth")

In [None]:
f0, amp, _ = dataset[10]
f0, amp = f0.cuda(), amp.cuda()

In [None]:
with torch.inference_mode():
    y = model(f0[None, ...], amp[None, ...])

In [None]:
Audio(y[0].cpu(), rate=48000)

In [None]:
with torch.inference_mode():
    harmonic_controls, noise_controls = model.controller(f0[None, ...], amp[None, ...])
    harmonics = model.harmonics(*harmonic_controls)
    noise = model.noise(noise_controls)

In [None]:
Audio(harmonics[0].cpu(), rate=48000)

In [None]:
Audio(noise[0].cpu(), rate=48000)

In [None]:
Audio(harmonics[0].cpu() + noise[0].cpu(), rate=48000)

In [None]:
Audio(model.reverb.ir.detach().squeeze().cpu()[1:], rate=48000)

In [None]:
model.reverb.ir.detach().squeeze().cpu()[1:].shape

In [None]:
plt.plot(model.reverb.ir.detach().squeeze().cpu()[1:][0])

In [None]:
def adsr(ts: float, a: float, d: float, s: float, r:float):
    attack = torch.linspace(0., 1., int(a * 250))
    decay = torch.linspace(0., 1., int(d * 250))
    sustain = torch.ones(int(ts * 250)) * s
    release = torch.linspace(s, 0., int(r * 250))

    env = torch.cat([attack, decay, sustain, release])[None, None, :].cuda()
    
    return env

In [None]:
def sin(ts: float, f: float):
    t = torch.arange(int(ts * 250), dtype=torch.float32, device='cuda') / 250
    result = torch.sin(2 * np.pi * f * t)
    
    return result

In [None]:
# amp = 60 * adsr(2, 0.01, 0.02, 0.5, 1) - 80 + 2 * sin(3.03, 4)
# amp += torch.randn_like(amp) * 0.1
f0 = torch.ones_like(amp, device='cuda') * 880*2

In [None]:
with torch.inference_mode():
    y = model(f0, amp)

_y = y.cpu().numpy().squeeze()

Audio(data=_y, rate=48000)

In [None]:
def export_controller_jit(ckpt, name, map_location='cuda'):
    with torch.inference_mode():
        model = DDSP.load_from_checkpoint(ckpt, map_location=map_location)
        model = model.to(map_location)
        model.eval()

        ctrl = model.controller
        ctrl_jit = torch.jit.script(ctrl)
        ctrl_jit.save(f'../out/{name}-{map_location}-controller.pt')

        harmonics = model.harmonics
        harmonics_jit = torch.jit.script(harmonics)
        harmonics_jit.save(f'../out/{name}-{map_location}-harmonics.pt')

        noise = model.noise
        noise_jit = torch.jit.script(noise)
        noise_jit.save(f'../out/{name}-{map_location}-noise.pt')

        ir = model.reverb.ir.cpu().numpy()[:, 0, :]
        ir = np.concatenate([np.ones((2, 1), dtype='float32'), ir], axis=1)

        sf.write(f'../out/{name}-ir.wav', ir.T, 48000, subtype='PCM_24')

In [None]:
export_controller_jit(vln_ckpt, 'violin', 'cuda')

In [None]:
model.reverb.ir.shape

In [None]:
Audio(data=model.reverb.ir[:, 0, :], rate=48000)

In [None]:
dur = 2
overtones = 1 / torch.arange(1, 181)
overtones = overtones[None, None, :, None]
overtones = overtones.repeat(1, 1, 1, dur)
amp = torch.ones(1, 1, dur) * -40.0
f0 = torch.ones(1, 1, dur) * 110.0

y = model.harmonics.forward(f0, amp, overtones)
Audio(data=y[0], rate=48000)

In [None]:
y = model.harmonics.forward_live(f0, amp, overtones)
w = model.harmonics.forward_live(f0, amp, overtones)
Audio(data=torch.cat((y, w), dim=-1)[0], rate=48000)

In [None]:
ctrl = torch.jit.load('cello_controller.pt')

In [None]:
f0 = torch.ones(1, 1, 1, device='cpu') * 440.
amp = torch.ones(1, 1, 1, device='cpu') * -32.
hidden = torch.ones(3, 1, 1, device='cpu')

In [None]:
p, n, h, = ctrl.forward_live(f0, amp, hidden)


In [None]:
vln = DDSP().load_from_checkpoint(vln_ckpt, map_location='cuda').cuda()
vln.eval()
pass

In [None]:
ctrl = vln.controller

In [None]:
x = torch.ones(1, 8, 1024, device='cuda')
with torch.inference_mode():
    y, h = ctrl.gru(x)

print(h.shape)

In [None]:
f0 = torch.ones(1, 1, 128, device='cuda') * 440.
amp = torch.ones(1, 1, 128, device='cuda') * -32.
hidden = torch.ones(3, 1, 512, device='cuda')

with torch.inference_mode():
    h_ctrl, n_ctrl, hidden = ctrl.forward_live(f0, amp, hidden)

for param in h_ctrl:
    print(param.shape)
print(n_ctrl.shape)
print(hidden.shape)

In [None]:
shits = torch.jit.script(ctrl)

In [None]:
with torch.inference_mode():
        h_ctrl, n_ctrl, hidden = shits.forward_live(f0, amp, hidden)

In [None]:
shits.save('violin_controller.pt')

In [None]:
bokumsel = torch.jit.load('violin_controller.pt')

In [None]:
with torch.inference_mode():
        h_ctrl, n_ctrl, hidden = bokumsel.forward_live(f0, amp, hidden)

In [None]:
dataset = DDSPDataset("../data/violin_samples.pth")

In [None]:
f0_, amp_, audio_ = dataset[0]

In [None]:
amp_.shape

In [None]:
plt.plot(amp_[0].cpu().numpy())
plt.plot(amp[0, 0].cpu().numpy())
plt.show()

In [None]:
vlc = DDSP().load_from_checkpoint(vlc_ckpt, map_location='cuda').cuda()
vlc.eval()
pass

In [None]:
def sec_to_ctrl(sec: int):
    return int(48000 * sec / 192 + 1)

def midi_to_freq(m):
    return 440. * 2 ** ((m - 69) / 12)

In [None]:
sec = 2
ctrl = sec_to_ctrl(sec)
f0s = []
amps = [torch.exp(-5. * torch.linspace(-sec/2., sec/2, ctrl) ** 2)[None, None, :] * 70. - 80.] * 8
for pitch in [36, 38, 40, 41, 43, 45, 47, 48]:
    f = torch.ones(1, 1, ctrl) * midi_to_freq(pitch + 36)
    f += torch.sin(torch.linspace(0., sec, ctrl) * 2 * 3.14159265 * 4) * (midi_to_freq(pitch+0.25) - midi_to_freq(pitch))
    f0s.append(f)

silence = torch.linspace(amps[-1][0, 0, -1], -80., sec_to_ctrl(3))[None, None, :]
f_silence = torch.ones(1, 1, sec_to_ctrl(3)) * f0s[-1][0, 0, -1]
f0 = torch.cat(f0s + [f_silence], dim=-1).cuda()
amp = (torch.cat(amps + [silence], dim=-1) - 40.).cuda()

In [None]:
amp.shape

In [None]:
plt.plot(amp.cpu().numpy()[0][0])

In [None]:
with torch.no_grad():
    y = vln(f0, amp)

Audio(data=y[0].cpu().numpy(), rate=48000)

In [None]:
SAMPLE_RATE = 48000
CREPE_SAMPLE_RATE = 16000
SR_RATIO = SAMPLE_RATE // CREPE_SAMPLE_RATE
CREPE_N_FFT = 1024
N_FFT = 1024 * SR_RATIO

# TODO: FRAME_RATE should be adjustable but valid values depend on audio example duration
FRAME_RATE = 250
HOP_LENGTH = SAMPLE_RATE // FRAME_RATE
CREPE_HOP_LENGTH = HOP_LENGTH // SR_RATIO

In [None]:
class Preprocess:
    def __init__(self, device):
        self.ld = Loudness().to(device)
    
    def do(self, y):
        if (diff := len(y) % HOP_LENGTH) != 0:
            F.pad(y, (0, HOP_LENGTH - diff))
        
        audio = F.pad(y[None, None, :], (N_FFT // 2, N_FFT // 2))
        loudness = self.ld.get_amp(audio)
        f0 = get_f0(audio)
        
        return f0, loudness

In [None]:
preprocessor = Preprocess('cuda')

In [None]:
y, _ = librosa.load('/home/kureta/Music/Flute Samples/01. Air.wav', sr=48000, duration=10)

In [None]:
y = torch.from_numpy(y).cuda()

In [None]:
f0, amp = preprocessor.do(y)

In [None]:
plt.plot(amp.squeeze().cpu())

In [None]:
with torch.no_grad():
    y = vln(f0, amp)

Audio(data=y[0].cpu(), rate=48000)

In [None]:
with torch.no_grad():
    y = vlc_alt(f0, amp)

Audio(data=y[0].cpu(), rate=48000)

In [None]:
with torch.no_grad():
    y = vln(f0 * 2.0, amp)

Audio(data=y[0].cpu(), rate=48000)

In [None]:
from src.utils.multiscale_stft_loss import multiscale_stft, distance

In [None]:
y, _ = librosa.load('/home/kureta/Music/Cello Samples/SchummTrau-00003-.wav', sr=48000, mono=False)
y = torch.from_numpy(y)[None, ...]

In [None]:
distance(y, y)

In [None]:
ss = multiscale_stft(y, [4096, 2048, 1024, 512, 256, 128], 0.75)

In [None]:
ss[0].shape

In [None]:
plt.matshow(ss[-4][0])

In [None]:
ss[-4].shape