In [None]:
from pathlib import Path

import IPython.display as ipd
import matplotlib.pyplot as plt
import torch
import torchaudio

from timbreremap.np import OnsetFrames
from timbreremap.data import OnsetFeatureDataModule
import timbreremap.feature as feature
from timbreremap.synth import Snare808
from timbreremap.loss import FeatureDifferenceLoss

%load_ext autoreload
%autoreload 2

In [None]:
sample_rate = 48000
data_path = "audio/carson_gant_drums/performance.wav"

In [None]:
drums, sr = torchaudio.load(data_path)
drums = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)(drums)[:1]
print(drums.shape)

In [None]:
ipd.Audio(drums.squeeze().numpy(), rate=48000)

In [None]:
onset_detection = OnsetFrames(sr, 256)
onset_times = onset_detection.onset(drums)

In [None]:
plt.plot(drums[0].numpy())

# Plot the detected onsets as vlines
for onset in onset_times:
    plt.vlines(onset, -0.5, 0.5, color="red", alpha=0.5)

plt.title("Detected Onsets")
plt.show()

In [None]:
def get_loudness_extractor(scaling_function: None):
    loudness_extractor = feature.Loudness(sample_rate=sample_rate, scaling_function=scaling_function)
    frame_extractor = feature.CascadingFrameExtactor(
        [loudness_extractor],
        [
            2,
        ],
        2048,
        512,
    )
    return frame_extractor

onset_frames = OnsetFrames(
    sample_rate,
    frame_size=sample_rate,
    on_thresh=16.0,
    wait=1323,
    backtrack=16,
    overlap_buffer=512,
)

frames = onset_frames(drums)
frames = torch.from_numpy(frames).float()
frame_extractor = get_loudness_extractor(None)
loudness = frame_extractor(frames)
print(loudness.shape)

In [None]:
plt.plot(loudness.numpy())

In [None]:
synth = Snare808(
    sample_rate=sample_rate,
    num_samples=sample_rate,
    buffer_noise=True,
    buffer_size=sample_rate,
)

preset = "808_snare_1.json"  # @param ["808_snare_1.json", "808_snare_2.json", "808_snare_3.json", "808_noisy_snare.json", "808_open_snare.json"]
preset = f"../cfg/presets/{preset}"

parameters, _ = synth.load_params_json(preset)
audio = synth(parameters)
ipd.Audio(audio, rate=sample_rate)

In [None]:
def optimize_gain(extractor, target_audio, audio, ref:int = 0, iters:int = 500):
    gain = torch.ones_like(loudness)

    # Compute target loudness difference
    y = extractor(target_audio)
    y_diff = y - y[ref]

    y_hat = extractor(audio * gain)
    synth_feat = y_hat[0]

    assert y.shape == y_hat.shape

    gain = torch.nn.Parameter(gain)
    optimizer = torch.optim.Adam([gain], lr=1e-2)
    loss = FeatureDifferenceLoss()

    for i in range(iters):
        optimizer.zero_grad()
        y_hat = extractor(audio * gain)

        err = loss(y_hat, synth_feat, y_diff)
        err.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"Step {i}, Loss {err.item()}")
    
    return gain.detach(), y

In [None]:
def resynthesize(synth_audio, original_audio, timings):

    resynth = torch.zeros_like(original_audio)
    for i, onset in enumerate(timings):
        start = onset
        end = min(onset + synth_audio[i].shape[-1], resynth.shape[-1])
        resynth[0, start:end] += synth_audio[i][: end - start]
    
    return resynth

In [None]:
scaling_function = lambda x: -1.0 * torch.pow(x + 0.000001, -0.1)
#scaling_function = lambda x: -6.92 + 10.0*torch.log10(x + 1e-8)
#scaling_function = None
frame_extractor = get_loudness_extractor(scaling_function)

gain, y = optimize_gain(frame_extractor, frames, audio, iters=1000)

audio_hat = audio * gain.detach().numpy()
y_hat = frame_extractor(audio_hat)

In [None]:
plt.plot(y.numpy(), label="Original")
plt.plot(y_hat.numpy(), label="Predicted")
plt.legend()
plt.show()

resynth = resynthesize(audio_hat, drums, onset_times)

ipd.display(ipd.Audio(drums.squeeze().numpy(), rate=sample_rate))
ipd.display(ipd.Audio(resynth.squeeze().numpy(), rate=sample_rate))

In [None]:
x = torch.linspace(0.001, 0.1, 100)
scaling_function = lambda x: -1 * torch.pow(x + 0.0000001, -0.1)
scaling_function_2 = lambda x: -1 * torch.pow(x + 0.0000001, -0.1)
scaling_function_db = lambda x: -6.92 + 10.0*torch.log10(x + 1e-8)


plt.plot(x, scaling_function(x), color="red")
# plt.plot(x, scaling_function(x + 0.01))
# plt.plot(x, scaling_function(x + 0.02))
plt.plot(x, scaling_function(x + 0.05), color="red")

# plt.plot(x, scaling_function_db(x), color="blue")
# # plt.plot(x, scaling_function_db(x + 0.01))
# # plt.plot(x, scaling_function_db(x + 0.02))
# plt.plot(x, scaling_function_db(x + 0.05), color="blue")