In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
from hydra import compose, initialize
import IPython.display as ipd
from omegaconf import OmegaConf
import pytorch_lightning as pl
import torch

from neural_field_synth.data import NSynthDataset
from neural_field_synth.models import LightningWrapper, NeuralFieldSynth

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7faac07d0e80>

<IPython.core.display.Javascript object>

In [3]:
initialize(config_path="../outputs/2021-12-19/13-52-16/.hydra/")
cfg = compose(config_name="config")

<IPython.core.display.Javascript object>

In [4]:
model = NeuralFieldSynth(**cfg.model)

model = LightningWrapper.load_from_checkpoint(
    "../outputs/2021-12-19/13-52-16/wavespace/1iww9s2n/checkpoints/epoch=60-step=62463.ckpt",
    model=model,
    loss_fn=None,
)

model.eval()
model = model.cuda()

<IPython.core.display.Javascript object>

In [6]:
ds = NSynthDataset("/import/c4dm-datasets/nsynth/nsynth-test")
dl = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=True)
dl_it = iter(dl)

<IPython.core.display.Javascript object>

## Reconstruction

In [45]:
batch = next(dl_it)
time = torch.linspace(-1, 1, 64000)[:, None].expand(-1, 2).cuda()
output = model(
    time,
    batch["pitch"].float().cuda(),
    batch["velocity"].float().cuda(),
    batch["instrument"].float().cuda(),
    return_params=True,
)
print("example 1")
print(" -- target")
ipd.display(ipd.Audio(batch["audio"][0, 0], rate=model.model.sample_rate))
print(" -- reconstruction")
ipd.display(ipd.Audio(output.output.detach().cpu()[:, 0], rate=model.model.sample_rate))
print("example 2")
print(" -- target")
ipd.display(ipd.Audio(batch["audio"][1, 0], rate=model.model.sample_rate))
print(" -- reconstruction")
ipd.display(ipd.Audio(output.output.detach().cpu()[:, 1], rate=model.model.sample_rate))

example 1
 -- target


 -- reconstruction


example 2
 -- target


 -- reconstruction


<IPython.core.display.Javascript object>

## Interpolation

In [46]:
n_steps = 4
interp_embed = torch.stack(
    [
        torch.linspace(
            output.instrument_embedding[0, dim].item(),
            output.instrument_embedding[1, dim].item(),
            n_steps,
        )
        for dim in range(output.instrument_embedding.shape[0])
    ],
    dim=-1,
).cuda()
interp_pitch = torch.stack(
    [torch.linspace(batch["pitch"][0].item(), batch["pitch"][1].item(), n_steps)],
    dim=-1,
).cuda()
interp_vel = torch.stack(
    [torch.linspace(batch["velocity"][0].item(), batch["velocity"][1].item(), n_steps)],
    dim=-1,
).cuda()
interp_output = [
    model.model.forward_with_instrument_embed(
        time[:, 0:1],
        interp_pitch[i],
        interp_vel[i],
        interp_embed[i : i + 1],
    )[:, 0]
    for i in range(n_steps)
]
for o in interp_output:
    ipd.display(ipd.Audio(o.detach().cpu(), rate=model.model.sample_rate))

<IPython.core.display.Javascript object>

## Random sampling

In [62]:
time = torch.linspace(-1, 1, 64000)[:, None].cuda().expand(-1, 4)
pitch = torch.rand(4).cuda() * 127
velocity = torch.rand(4).cuda() * 127
instrument_embed = torch.randn(4, 2).cuda() * 0.1

output = model.model.forward_with_instrument_embed(
    time, pitch, velocity, instrument_embed, return_params=True
)
for i in range(output.output.shape[-1]):
    ipd.display(
        ipd.Audio(output.output.detach().cpu()[:, i], rate=model.model.sample_rate)
    )
del output

<IPython.core.display.Javascript object>