In [1]:
from math import prod
from typing import Iterable, cast

import torch
from torch import Tensor
from cattrs import structure
from attrs import asdict
import IPython.display as ipd
from einops import rearrange

from src.model import VAE
from src.pqmf import PQMF
from src.cli import parse_hydra_config
from src.train import TrainingConfig

  from .autonotebook import tqdm as notebook_tqdm
Disabling PyTorch because PyTorch >= 2.1 is required but found 1.12.1+cu113
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
CONFIG = "configs/experiment/snares.yaml"
cfg = structure(
    parse_hydra_config("/home/jp/mgr2024/configs", CONFIG),
    TrainingConfig,
)
CHECKPOINT = "/home/jp/mgr2024/logs/snares/version_7/checkpoints/last.ckpt"

# these values should match the model configuration
SAMPLING_RATE = cfg.dataset.expected_sample_rate
N_BANDS = cfg.model.n_bands
LATENT_SIZE = cfg.model.latent_size
# DOWNSAMPLED_LENGTH = cfg.dataset.zero_pad_cut // (N_BANDS * prod(cfg.model.strides))
DOWNSAMPLED_LENGTH = 1
# TARGET_LENGTH = SAMPLING_RATE * 3  # assuming we're doing the single latent version
DEVICE = torch.device('cpu')
AUDIO_CHANNELS = 1 if cfg.dataset.mono else 2



In [3]:
pqmf = PQMF(100, 16, n_channels=AUDIO_CHANNELS).to(DEVICE)

def _random_z(n: int) -> Tensor:
    return torch.distributions.Normal(
        torch.zeros(size=(n, LATENT_SIZE, DOWNSAMPLED_LENGTH)),
        torch.ones(size=(n, LATENT_SIZE, DOWNSAMPLED_LENGTH)),
    ).sample().to(DEVICE)

def _iterate_batched(x: Tensor, batch_size: int) -> Iterable[Tensor]:
    for idx in range(0, x.shape[0], batch_size):
        yield x[idx:idx + batch_size]

def join_bands(multiband_x: Tensor) -> Tensor:
    x_rearranged = cast(Tensor, rearrange(multiband_x, "b (chs c) t -> (b chs) c t", chs=AUDIO_CHANNELS))
    single_band_x = pqmf.inverse(x_rearranged)
    return cast(Tensor, rearrange(single_band_x, "(b chs) c t -> b (chs c) t", chs=AUDIO_CHANNELS))

In [4]:
model = VAE.load_from_checkpoint(CHECKPOINT, noise_config=cfg.noise, **asdict(cfg.model)).to(DEVICE)
model.eval()
model



VAE(
  (encoder): Encoder(
    (net): Sequential(
      (0): Conv1d(16, 64, kernel_size=(7,), stride=(1,), padding=(3,))
      (1): _EncoderBlock(
        (net): Sequential(
          (0): _ResidualDilatedUnit(
            (net): Sequential(
              (0): LeakyReLU(negative_slope=0.01)
              (1): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(3,))
              (2): LeakyReLU(negative_slope=0.01)
              (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
            )
          )
          (1): _ResidualDilatedUnit(
            (net): Sequential(
              (0): LeakyReLU(negative_slope=0.01)
              (1): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))
              (2): LeakyReLU(negative_slope=0.01)
              (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
            )
          )
          (2): _ResidualDilatedUnit(
            (net): Sequential(
              (0): LeakyReLU(negative_slope=0.01)
              (1)

In [5]:
# from pytorch_lightning.utilities.model_summary import summarize

# model = VAE(**asdict(cfg.model), noise_config=cfg.noise)
# summarize(model)

In [23]:
with torch.no_grad():
    z = _random_z(4)  # 4 examples
    print(z.shape)
    generated = join_bands(model.decoder(z))

ipd.Audio(generated[0].to(torch.device('cpu')), rate=SAMPLING_RATE)

torch.Size([4, 32, 1])


In [24]:
ipd.Audio(generated[1].to(torch.device('cpu')), rate=SAMPLING_RATE)

In [25]:
ipd.Audio(generated[2].to(torch.device('cpu')), rate=SAMPLING_RATE)

In [26]:
ipd.Audio(generated[3].to(torch.device('cpu')), rate=SAMPLING_RATE)

In [27]:
# interpolate

def _linear_interp(x: Tensor, y: Tensor, weight: float) -> Tensor:
    return (1 - weight) * x + weight * y

z1, z2 = z[0], z[1]
zs = torch.stack([_linear_interp(z1, z2, w) for w in [0.0, 0.25, 0.5, 0.75, 1.0]], dim=0)
print(zs.shape)

with torch.no_grad():
    generated = join_bands(model.decoder(zs))

ipd.Audio(generated[0].to(torch.device('cpu')), rate=SAMPLING_RATE)

torch.Size([5, 32, 1])


In [28]:
ipd.Audio(generated[1].to(torch.device('cpu')), rate=SAMPLING_RATE)

In [29]:
ipd.Audio(generated[2].to(torch.device('cpu')), rate=SAMPLING_RATE)

In [30]:
ipd.Audio(generated[3].to(torch.device('cpu')), rate=SAMPLING_RATE)

In [31]:
ipd.Audio(generated[4].to(torch.device('cpu')), rate=SAMPLING_RATE)

In [61]:
# randomly translating the whole vector

def _noise(n: int, loc: float, scale: float) -> Tensor:
    return torch.distributions.Normal(
        torch.zeros(size=(n, LATENT_SIZE, DOWNSAMPLED_LENGTH)) + loc,
        torch.ones(size=(n, LATENT_SIZE, DOWNSAMPLED_LENGTH)) * scale,
    ).sample().to(DEVICE)

z0 = z[0]
noises = _noise(4, loc=-4, scale=1.5)
# zs = torch.stack([_linear_interp(z1, z2, w) for w in [0.0, 0.25, 0.5, 0.75, 1.0]], dim=0)
zs = torch.stack([z0 + n for n in noises], dim=0)
print(zs.shape)

with torch.no_grad():
    generated = join_bands(model.decoder(zs))

ipd.Audio(generated[0].to(torch.device('cpu')), rate=SAMPLING_RATE)

torch.Size([4, 32, 1])


In [62]:
ipd.Audio(generated[1].to(torch.device('cpu')), rate=SAMPLING_RATE)

In [63]:
ipd.Audio(generated[2].to(torch.device('cpu')), rate=SAMPLING_RATE)

In [64]:
ipd.Audio(generated[3].to(torch.device('cpu')), rate=SAMPLING_RATE)

In [None]:

# with torch.no_grad():
#     z = _random_z(1000)  # 4 examples
#     print(z.shape)

# def _linear_interp(x: Tensor, y: Tensor, weight: float) -> Tensor:
#     return (1 - weight) * x + weight * y

# z1, z2 = z[1], z[3]
# zs = torch.stack([_linear_interp(z1, z2, w) for w in [0.0, 0.25, 0.5, 0.75, 1.0]], dim=0)
# print(zs.shape)

# with torch.no_grad():
#     generated = join_bands(model.decoder(zs))

# ipd.Audio(generated[0].to(torch.device('cpu')), rate=SAMPLING_RATE)