# Loading pre-trained models

In [None]:
from einops import rearrange
import IPython.display as ipd
from lightning.pytorch.cli import LightningCLI
import yaml
from jsonargparse import ArgumentParser, ActionConfigFile
import sys
from unittest.mock import patch
import torch

from synthmap.task import SynthMapTask

In [None]:
USE_GPU = False
device = "cuda" if torch.cuda.is_available() and USE_GPU else "cpu"

In [None]:
config = "../cfg/vae_discrete.yaml"
ckpt = "../lightning_logs/version_9/checkpoints/epoch=20-step=32823.ckpt"

with patch.object(
    sys, "argv", ["fit", "-c", str(config), "--trainer.accelerator", device]
):
    cli = LightningCLI(run=False)
    model = cli.model

In [None]:
state_dict = torch.load(ckpt, map_location=device)["state_dict"]
model.load_state_dict(state_dict)
model = model.eval()

In [None]:
decoder = model.autoencoder.decoder
print(decoder.in_size)

In [None]:
synth = cli.datamodule.synth

In [None]:
z = torch.randn(2, decoder.in_size)
p = decoder(z)

# Convert to synth params
p = model.param_discretizer.group_parameters(p)
p = model.param_discretizer.inverse(p)

y = synth(p)

In [None]:
ipd.display(ipd.Audio(y[0].numpy(), rate=synth.sample_rate))
ipd.display(ipd.Audio(y[1].numpy(), rate=synth.sample_rate))

In [None]:
# interpolate between the two z vectors
steps = torch.linspace(0, 1, 10)
z_interp = z[0] * steps[:, None] + z[1] * (1 - steps[:, None])
z_interp = z_interp / torch.sqrt((1 - steps[:, None]) ** 2 + steps[:, None] ** 2)

p_interp = decoder(z_interp)
p_interp = model.param_discretizer.group_parameters(p_interp)
p_interp = model.param_discretizer.inverse(p_interp)

y_interp = synth(p_interp)

y_interp = rearrange(y_interp, "b n -> (b n)")
ipd.display(ipd.Audio(y_interp.cpu().numpy(), rate=synth.sample_rate))

# for i in range(10):
#     ipd.display(ipd.Audio(y_interp[i].cpu().numpy(), rate=synth.sample_rate))