In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import lmdb
from udls.generated import AudioExample
import IPython.display as ipd
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
import logging
import torchaudio
import librosa
from typing import Callable, Optional
import warnings

logger = logging.getLogger(__name__)
logging.basicConfig(
    format="[%(asctime)s] %(levelname)s(%(name)s)\t%(message)s", level=logging.INFO
)

device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info("device: %s", device)

[2024-01-12 15:34:25,687] INFO(root)	device: cpu


In [2]:
# utils
plt.rcParams["savefig.bbox"] = "tight"


def show_image(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = torchvision.transforms.functional.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])


def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)
    if title is not None:
        ax.set_title(title)
    ax.set_ylabel(ylabel)
    ax.imshow(
        librosa.power_to_db(specgram),
        origin="lower",
        aspect="auto",
        interpolation="nearest",
    )

In [3]:
# dataset definition
class LoopDataset(torch.utils.data.Dataset):
    FS = 44100
    SIZE_SAMPLES = 65536

    def __init__(self, db_path: str) -> None:
        super().__init__()

        self._db_path = db_path

        self.env = lmdb.open(self._db_path, lock=False)

        with self.env.begin(write=False) as txn:
            self.keys = list(txn.cursor().iternext(values=False))

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx: int):
        with self.env.begin(write=False) as txn:
            ae = AudioExample.FromString(txn.get(self.keys[idx]))

        buffer = ae.buffers["waveform"]
        assert buffer.precision == AudioExample.Precision.INT16
        assert buffer.sampling_rate == self.FS

        audio = torch.frombuffer(buffer.data, dtype=torch.int16)
        audio = audio.float() / (2**15 - 1)
        assert len(audio) == self.SIZE_SAMPLES

        return audio


# get 5 random examples
dataset = LoopDataset(db_path="../../data/loops/")
valid_ratio = 0.2
nb_valid = int(valid_ratio * len(dataset))
nb_train = len(dataset) - nb_valid
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(
    dataset, [nb_train, nb_valid]
)

print(nb_train, nb_valid)

num_threads = 0  # != 0 crashes on windows o_o
batch_size = 128

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, num_workers=num_threads
)
valid_loader = torch.utils.data.DataLoader(
    dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_threads
)

for i in np.random.randint(len(train_dataset), size=5):
    print(f"example #{i}")
    ipd.display(ipd.Audio(train_dataset[i], rate=LoopDataset.FS))

9195 2298
example #7386


  audio = torch.frombuffer(buffer.data, dtype=torch.int16)


example #3835


example #284


example #2132


example #4340


In [33]:

# test: forward then backward
html = """<table><thead><tr>
<td>id</td>
<td>original</td>
<td>copyphase</td>
<td>copyphase diff</td>
<td>griffinlim</td>
<td>griffinlim diff</td>
</tr></thead>"""
t = CustomTransform(
    sample_rate=LoopDataset.FS,
    n_mels=32,
    n_fft=256,
    griffin_lim_iter=64,

).to(device)
for i in [7809, 2016, 8888, 1234]:
    spec, phase = t.forward(dataset[i].to(device))
    # print(spec.shape, phase.shape)
    # plot_spectrogram(spec.cpu())
    t_copyphase = t.backward(spec, phase)
    t_griffinlim = t.backward(spec)
    html += "<tr>"
    html += f"<td>{i}</td>"
    html += "<td>" + ipd.Audio(dataset[i], rate=LoopDataset.FS)._repr_html_() + "</td>"
    html += "<td>" + ipd.Audio(t_copyphase.cpu(), rate=LoopDataset.FS)._repr_html_() + "</td>"
    html += "<td>" + ipd.Audio(t_copyphase.cpu() - dataset[i], rate=LoopDataset.FS)._repr_html_() + "</td>"
    html += "<td>" + ipd.Audio(t_griffinlim.cpu(), rate=LoopDataset.FS)._repr_html_() + "</td>"
    html += "<td>" + ipd.Audio(t_griffinlim.cpu() - dataset[i], rate=LoopDataset.FS)._repr_html_() + "</td>"
    html += "</tr>"
html += "</table>"

ipd.display(ipd.HTML(html))



id,original,copyphase,copyphase diff,griffinlim,griffinlim diff
7809,Your browser does not support the audio element.,Your browser does not support the audio element.,Your browser does not support the audio element.,Your browser does not support the audio element.,Your browser does not support the audio element.
2016,Your browser does not support the audio element.,Your browser does not support the audio element.,Your browser does not support the audio element.,Your browser does not support the audio element.,Your browser does not support the audio element.
8888,Your browser does not support the audio element.,Your browser does not support the audio element.,Your browser does not support the audio element.,Your browser does not support the audio element.,Your browser does not support the audio element.
1234,Your browser does not support the audio element.,Your browser does not support the audio element.,Your browser does not support the audio element.,Your browser does not support the audio element.,Your browser does not support the audio element.


