In [11]:
import torch
from sonification.models.models import PlFMParamEstimator
from sonification.utils.misc import midi2frequency
from IPython.display import Audio, display
import numpy as np

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class Args:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

args = Args(
    sr=48000,
    length_s=0.25,
    n_fft=4096,
    f_min=midi2frequency(38),
    f_max=midi2frequency(86),
    n_mels=512,
    power=1,
    normalized=1,
    max_harm_ratio=6,
    max_mod_idx=6,
    latent_size=128,
    encoder_kernels=[4, 16],
    n_res_block=24,
    n_res_channel=128,
    hidden_dim=32,
    num_layers=3,
    batch_size=512,
    lr=0.0001,
    lr_decay=0.75,
    warmup_epochs=1,
    train_epochs=100000,
    steps_per_epoch=1000,
    param_loss_weight_start=9.5,
    param_loss_weight_end=9.5,
    param_loss_weight_ramp_start_epoch=0,
    param_loss_weight_ramp_end_epoch=1,
    ckpt_path="./ckpt/fm_ddsp",
    ckpt_name=f"grad_test_16",
    logdir="./logs/fm_ddsp",
    comment=""
)

In [5]:
ckpt_path = "../../ckpt/fm_ddsp/grad_test_16/grad_test_16_last_epoch=1813.ckpt"
ckpt = torch.load(ckpt_path, map_location=device)
model = PlFMParamEstimator(args).to(device)
model.load_state_dict(ckpt['state_dict'])
model.eval()

  ckpt = torch.load(ckpt_path, map_location=device)


PlFMParamEstimator(
  (input_synth): FMSynth(
    (modulator_sine): Sinewave(
      (phasor): Phasor()
    )
    (carrier_sine): Sinewave(
      (phasor): Phasor()
    )
  )
  (mel_spectrogram): MelSpectrogram(
    (spectrogram): Spectrogram()
    (mel_scale): MelScale()
  )
  (model): FMParamEstimator(
    (encoder): MultiScaleEncoder(
      (lanes): ModuleList(
        (0): Sequential(
          (0): Conv2d(1, 64, kernel_size=(4, 16), stride=(2, 2), padding=(1, 7))
          (1): LeakyReLU(negative_slope=0.2, inplace=True)
          (2): Conv2d(64, 128, kernel_size=(4, 16), stride=(2, 2), padding=(1, 7))
          (3): LeakyReLU(negative_slope=0.2, inplace=True)
          (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (5): ResBlock(
            (conv): Sequential(
              (0): LeakyReLU(negative_slope=0.2, inplace=True)
              (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (2): BatchNorm2d(128, e

In [60]:
# forward pass: estimate params
with torch.no_grad():
    norm_params, freqs, ratios, indices = model.sample_fm_params(1)
    input_params = [freqs[0][0].item(), ratios[0][0].item(), indices[0][0].item()]
    x = model.input_synth(freqs, ratios, indices)
    predicted_params = model(x)
    # now repeat on the samples dimension
    predicted_freqs = predicted_params[:, 0].unsqueeze(1).repeat(1, model.n_samples)
    predicted_ratios = predicted_params[:, 1].unsqueeze(1).repeat(1, model.n_samples)
    predicted_indices = predicted_params[:, 2].unsqueeze(1).repeat(1, model.n_samples)
    # generate the output
    y = model.output_synth(predicted_freqs, predicted_ratios, predicted_indices)
    out_wf = y.unsqueeze(1)

print("input params:", input_params)
print("predicted params:", predicted_params.squeeze().tolist())
print("input audio:")
display(Audio(x.squeeze().cpu().numpy(), rate=args.sr))
print("predicted audio:")
display(Audio(out_wf.squeeze().cpu().numpy(), rate=args.sr))

input params: [137.6983184814453, 3.1533679962158203, 1.9362244606018066]
predicted params: [435.6393127441406, 0.04395894706249237, 1.0492754881852306e-05]
input audio:


predicted audio:
