In [26]:
from math import prod
from typing import Iterable, cast

import torch
from torch import Tensor
from cattrs import structure
from attrs import asdict
import IPython.display as ipd
from einops import rearrange

from src.model import VAE
from src.pqmf import PQMF
from src.cli import parse_hydra_config
from src.train import TrainingConfig

In [27]:
CONFIG = "configs/experiment/graz/graz_mini_base.yaml"
cfg = structure(
    parse_hydra_config("/home/jp/mgr2024/configs", CONFIG),
    TrainingConfig,
)
CHECKPOINT = "/home/jp/mgr2024/logs/graz_mini/graz_mini_2025-12-05_13-09/version_0/checkpoints/last.ckpt"

# these values should match the model configuration
SAMPLING_RATE = cfg.dataset.expected_sample_rate
N_BANDS = cfg.model.n_bands
LATENT_SIZE = cfg.model.latent_size
DOWNSAMPLED_LENGTH = cfg.dataset.zero_pad_cut // (N_BANDS * prod(cfg.model.strides))
# TARGET_LENGTH = SAMPLING_RATE * 3  # assuming we're doing the single latent version
DEVICE = torch.device('cpu')



In [35]:
pqmf = PQMF(100, 16, n_channels=2).to(DEVICE)

def _random_z(n: int) -> Tensor:
    return torch.distributions.Normal(
        torch.zeros(size=(n, LATENT_SIZE, DOWNSAMPLED_LENGTH)),
        torch.ones(size=(n, LATENT_SIZE, DOWNSAMPLED_LENGTH)),
    ).sample().to(DEVICE)

def _iterate_batched(x: Tensor, batch_size: int) -> Iterable[Tensor]:
    for idx in range(0, x.shape[0], batch_size):
        yield x[idx:idx + batch_size]

def join_bands(multiband_x: Tensor) -> Tensor:
    x_rearranged = cast(Tensor, rearrange(multiband_x, "b (chs c) t -> (b chs) c t", chs=2))
    single_band_x = pqmf.inverse(x_rearranged)
    return cast(Tensor, rearrange(single_band_x, "(b chs) c t -> b (chs c) t", chs=2))

In [29]:
model = VAE.load_from_checkpoint(CHECKPOINT, noise_config=cfg.noise, **asdict(cfg.model)).to(DEVICE)
model.eval()
model



VAE(
  (encoder): Encoder(
    (net): Sequential(
      (0): Conv1d(32, 64, kernel_size=(7,), stride=(1,), padding=(3,))
      (1): _EncoderBlock(
        (net): Sequential(
          (0): _ResidualDilatedUnit(
            (net): Sequential(
              (0): LeakyReLU(negative_slope=0.01)
              (1): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(3,))
              (2): LeakyReLU(negative_slope=0.01)
              (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
            )
          )
          (1): _ResidualDilatedUnit(
            (net): Sequential(
              (0): LeakyReLU(negative_slope=0.01)
              (1): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))
              (2): LeakyReLU(negative_slope=0.01)
              (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
            )
          )
          (2): _ResidualDilatedUnit(
            (net): Sequential(
              (0): LeakyReLU(negative_slope=0.01)
              (1)

In [37]:
with torch.no_grad():
    z = _random_z(4)  # 4 examples
    print(z.shape)
    generated = join_bands(model.decoder(z))

ipd.Audio(generated[0].to(torch.device('cpu')), rate=SAMPLING_RATE)

torch.Size([4, 128, 30])


In [38]:
ipd.Audio(generated[1].to(torch.device('cpu')), rate=SAMPLING_RATE)

In [39]:
ipd.Audio(generated[2].to(torch.device('cpu')), rate=SAMPLING_RATE)

In [None]:
ipd.Audio(generated[3].to(torch.device('cpu')), rate=SAMPLING_RATE)

: 

In [None]:
# nice