In [22]:
# model definition
class AudioVAE(nn.Module):
    def __init__(self):
        super(AudioVAE, self).__init__()
        n_latent = 16
        sizes = [129, 64, 32, 16]
        encoder_modules = []
        decoder_modules = []
        for i in range(len(sizes) - 1):
            encoder_modules += [
                nn.Conv1d(
                    in_channels=sizes[i], out_channels=sizes[i + 1], kernel_size=4
                ),
                nn.BatchNorm1d(sizes[i + 1]),
                nn.LeakyReLU(),
            ]
            decoder_modules += [
                nn.Conv1d(
                    in_channels=sizes[-i - 1], out_channels=sizes[-i - 2], kernel_size=4
                ),
                nn.BatchNorm1d(sizes[-i - 2]),
                nn.LeakyReLU(),
            ]

        # decoder_modules.append(nn.Sigmoid())

        self.encoder = nn.Sequential(*encoder_modules)
        self.decoder = nn.Sequential(*decoder_modules)

        self.mu = nn.Conv1d(sizes[-1], n_latent, 4)
        self.sigma = nn.Sequential(nn.Conv1d(sizes[-1], n_latent, 4), nn.Softplus())

    def encode(self, x: torch.Tensor):
        h = self.encoder(x)
        return self.mu(h), self.sigma(h)

    def decode(self, z: torch.Tensor):
        return self.decoder(z)

    def forward(self, x: torch.Tensor):
        # Encode the inputs
        mu, log_var = self.encode(x)
        # Obtain latent samples and kl divergence
        z_tilde, kl_div = self.latent(mu, log_var)
        # Decode the samples
        x_tilde = self.decode(z_tilde)
        return x_tilde, kl_div

    def latent(self, mu: torch.Tensor, log_var: torch.Tensor):
        # reparametrization trick
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = eps * std + mu

        kl_div = torch.sum(
            -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0
        )

        return z, kl_div

In [23]:
def beta_warmup(epoch, interval, epoch_interval):
    if epoch < epoch_interval[0]:
        return interval[0]
    elif epoch > epoch_interval[1]:
        return interval[1]
    return (
        (epoch - epoch_interval[0])
        * (interval[1] - interval[0])
        / (epoch_interval[1] - epoch_interval[0])
    )

In [24]:
# training loop

# parameters
n_epochs = 50
beta_interval = (0, 1)  # min, max
beta_epoch_interval = (10, 40)  # start, end
evaluate_every_nth_epoch = 3
generate_every_nth_epoch = 3
n_latent = 16
n_mels = 512
transform = CustomTransform(
    sample_rate=LoopDataset.FS,
    n_mels=n_mels,
    n_fft=1024,
    griffin_lim_iter=64,
).to(device)
n_frames = transform.get_n_frames(LoopDataset.SIZE_SAMPLES)

filename = f"audiovae_{beta_interval[0]}_{beta_interval[1]}_{n_epochs}.pt"
logger.info("training %s", filename)
logger.info("n_epochs=%d", n_epochs)

