In [21]:
from math import prod
from typing import Iterable, cast

import torch
from torch import Tensor
from cattrs import structure
from attrs import asdict
import IPython.display as ipd
from einops import rearrange

from src.model import VAE
from src.pqmf import PQMF
from src.configs import parse_hydra_config
from src.train import TrainingConfig

In [46]:
CONFIG = "configs/experiment/snares2.yaml"
cfg = structure(
    parse_hydra_config("/home/jp/mgr2024/configs", CONFIG),
    TrainingConfig,
)
CHECKPOINT = "/home/jp/mgr2024/logs/snares/version_11/checkpoints/last.ckpt"

# these values should match the model configuration
SAMPLING_RATE = cfg.dataset.expected_sample_rate
N_BANDS = cfg.model.n_bands
LATENT_SIZE = cfg.model.latent_size
# DOWNSAMPLED_LENGTH = cfg.dataset.zero_pad_cut // (N_BANDS * prod(cfg.model.strides))
DOWNSAMPLED_LENGTH = 1
# TARGET_LENGTH = SAMPLING_RATE * 3  # assuming we're doing the single latent version
DEVICE = torch.device('cpu')
AUDIO_CHANNELS = 1 if cfg.dataset.mono else 2

In [47]:
pqmf = PQMF(100, 16, n_channels=AUDIO_CHANNELS).to(DEVICE)

def _random_z(n: int) -> Tensor:
    return torch.distributions.Normal(
        torch.zeros(size=(n, LATENT_SIZE, DOWNSAMPLED_LENGTH)),
        torch.ones(size=(n, LATENT_SIZE, DOWNSAMPLED_LENGTH)),
    ).sample().to(DEVICE)

def _iterate_batched(x: Tensor, batch_size: int) -> Iterable[Tensor]:
    for idx in range(0, x.shape[0], batch_size):
        yield x[idx:idx + batch_size]

def join_bands(multiband_x: Tensor) -> Tensor:
    x_rearranged = cast(Tensor, rearrange(multiband_x, "b (chs c) t -> (b chs) c t", chs=AUDIO_CHANNELS))
    single_band_x = pqmf.inverse(x_rearranged)
    return cast(Tensor, rearrange(single_band_x, "(b chs) c t -> b (chs c) t", chs=AUDIO_CHANNELS))

In [48]:
model = VAE.load_from_checkpoint(CHECKPOINT, noise_config=cfg.noise, **asdict(cfg.model)).to(DEVICE)
model.eval()
model

VAE(
  (encoder): Encoder(
    (net): Sequential(
      (0): Conv1d(16, 64, kernel_size=(7,), stride=(1,), padding=(3,))
      (1): _EncoderBlock(
        (net): Sequential(
          (0): _ResidualDilatedUnit(
            (net): Sequential(
              (0): LeakyReLU(negative_slope=0.2)
              (1): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(3,))
              (2): LeakyReLU(negative_slope=0.2)
              (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
            )
          )
          (1): _ResidualDilatedUnit(
            (net): Sequential(
              (0): LeakyReLU(negative_slope=0.2)
              (1): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))
              (2): LeakyReLU(negative_slope=0.2)
              (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
            )
          )
          (2): _ResidualDilatedUnit(
            (net): Sequential(
              (0): LeakyReLU(negative_slope=0.2)
              (1): Con

In [49]:
from einops import rearrange


with torch.no_grad():
    z = _random_z(8)  # 4 examples
    print(z.shape)
    generated = join_bands(model.decoder(z))

ipd.display(ipd.Audio(rearrange(generated, "b c l -> c (b l)").to(torch.device('cpu')), rate=SAMPLING_RATE))

torch.Size([8, 64, 1])


In [50]:
# interpolate

def _linear_interp(x: Tensor, y: Tensor, weight: float) -> Tensor:
    return (1 - weight) * x + weight * y

z1, z2 = z[0], z[1]
zs = torch.stack([_linear_interp(z1, z2, w) for w in [0.0, 0.25, 0.5, 0.75, 1.0]], dim=0)
print(zs.shape)

with torch.no_grad():
    generated = join_bands(model.decoder(zs))

ipd.Audio(rearrange(generated, "b c l -> c (b l)").to(torch.device('cpu')), rate=SAMPLING_RATE)

torch.Size([5, 64, 1])


In [41]:
# randomly translating the whole vector

def _noise(n: int, loc: float, scale: float) -> Tensor:
    return torch.distributions.Normal(
        torch.zeros(size=(n, LATENT_SIZE, DOWNSAMPLED_LENGTH)) + loc,
        torch.ones(size=(n, LATENT_SIZE, DOWNSAMPLED_LENGTH)) * scale,
    ).sample().to(DEVICE)

z0 = z[0]
noises = _noise(4, loc=-4, scale=1.5)
# zs = torch.stack([_linear_interp(z1, z2, w) for w in [0.0, 0.25, 0.5, 0.75, 1.0]], dim=0)
zs = torch.stack([z0 + n for n in noises], dim=0)
print(zs.shape)

with torch.no_grad():
    generated = join_bands(model.decoder(zs))

ipd.Audio(generated[0].to(torch.device('cpu')), rate=SAMPLING_RATE)

torch.Size([4, 64, 1])


In [None]:

# with torch.no_grad():
#     z = _random_z(1000)  # 4 examples
#     print(z.shape)

# def _linear_interp(x: Tensor, y: Tensor, weight: float) -> Tensor:
#     return (1 - weight) * x + weight * y

# z1, z2 = z[1], z[3]
# zs = torch.stack([_linear_interp(z1, z2, w) for w in [0.0, 0.25, 0.5, 0.75, 1.0]], dim=0)
# print(zs.shape)

# with torch.no_grad():
#     generated = join_bands(model.decoder(zs))

# ipd.Audio(generated[0].to(torch.device('cpu')), rate=SAMPLING_RATE)