model = AudioVAE().to(device)
logger.info(model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
recons_criterion = torch.nn.MSELoss(reduction="sum")

if Path(filename).exists():
    model.load_state_dict(torch.load(filename))
    logger.info("loaded from", filename)

else:
    WRITER = SummaryWriter(comment=filename)

    for epoch in range(n_epochs):
        beta = beta_warmup(epoch, beta_interval, beta_epoch_interval)
        logger.info(
            f"epoch {epoch+1}/{n_epochs}; beta={beta:.2f}; {epoch * len(train_dataset)} examples seen"
        )
        WRITER.add_scalar("beta", beta, epoch)

        logger.info("training")
        model.train()
        full_loss = 0
        recons_loss = 0
        kl_div = 0
        for i, waveform in tqdm(enumerate(train_loader), desc="Train"):
            optimizer.zero_grad()
            waveform = waveform.to(device)
            mag, phase = transform.forward(waveform)
            mag = mag.reshape(-1, n_frames, n_mels)
            mag_tilde, kl_div_batch = model(mag)
            recons_loss_batch = recons_criterion(mag_tilde, mag)
            full_loss_batch = recons_loss_batch + beta * kl_div_batch
            recons_loss += recons_loss_batch
            full_loss += full_loss_batch
            kl_div += kl_div_batch
            full_loss_batch.backward()
            optimizer.step()
        WRITER.add_scalar("loss/train/full", full_loss, epoch)
        WRITER.add_scalar("loss/train/reconstruction", recons_loss, epoch)
        WRITER.add_scalar("loss/train/kl_div", kl_div, epoch)

        if epoch % evaluate_every_nth_epoch == 0 or epoch == n_epochs - 1:
            logger.info("evaluating")
            model.eval()
            full_loss = 0
            recons_loss = 0
            kl_div = 0
            for i, waveform in tqdm(enumerate(valid_loader), desc="Evaluation"):
                optimizer.zero_grad()
                waveform = waveform.to(device)
                mag, phase = transform.forward(waveform)
                mag = mag.reshape(-1, n_in)
                mag_tilde, kl_div_batch = model(mag)
                recons_loss_batch = recons_criterion(mag_tilde, mag)
                full_loss_batch = recons_loss_batch + beta * kl_div_batch
                recons_loss += recons_loss_batch
                full_loss += full_loss_batch
                kl_div += kl_div_batch
                full_loss_batch.backward()
                optimizer.step()
            WRITER.add_scalar("loss/eval/full", full_loss, epoch)
            WRITER.add_scalar("loss/eval/reconstruction", recons_loss, epoch)
            WRITER.add_scalar("loss/eval/kl_div", kl_div, epoch)

        if epoch % generate_every_nth_epoch == 0 or epoch == n_epochs - 1:
            logger.info("generating from dataset")
            with torch.no_grad():
                n_sounds = 4
                waveform = next(iter(valid_loader)).to(device)
                mag, phase = transform.forward(waveform[:n_sounds])
                mag_tilde, _ = model(mag.reshape(-1, n_in))
                mag_tilde = mag_tilde.reshape(-1, 1, n_mels, n_frames)
                grid = torchvision.utils.make_grid(mag_tilde, n_sounds)
                WRITER.add_image("gen/dataset/melspec", grid)

                waveform_tilde_copyphase = transform.backward(mag_tilde, phase)
                waveform_tilde_griffinlim = transform.backward(mag_tilde)
                WRITER.add_audio(
                    "gen/dataset/copyphase",
                    waveform_tilde_copyphase.reshape(-1),
                    epoch,
                    sample_rate=LoopDataset.FS,
                )
                WRITER.add_audio(
                    "gen/dataset/griffinlim",
                    waveform_tilde_griffinlim.reshape(-1),
                    epoch,
                    sample_rate=LoopDataset.FS,
                )

            logger.info("generating random from latent space")
            with torch.no_grad():
                n_sounds = 16
                z = torch.randn(n_sounds, n_latent).to(device)
                mag_tilde = model.decode(z).reshape(-1, 1, n_mels, n_frames)
                grid = torchvision.utils.make_grid(mag_tilde, n_sounds)
                WRITER.add_image("gen/rand_latent/melspec", grid)
                waveform_tilde = transform.backward(mag_tilde)
                WRITER.add_audio(
                    "gen/rand_latent/griffinlim",
                    waveform_tilde_griffinlim.reshape(-1),
                    epoch,
                    sample_rate=LoopDataset.FS,
                )
            logger.info("exploring latent space")
            with torch.no_grad():
                n_sounds_per_dimension = 8
                z = torch.zeros(n_sounds_per_dimension * n_latent, n_latent).to(device)
                for i in range(n_latent):
                    a = i * n_sounds_per_dimension
                    b = (i + 1) * n_sounds_per_dimension
                    z[a:b, i] = torch.linspace(-5, +5, n_sounds_per_dimension)
                mag_tilde = model.decode(z).reshape(-1, 1, n_mels, n_frames)
                grid = torchvision.utils.make_grid(mag_tilde, n_sounds_per_dimension)
                WRITER.add_image("gen/explo_latent/melspec", grid)
                waveform_tilde = transform.backward(mag_tilde)
                WRITER.add_audio(
                    "gen/explo_latent/griffinlim",
                    waveform_tilde_griffinlim.reshape(-1),
                    epoch,
                    sample_rate=LoopDataset.FS,
                )

    # save weights
    torch.save(model.state_dict(), filename)
    logger.info("saved to %s", filename)

[2024-01-02 18:13:00,741] INFO(__main__)	training audiovae_0_1_50.pt
[2024-01-02 18:13:00,741] INFO(__main__)	n_epochs=50
[2024-01-02 18:13:00,782] INFO(__main__)	AudioVAE(
  (encoder): Sequential(
    (0): Conv1d(129, 64, kernel_size=(4,), stride=(1,))
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Conv1d(64, 32, kernel_size=(4,), stride=(1,))
    (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): Conv1d(32, 16, kernel_size=(4,), stride=(1,))
    (7): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.01)
  )
  (decoder): Sequential(
    (0): Conv1d(16, 32, kernel_size=(4,), stride=(1,))
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Conv1d(32, 64, kernel_size=

RuntimeError: The size of tensor a (491) must match the size of tensor b (512) at non-singleton dimension 2