# TODO
- get rid of the gin config stuff
    - replace everything in v1 with variables & reference those vars
    - same for v2
- make copy & reduce to entry point
- compile entry point
- re-add next section of graph

## directory & imports

In [1]:
%cd /home/ubuntu/rave-compile

/home/ubuntu/rave-compile


In [2]:
import hashlib
import os
import sys

import gin
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

import pathlib
import cached_conv as cc

from functools import partial

import cached_conv as cc
import gin
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
from torchaudio.transforms import Spectrogram

from rave.core import amp_to_impulse_response, fft_convolve, mod_sigmoid

import json
from pathlib import Path
from random import random
from typing import Callable, Optional, Sequence, Union

import GPUtil as gpu
import librosa as li
import lmdb
import numpy as np
import torch.fft as fft
import torch.nn as nn
import torchaudio
from einops import rearrange
from scipy.signal import lfilter

import base64
import logging
import math
import os
import subprocess
from random import random
from typing import Dict, Iterable, Optional, Sequence, Callable, Tuple, Type, Any, Union

import numpy as np
import requests
import yaml
from scipy.signal import lfilter
from torch.utils import data
from tqdm import tqdm
from udls import AudioExample as AudioExampleWrapper
from udls import transforms
from udls.generated import AudioExample

from time import time

from sklearn.decomposition import PCA

## config

In [17]:
SAMPLING_RATE = 44100
CAPACITY = 64
N_BAND = 16
LATENT_SIZE = 128
RATIOS = [4, 4, 4, 2]
PHASE_1_DURATION = 1000000

# CORE CONFIGURATION
core.AudioDistanceV1:
    multiscale_stft = @core.MultiScaleSTFT
    log_epsilon = 1e-7

core.MultiScaleSTFT:
    scales = [2048, 1024, 512, 256, 128]
    sample_rate = %SAMPLING_RATE
    magnitude = True

dataset.split_dataset.max_residual = 1000

# CONVOLUTION CONFIGURATION
cc.Conv1d.bias = False
cc.ConvTranspose1d.bias = False

# PQMF
pqmf.CachedPQMF:
    attenuation = 100
    n_band = %N_BAND

blocks.normalization.mode = 'weight_norm'

# ENCODER
blocks.Encoder:
    data_size = %N_BAND
    capacity = %CAPACITY
    latent_size = %LATENT_SIZE
    ratios = %RATIOS
    sample_norm = False
    repeat_layers = 1

variational/blocks.Encoder.n_out = 2

blocks.VariationalEncoder:
    encoder = @variational/blocks.Encoder

# DECODER
blocks.Generator:
    latent_size = %LATENT_SIZE
    capacity = %CAPACITY
    data_size = %N_BAND
    ratios = %RATIOS
    loud_stride = 1
    use_noise = True

blocks.ResidualStack:
    kernel_sizes = [3]
    dilations_list = [[1, 1], [3, 1], [5, 1]]

blocks.NoiseGenerator:
    ratios = [4, 4, 4]
    noise_bands = 5

# DISCRIMINATOR
discriminator.ConvNet:
    in_size = 1
    out_size = 1
    capacity = %CAPACITY
    n_layers = 4
    stride = 4

scales/discriminator.ConvNet:
    conv = @torch.nn.Conv1d
    kernel_size = 15

discriminator.MultiScaleDiscriminator:
    n_discriminators = 3
    convnet = @scales/discriminator.ConvNet

feature_matching/core.mean_difference:
    norm = 'L1'

# BALANCER
balancer.Balancer:
    ema_averager = @balancer.EMA
    weights = {
        'regularization': .1,
        'feature_matching': 10,
    }
    deny_list = [
        'regularization'
    ]
    scale_gradients = False

balancer.EMA:
    beta = 0.999

# MODEL ASSEMBLING
latent_size = %LATENT_SIZE
pqmf = @pqmf.CachedPQMF
sampling_rate = %SAMPLING_RATE
encoder = @blocks.VariationalEncoder  
decoder = @blocks.Generator
discriminator = @discriminator.MultiScaleDiscriminator
phase_1_duration = %PHASE_1_DURATION
gan_loss = @core.hinge_gan
valid_signal_crop = False
feature_matching_fun = @feature_matching/core.mean_difference
num_skipped_features = 0
audio_distance = @core.AudioDistanceV1
multiband_audio_distance = @core.AudioDistanceV1
balancer = @balancer.Balancer

SyntaxError: invalid syntax (3050905911.py, line 9)

## classes

In [3]:
class EMA:
    def __init__(self, beta: float = 0.999) -> None:
        self.shadows = {}
        self.beta = beta

    def __call__(self, inputs: Dict[str, torch.Tensor]):
        outputs = {}
        for k, v in inputs.items():
            if not k in self.shadows:
                self.shadows[k] = v.to("cuda")
            else:
                self.shadows[k] *= self.beta
                self.shadows[k] += (1 - self.beta) * v

            outputs[k] = self.shadows[k].clone()
        return outputs
    
    
class Balancer:
    def __init__(
        self,
        ema_averager: Callable[[], EMA],
        weights: Dict[str, float],
        scale_gradients: bool = False,
        deny_list: Optional[Sequence[str]] = None,
    ) -> None:
        self.ema_averager = ema_averager()
        self.weights = weights
        self.scale_gradients = scale_gradients
        self.deny_list = deny_list

    def backward(
        self,
        losses: Dict[str, torch.Tensor],
        model_output: torch.Tensor,
        logger: Optional[Callable[[str, float], None]] = None,
        profiler: Optional[Any] = None,
    ):
        grads = {}
        norms = {}

        for k, v in losses.items():
            if self.deny_list is not None:
                if k in self.deny_list:
                    continue

            (grads[k],) = torch.autograd.grad(
                v.to("cuda"),
                [model_output.to("cuda")],
                retain_graph=True,
            )

            if (nans := torch.isnan(grads[k].to('cuda'))).any():
                count = nans.float().to("cuda").mean()
                grads[k] = torch.where(
                    nans, torch.zeros_like(nans, device="cuda"), grads[k]
                )
                if logger is not None:
                    logger(f"{k}_nan_ratio", count)

            norms[k] = grads[k].to("cuda").norm(dim=tuple(range(1, grads[k].dim()))).mean()

            if profiler is not None:
                profiler(f"partial backward {k}")

        avg_norms = self.ema_averager(norms)

        if profiler is not None:
            profiler("grad norm estimation")

        sum_weights = sum([self.weights.get(k, 1) for k in avg_norms])

        for name, norm in avg_norms.items():
            if self.scale_gradients:
                ratio = self.weights.get(name, 1) / sum_weights
                scale = ratio / (norm + 1e-6)

                if logger is not None:
                    logger(f"scale_{name}", scale)
                    logger(f"grad_norm_{name}", grads[name].norm())
                    logger(f"target_norm_{name}", ratio)
            else:
                scale = self.weights.get(name, 1)

                if logger is not None:
                    logger(f"scale_{name}", scale)
                    logger(f"grad_norm_{name}", grads[name].norm())

        if profiler is not None:
            profiler("norm scaling")

        full_grad = sum([grads[name].to("cuda") for name in avg_norms.keys()]).to("cuda")
        model_output.backward(full_grad, retain_graph=True)

        if profiler is not None:
            profiler("scaled backward")

        if self.deny_list is not None:
            for k in self.deny_list:
                if k in losses:
                    loss = losses[k].to("cuda") * self.weights.get(k, 1)
                    if logger is not None:
                        logger(f"scale_{name}", scale)
                        logger(f"grad_norm_{name}", grads[name].norm())
                    if loss.requires_grad:
                        loss.backward(retain_graph=True)

        if profiler is not None:
            profiler("denied backward")

In [4]:
__file__ = "/home/ubuntu/rave-training/rave/__init__.py"

gin.enter_interactive_mode()
gin.add_config_file_search_path(os.path.dirname(__file__))
gin.add_config_file_search_path(
    os.path.join(
        os.path.dirname(__file__),
        'configs',
    ))

cc.get_padding = gin.external_configurable(cc.get_padding, module="cc")
cc.Conv1d = gin.external_configurable(cc.Conv1d, module="cc")
cc.ConvTranspose1d = gin.external_configurable(cc.ConvTranspose1d, module="cc")

In [5]:
def spectrogram(n_fft: int):
    return torchaudio.transforms.Spectrogram(
        n_fft,
        hop_length=n_fft // 4,
        power=None,
        normalized=True,
        center=False,
        pad_mode=None,
    )


def rectified_2d_conv_block(
    capacity,
    kernel_sizes,
    strides: Optional[Tuple[int, int]] = None,
    dilations: Optional[Tuple[int, int]] = None,
    in_size: Optional[int] = None,
    out_size: Optional[int] = None,
    activation: bool = True,
):
    if dilations is None:
        paddings = kernel_sizes[0] // 2, kernel_sizes[1] // 2
    else:
        fks = (kernel_sizes[0] - 1) * dilations[0], (kernel_sizes[1] -
                                                     1) * dilations[1]
        paddings = fks[0] // 2, fks[1] // 2

    conv = normalization(
        nn.Conv2d(
            in_size or capacity,
            out_size or capacity,
            kernel_size=kernel_sizes,
            stride=strides or (1, 1),
            dilation=dilations or (1, 1),
            padding=paddings,
        ))

    if not activation: return conv

    return nn.Sequential(conv, nn.LeakyReLU(.2))


class EncodecConvNet(nn.Module):

    def __init__(self, capacity: int) -> None:
        super().__init__()
        self.net = nn.Sequential(
            rectified_2d_conv_block(capacity, (9, 3), in_size=2),
            rectified_2d_conv_block(capacity, (9, 3), (2, 1), (1, 1)),
            rectified_2d_conv_block(capacity, (9, 3), (2, 1), (1, 2)),
            rectified_2d_conv_block(capacity, (9, 3), (2, 1), (1, 4)),
            rectified_2d_conv_block(capacity, (3, 3)),
            rectified_2d_conv_block(capacity, (3, 3),
                                    out_size=1,
                                    activation=False),
        )

    def forward(self, x):
        features = []
        for layer in self.net:
            x = layer(x)
            features.append(x)
        return features


class ConvNet(nn.Module):

    def __init__(self, in_size, out_size, capacity, n_layers, kernel_size,
                 stride, conv) -> None:
        super().__init__()
        channels = [in_size]
        channels += list(capacity * 2**np.arange(n_layers))

        if isinstance(stride, int):
            stride = n_layers * [stride]

        net = []
        for i in range(n_layers):
            if not isinstance(kernel_size, int):
                pad = (cc.get_padding(kernel_size[0],
                                      stride[i],
                                      mode="centered")[0], 0)
                s = (stride[i], 1)
            else:
                pad = cc.get_padding(kernel_size, stride[i],
                                     mode="centered")[0]
                s = stride[i]
            net.append(
                normalization(
                    conv(
                        channels[i],
                        channels[i + 1],
                        kernel_size,
                        stride=s,
                        padding=pad,
                    )))
            net.append(nn.LeakyReLU(.2))
        net.append(conv(channels[-1], out_size, 1))

        self.net = nn.Sequential(*net)

    def forward(self, x):
        features = []
        for layer in self.net:
            x = layer(x)
            if isinstance(layer, nn.modules.conv._ConvNd):
                features.append(x)
        return features


class MultiScaleDiscriminator(nn.Module):

    def __init__(self, n_discriminators, convnet) -> None:
        super().__init__()
        layers = []
        for i in range(n_discriminators):
            layers.append(convnet())
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        features = []
        for layer in self.layers:
            features.append(layer(x))
            x = nn.functional.avg_pool1d(x, 2)
        return features


class MultiScaleSpectralDiscriminator(nn.Module):

    def __init__(self, scales: Sequence[int],
                 convnet: Callable[[], nn.Module]) -> None:
        super().__init__()
        self.specs = nn.ModuleList([spectrogram(n) for n in scales])
        self.nets = nn.ModuleList([convnet() for _ in scales])

    def forward(self, x):
        features = []
        for spec, net in zip(self.specs, self.nets):
            spec_x = spec(x)
            spec_x = torch.cat([spec_x.real, spec_x.imag], 1)
            features.append(net(spec_x))
        return features


class MultiScaleSpectralDiscriminator1d(nn.Module):

    def __init__(self, scales: Sequence[int],
                 convnet: Callable[[int], nn.Module]) -> None:
        super().__init__()
        self.specs = nn.ModuleList([spectrogram(n) for n in scales])
        self.nets = nn.ModuleList([convnet(n + 2) for n in scales])

    def forward(self, x):
        features = []
        for spec, net in zip(self.specs, self.nets):
            spec_x = spec(x).squeeze(1)
            spec_x = torch.cat([spec_x.real, spec_x.imag], 1)
            features.append(net(spec_x))
        return features


class MultiPeriodDiscriminator(nn.Module):

    def __init__(self, periods, convnet) -> None:
        super().__init__()
        layers = []
        self.periods = periods

        for _ in periods:
            layers.append(convnet())

        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        features = []
        for layer, n in zip(self.layers, self.periods):
            features.append(layer(self.fold(x, n)))
        return features

    def fold(self, x, n):
        pad = (n - (x.shape[-1] % n)) % n
        x = nn.functional.pad(x, (0, pad))
        return x.reshape(*x.shape[:2], -1, n)


class CombineDiscriminators(nn.Module):

    def __init__(self, discriminators: Sequence[Type[nn.Module]]) -> None:
        super().__init__()
        self.discriminators = nn.ModuleList(disc_cls()
                                            for disc_cls in discriminators)

    def forward(self, x):
        features = []
        for disc in self.discriminators:
            features.extend(disc(x))
        return features

In [6]:
class Profiler:

    def __init__(self):
        self.ticks = [[time(), None]]

    def tick(self, msg):
        self.ticks.append([time(), msg])

    def __repr__(self):
        rep = 80 * "=" + "\n"
        for i in range(1, len(self.ticks)):
            msg = self.ticks[i][1]
            ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
            rep += msg + f": {ellapsed*1000:.2f}ms\n"
        rep += 80 * "=" + "\n\n\n"
        return rep


class WarmupCallback(pl.Callback):

    def __init__(self) -> None:
        super().__init__()
        self.state = {'training_steps': 0}

    def on_train_batch_start(self, trainer, pl_module, batch,
                             batch_idx) -> None:
        if self.state['training_steps'] >= pl_module.warmup:
            pl_module.warmed_up = True
        self.state['training_steps'] += 1

    def state_dict(self):
        return self.state.copy()

    def load_state_dict(self, state_dict):
        self.state.update(state_dict)


class QuantizeCallback(WarmupCallback):

    def on_train_batch_start(self, trainer, pl_module, batch,
                             batch_idx) -> None:

        if pl_module.warmup_quantize is None: return

        if self.state['training_steps'] >= pl_module.warmup_quantize:
            if isinstance(pl_module.encoder, DiscreteEncoder):
                pl_module.encoder.enabled = torch.tensor(1, device="cuda").type_as(
                    pl_module.encoder.enabled)
        self.state['training_steps'] += 1


@gin.configurable
class RAVE(pl.LightningModule):

    def __init__(
        self,
        latent_size,
        sampling_rate,
        encoder,
        decoder,
        discriminator,
        phase_1_duration,
        gan_loss,
        valid_signal_crop,
        feature_matching_fun,
        num_skipped_features,
        audio_distance: Callable[[], nn.Module],
        multiband_audio_distance: Callable[[], nn.Module],
        balancer: Callable[[], Balancer],
        warmup_quantize: Optional[int] = None,
        pqmf: Optional[Callable[[], nn.Module]] = None,
        update_discriminator_every: int = 2,
        enable_pqmf_encode: bool = True,
        enable_pqmf_decode: bool = True,
    ):
        super().__init__()

        self.pqmf = None
        if pqmf is not None:
            self.pqmf = pqmf().to("cuda")

        self.encoder = encoder().to("cuda")
        self.decoder = decoder().to("cuda")
        self.discriminator = discriminator().to("cuda")

        self.audio_distance = audio_distance().to("cuda")
        self.multiband_audio_distance = multiband_audio_distance().to("cuda")

        self.gan_loss = gan_loss

        self.register_buffer("latent_pca", torch.eye(latent_size, device="cuda"))
        self.register_buffer("latent_mean", torch.zeros(latent_size, device="cuda"))
        self.register_buffer("fidelity", torch.zeros(latent_size, device="cuda"))

        self.latent_size = latent_size

        self.automatic_optimization = False

        # SCHEDULE
        self.warmup = 500 # phase_1_duration
        self.warmup_quantize = warmup_quantize
        self.balancer = balancer()

        self.warmed_up = False

        # CONSTANTS
        self.sr = sampling_rate
        self.valid_signal_crop = valid_signal_crop
        self.feature_matching_fun = feature_matching_fun
        self.num_skipped_features = num_skipped_features
        self.update_discriminator_every = update_discriminator_every

        self.eval_number = 0
        self.integrator = None

        self.enable_pqmf_encode = enable_pqmf_encode
        self.enable_pqmf_decode = enable_pqmf_decode

        self.register_buffer("receptive_field", torch.tensor([0, 0], device="cuda").long())

    def configure_optimizers(self):
        gen_p = list(self.encoder.parameters())
        gen_p += list(self.decoder.parameters())
        dis_p = list(self.discriminator.parameters())

        gen_opt = torch.optim.Adam(gen_p, 1e-4, (.5, .9))
        dis_opt = torch.optim.Adam(dis_p, 1e-4, (.5, .9))

        return gen_opt, dis_opt

    def split_features(self, features):
        feature_real = []
        feature_fake = []
        for scale in features:
            true, fake = zip(*map(
                lambda x: torch.split(x, x.shape[0] // 2, 0),
                scale,
            ))
            feature_real.append(true)
            feature_fake.append(fake)
        return feature_real, feature_fake

    def training_step(self, batch, batch_idx):
        batch = batch.to("cuda")
        p = Profiler()
        gen_opt, dis_opt = self.optimizers()

        with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
            x = batch.unsqueeze(1)

            if self.pqmf is not None:
                x_multiband = self.pqmf(x)
            else:
                x_multiband = x
            p.tick('decompose')

            self.encoder.set_warmed_up(self.warmed_up)
            self.decoder.set_warmed_up(self.warmed_up)

            # ENCODE INPUT
            if self.enable_pqmf_encode:
                z_pre_reg = self.encoder(x_multiband)
            else:
                z_pre_reg = self.encoder(x)

            z, reg = self.encoder.reparametrize(z_pre_reg)[:2]
            p.tick('encode')

            # DECODE LATENT
            y_multiband = self.decoder(z)
            p.tick('decode')

            if self.valid_signal_crop and self.receptive_field.sum():
                x_multiband = rave.core.valid_signal_crop(
                    x_multiband,
                    *self.receptive_field,
                )
                y_multiband = rave.core.valid_signal_crop(
                    y_multiband,
                    *self.receptive_field,
                )
            p.tick('crop')

            # DISTANCE BETWEEN INPUT AND OUTPUT
            distances = {}

            if self.pqmf is not None:
                multiband_distance = self.multiband_audio_distance(
                    x_multiband, y_multiband)
                p.tick('mb distance')

                x = self.pqmf.inverse(x_multiband)
                y = self.pqmf.inverse(y_multiband)
                p.tick('recompose')

                for k, v in multiband_distance.items():
                    distances[f'multiband_{k}'] = v
            else:
                x = x_multiband
                y = y_multiband

            fullband_distance = self.audio_distance(x, y)
            p.tick('fb distance')

            for k, v in fullband_distance.items():
                distances[f'fullband_{k}'] = v

            feature_matching_distance = 0.

            if self.warmed_up:  # DISCRIMINATION
                xy = torch.cat([x, y], 0)
                features = self.discriminator(xy)

                feature_real, feature_fake = self.split_features(features)

                loss_dis = 0
                loss_adv = 0

                pred_real = 0
                pred_fake = 0

                for scale_real, scale_fake in zip(feature_real, feature_fake):
                    current_feature_distance = sum(
                        map(
                            self.feature_matching_fun,
                            scale_real[self.num_skipped_features:],
                            scale_fake[self.num_skipped_features:],
                        )) / len(scale_real[self.num_skipped_features:])

                    feature_matching_distance = feature_matching_distance + current_feature_distance

                    _dis, _adv = self.gan_loss(scale_real[-1], scale_fake[-1])

                    pred_real = pred_real + scale_real[-1].mean()
                    pred_fake = pred_fake + scale_fake[-1].mean()

                    loss_dis = loss_dis + _dis
                    loss_adv = loss_adv + _adv

                feature_matching_distance = feature_matching_distance / len(
                    feature_real)

            else:
                pred_real = torch.tensor(0., device="cuda").to(x)
                pred_fake = torch.tensor(0., device="cuda").to(x)
                loss_dis = torch.tensor(0., device="cuda").to(x)
                loss_adv = torch.tensor(0., device="cuda").to(x)
            p.tick('discrimination')

            # COMPOSE GEN LOSS
            loss_gen = {}
            loss_gen.update(distances)
            p.tick('update loss gen dict')

            if reg.item():
                loss_gen['regularization'] = reg

            if self.warmed_up:
                loss_gen['feature_matching'] = feature_matching_distance
                loss_gen['adversarial'] = loss_adv

        # OPTIMIZATION
        if not (batch_idx %
                self.update_discriminator_every) and self.warmed_up:
            dis_opt.zero_grad(set_to_none=True)
            p.tick('dis opt')
        else:
            gen_opt.zero_grad(set_to_none=True)
            self.balancer.backward(loss_gen, y_multiband, self.log, p.tick)
            gen_opt.step()

        # LOGGING
        if self.warmed_up:
            self.log("loss_dis", loss_dis)
            self.log("pred_real", pred_real.mean())
            self.log("pred_fake", pred_fake.mean())

        self.log_dict(loss_gen)
        p.tick('logging')

    def encode(self, x):
        if self.pqmf is not None and self.enable_pqmf_encode:
            x = self.pqmf(x)
        z, = self.encoder.reparametrize(self.encoder(x))[:1]
        return z

    def decode(self, z):
        y = self.decoder(z)
        if self.pqmf is not None and self.enable_pqmf_decode:
            y = self.pqmf.inverse(y)
        return y

    def forward(self, x):
        return self.decode(self.encode(x))

    def validation_step(self, batch, batch_idx):
        batch = batch.to("cuda")
        x = batch.unsqueeze(1)

        if self.pqmf is not None:
            x_multiband = self.pqmf(x)

        if self.enable_pqmf_encode:
            z = self.encoder(x_multiband)
        else:
            z = self.encoder(x)

        if isinstance(self.encoder, VariationalEncoder):
            mean = torch.split(z, z.shape[1] // 2, 1)[0]
        else:
            mean = None

        z = self.encoder.reparametrize(z)[0]
        y = self.decoder(z)

        if self.pqmf is not None:
            x = self.pqmf.inverse(x_multiband)
            y = self.pqmf.inverse(y)

        distance = self.audio_distance(x, y)

        full_distance = sum(distance.values())

        if self.trainer is not None:
            self.log('validation', full_distance)

        return torch.cat([x, y], -1), mean

    def validation_epoch_end(self, out):
        if not self.receptive_field.sum():
            print("Computing receptive field for this configuration...")
            lrf, rrf = rave.core.get_rave_receptive_field(self)
            self.receptive_field[0] = lrf
            self.receptive_field[1] = rrf
            print(
                f"Receptive field: {1000*lrf/self.sr:.2f}ms <-- x --> {1000*rrf/self.sr:.2f}ms"
            )

        if not len(out): return

        audio, z = list(zip(*out))
        audio = list(map(lambda x: x.cpu(), audio))

        # LATENT SPACE ANALYSIS
        if not self.warmed_up and isinstance(self.encoder, VariationalEncoder):
            z = torch.cat(z, 0)
            z = rearrange(z, "b c t -> (b t) c")

            self.latent_mean.copy_(z.mean(0))
            z = z - self.latent_mean

            pca = PCA(z.shape[-1]).fit(z.cpu().numpy())

            components = pca.components_
            components = torch.from_numpy(components).to("cuda").to(z)
            self.latent_pca.copy_(components)

            var = pca.explained_variance_ / np.sum(pca.explained_variance_)
            var = np.cumsum(var)

            self.fidelity.copy_(torch.from_numpy(var).to("cuda").to(self.fidelity))

            var_percent = [.8, .9, .95, .99]
            for p in var_percent:
                self.log(
                    f"fidelity_{p}",
                    np.argmax(var > p).astype(np.float32),
                )

        y = torch.cat(audio, 0)[:8].reshape(-1).numpy()

        if self.integrator is not None:
            y = self.integrator(y)

        self.logger.experiment.add_audio("audio_val", y, self.eval_number,
                                         self.sr)
        self.eval_number += 1

    def on_fit_start(self):
        tb = self.logger.experiment

        config = gin.operative_config_str()
        config = config.split('\n')
        config = ['```'] + config + ['```']
        config = '\n'.join(config)
        tb.add_text("config", config)

        model = str(self)
        model = model.split('\n')
        model = ['```'] + model + ['```']
        model = '\n'.join(model)
        tb.add_text("model", model)


In [7]:
# import rave
from rave.pqmf import *

try:
    from .__version__ import *
except:
    __version__ = None
    __commit__ = None

In [8]:
# import rave.blocks

@gin.configurable
def normalization(module: nn.Module, mode: str = "identity"):
    if mode == "identity":
        return module
    elif mode == "weight_norm":
        return weight_norm(module)
    else:
        raise Exception(f"Normalization mode {mode} not supported")


class SampleNorm(nn.Module):
    def forward(self, x):
        return x / torch.norm(x, 2, 1, keepdim=True)


class Residual(nn.Module):
    def __init__(self, module, cumulative_delay=0):
        super().__init__()
        additional_delay = module.cumulative_delay
        self.aligned = cc.AlignBranches(
            module,
            nn.Identity(),
            delays=[additional_delay, 0],
        )
        self.cumulative_delay = additional_delay + cumulative_delay

    def forward(self, x):
        x_net, x_res = self.aligned(x)
        return x_net + x_res


class ResidualLayer(nn.Module):
    def __init__(self, dim, kernel_size, dilations, cumulative_delay=0):
        super().__init__()
        net = []
        cd = 0
        for d in dilations:
            net.append(nn.LeakyReLU(0.2))
            net.append(
                normalization(
                    cc.Conv1d(
                        dim,
                        dim,
                        kernel_size,
                        dilation=d,
                        padding=cc.get_padding(kernel_size, dilation=d),
                        cumulative_delay=cd,
                    )
                )
            )
            cd = net[-1].cumulative_delay
        self.net = Residual(
            cc.CachedSequential(*net),
            cumulative_delay=cumulative_delay,
        )
        self.cumulative_delay = self.net.cumulative_delay

    def forward(self, x):
        return self.net(x)


class DilatedUnit(nn.Module):
    def __init__(self, dim: int, kernel_size: int, dilation: int) -> None:
        super().__init__()
        net = [
            nn.LeakyReLU(0.2),
            normalization(
                cc.Conv1d(
                    dim,
                    dim,
                    kernel_size=kernel_size,
                    dilation=dilation,
                    padding=cc.get_padding(
                        kernel_size,
                        dilation=dilation,
                    ),
                )
            ),
            nn.LeakyReLU(0.2),
            normalization(cc.Conv1d(dim, dim, kernel_size=1)),
        ]

        self.net = cc.CachedSequential(*net)
        self.cumulative_delay = net[1].cumulative_delay

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class ResidualBlock(nn.Module):
    def __init__(self, dim, kernel_size, dilations_list, cumulative_delay=0) -> None:
        super().__init__()
        layers = []
        cd = 0

        for dilations in dilations_list:
            layers.append(
                ResidualLayer(
                    dim,
                    kernel_size,
                    dilations,
                    cumulative_delay=cd,
                )
            )
            cd = layers[-1].cumulative_delay

        self.net = cc.CachedSequential(
            *layers,
            cumulative_delay=cumulative_delay,
        )
        self.cumulative_delay = self.net.cumulative_delay

    def forward(self, x):
        return self.net(x)


@gin.configurable
class ResidualStack(nn.Module):
    def __init__(self, dim, kernel_sizes, dilations_list, cumulative_delay=0) -> None:
        super().__init__()
        blocks = []
        for k in kernel_sizes:
            blocks.append(ResidualBlock(dim, k, dilations_list))
        self.net = cc.AlignBranches(*blocks, cumulative_delay=cumulative_delay)
        self.cumulative_delay = self.net.cumulative_delay

    def forward(self, x):
        x = self.net(x)
        x = torch.stack(x, 0).to("cuda").sum(0)
        return x


class UpsampleLayer(nn.Module):
    def __init__(self, in_dim, out_dim, ratio, cumulative_delay=0):
        super().__init__()
        net = [nn.LeakyReLU(0.2)]
        if ratio > 1:
            net.append(
                normalization(
                    cc.ConvTranspose1d(
                        in_dim, out_dim, 2 * ratio, stride=ratio, padding=ratio // 2
                    )
                )
            )
        else:
            net.append(
                normalization(cc.Conv1d(in_dim, out_dim, 3, padding=cc.get_padding(3)))
            )

        self.net = cc.CachedSequential(*net)
        self.cumulative_delay = self.net.cumulative_delay + cumulative_delay * ratio

    def forward(self, x):
        return self.net(x)


@gin.configurable
class NoiseGenerator(nn.Module):
    def __init__(self, in_size, data_size, ratios, noise_bands):
        super().__init__()
        net = []
        channels = [in_size] * len(ratios) + [data_size * noise_bands]
        cum_delay = 0
        for i, r in enumerate(ratios):
            net.append(
                cc.Conv1d(
                    channels[i],
                    channels[i + 1],
                    3,
                    padding=cc.get_padding(3, r),
                    stride=r,
                    cumulative_delay=cum_delay,
                )
            )
            cum_delay = net[-1].cumulative_delay
            if i != len(ratios) - 1:
                net.append(nn.LeakyReLU(0.2))

        self.net = cc.CachedSequential(*net)
        self.data_size = data_size
        self.cumulative_delay = self.net.cumulative_delay * int(np.prod(ratios))

        self.register_buffer(
            "target_size", torch.tensor(np.prod(ratios), device="cuda").long()
        )

    def forward(self, x):
        amp = mod_sigmoid(self.net(x) - 5)
        amp = amp.permute(0, 2, 1)
        amp = amp.reshape(amp.shape[0], amp.shape[1], self.data_size, -1)

        ir = amp_to_impulse_response(amp, self.target_size)
        noise = torch.rand_like(ir, device="cuda") * 2 - 1

        noise = fft_convolve(noise, ir).permute(0, 2, 1, 3)
        noise = noise.reshape(noise.shape[0], noise.shape[1], -1)
        return noise


def normalize_dilations(
    dilations: Union[Sequence[int], Sequence[Sequence[int]]], ratios: Sequence[int]
):
    if isinstance(dilations[0], int):
        dilations = [dilations for _ in ratios]
    return dilations


class EncoderV2(nn.Module):
    def __init__(
        self,
        data_size: int,
        capacity: int,
        ratios: Sequence[int],
        latent_size: int,
        n_out: int,
        kernel_size: int,
        dilations: Sequence[int],
        keep_dim: bool = False,
        recurrent_layer: Optional[Callable[[], nn.Module]] = None,
        spectrogram: Optional[Callable[[], Spectrogram]] = None,
    ) -> None:
        super().__init__()
        dilations_list = normalize_dilations(dilations, ratios)

        if spectrogram is not None:
            self.spectrogram = spectrogram().to("cuda")
        else:
            self.spectrogram = None

        net = [
            normalization(
                cc.Conv1d(
                    data_size,
                    capacity,
                    kernel_size=kernel_size * 2 + 1,
                    padding=cc.get_padding(kernel_size * 2 + 1),
                ).to("cuda")
            ).to("cuda"),
        ]

        num_channels = capacity
        for r, dilations in zip(ratios, dilations_list):
            # ADD RESIDUAL DILATED UNITS
            for d in dilations:
                net.append(
                    Residual(
                        DilatedUnit(
                            dim=num_channels,
                            kernel_size=kernel_size,
                            dilation=d,
                        ).to("cuda")
                    ).to("cuda")
                )

            # ADD DOWNSAMPLING UNIT
            net.append(nn.LeakyReLU(0.2).to("cuda"))

            if keep_dim:
                out_channels = num_channels * r
            else:
                out_channels = num_channels * 2
            net.append(
                normalization(
                    cc.Conv1d(
                        num_channels,
                        out_channels,
                        kernel_size=2 * r,
                        stride=r,
                        padding=cc.get_padding(2 * r, r),
                    ).to("cuda")
                ).to("cuda")
            )

            num_channels = out_channels

        net.append(nn.LeakyReLU(0.2).to("cuda"))
        net.append(
            normalization(
                cc.Conv1d(
                    num_channels,
                    latent_size * n_out,
                    kernel_size=kernel_size,
                    padding=cc.get_padding(kernel_size),
                ).to("cuda")
            ).to("cuda")
        )

        if recurrent_layer is not None:
            net.append(recurrent_layer(latent_size * n_out).to("cuda"))

        self.net = cc.CachedSequential(*net).to("cuda")
        self.to("cuda")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.spectrogram is not None:
            x = self.spectrogram(x[:, 0])[..., :-1]
            x = torch.log1p(x).to("cuda")
        return self.net(x)


class GeneratorV2(nn.Module):
    def __init__(
        self,
        data_size: int,
        capacity: int,
        ratios: Sequence[int],
        latent_size: int,
        kernel_size: int,
        dilations: Sequence[int],
        keep_dim: bool = False,
        recurrent_layer: Optional[Callable[[], nn.Module]] = None,
        amplitude_modulation: bool = False,
    ) -> None:
        super().__init__()
        dilations_list = normalize_dilations(dilations, ratios)[::-1]
        ratios = ratios[::-1]

        if keep_dim:
            num_channels = np.prod(ratios) * capacity
        else:
            num_channels = 2 ** len(ratios) * capacity

        net = []

        if recurrent_layer is not None:
            net.append(recurrent_layer(latent_size))

        net.append(
            normalization(
                cc.Conv1d(
                    latent_size,
                    num_channels,
                    kernel_size=kernel_size,
                    padding=cc.get_padding(kernel_size),
                )
            ),
        )

        for r, dilations in zip(ratios, dilations_list):
            # ADD UPSAMPLING UNIT
            if keep_dim:
                out_channels = num_channels // r
            else:
                out_channels = num_channels // 2
            net.append(nn.LeakyReLU(0.2))
            net.append(
                normalization(
                    cc.ConvTranspose1d(
                        num_channels, out_channels, 2 * r, stride=r, padding=r // 2
                    )
                )
            )

            num_channels = out_channels

            # ADD RESIDUAL DILATED UNITS
            for d in dilations:
                net.append(
                    Residual(
                        DilatedUnit(
                            dim=num_channels,
                            kernel_size=kernel_size,
                            dilation=d,
                        )
                    )
                )

        net.append(nn.LeakyReLU(0.2))
        net.append(
            normalization(
                cc.Conv1d(
                    num_channels,
                    data_size * 2 if amplitude_modulation else data_size,
                    kernel_size=kernel_size * 2 + 1,
                    padding=cc.get_padding(kernel_size * 2 + 1),
                )
            )
        )

        self.net = cc.CachedSequential(*net)
        self.amplitude_modulation = amplitude_modulation
        self.to("cuda")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.net(x)

        if self.amplitude_modulation:
            x, amplitude = x.split(x.shape[1] // 2, 1)
            x = x * torch.sigmoid(amplitude)

        return torch.tanh(x)

    def set_warmed_up(self, state: bool):
        pass


class VariationalEncoder(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder().to("cuda")
        self.register_buffer("warmed_up", torch.tensor(0, device="cuda"))
        self.to("cuda")

    def reparametrize(self, z):
        mean, scale = z.chunk(2, 1)
        std = nn.functional.softplus(scale) + 1e-4
        var = std * std
        logvar = torch.log(var)

        z = torch.randn_like(mean, device="cuda") * std + mean
        kl = (mean * mean + var - logvar - 1).sum(1).mean()

        return z, kl

    def set_warmed_up(self, state: bool):
        state = torch.tensor(int(state), device=self.warmed_up.device)
        self.warmed_up = state

    def forward(self, x: torch.Tensor):
        z = self.encoder(x)
        if self.warmed_up:
            z = z.detach()
        return z



def unit_norm_vector_to_angles(x: torch.Tensor) -> torch.Tensor:
    norms = x.flip(1).pow(2)
    norms[:, 1] += norms[:, 0]
    norms = norms[:, 1:]
    norms = norms.cumsum(1).flip(1).sqrt()
    angles = torch.arccos(x[:, :-1] / norms)
    angles[:, -1] = torch.where(
        x[:, -1] >= 0,
        angles[:, -1],
        2 * np.pi - angles[:, -1],
    )
    angles[:, :-1] = angles[:, :-1] / np.pi
    angles[:, -1] = angles[:, -1] / (2 * np.pi)
    return 2 * (angles - 0.5)


def angles_to_unit_norm_vector(angles: torch.Tensor) -> torch.Tensor:
    angles = (angles / 2 + 0.5) % 1
    angles[:, :-1] = angles[:, :-1] * np.pi
    angles[:, -1] = angles[:, -1] * (2 * np.pi)
    cos = angles.cos()
    sin = angles.sin().cumprod(dim=1)
    cos = torch.cat(
        [
            cos,
            torch.ones(cos.shape[0], 1, cos.shape[-1], device="cuda").type_as(cos),
        ],
        1,
    )
    sin = torch.cat(
        [
            torch.ones(sin.shape[0], 1, sin.shape[-1], device="cuda").type_as(sin),
            sin,
        ],
        1,
    )
    return cos * sin


def wrap_around_value(x: torch.Tensor, value: float = 1) -> torch.Tensor:
    return (x + value) % (2 * value) - value


In [9]:
# import rave.core

def mod_sigmoid(x):
    return 2 * torch.sigmoid(x)**2.3 + 1e-7


def random_angle(min_f=20, max_f=8000, sr=24000):
    min_f = np.log(min_f)
    max_f = np.log(max_f)
    rand = np.exp(random() * (max_f - min_f) + min_f)
    rand = 2 * np.pi * rand / sr
    return rand


def get_augmented_latent_size(latent_size: int, noise_augmentation: int):
    return latent_size + noise_augmentation


def pole_to_z_filter(omega, amplitude=.9):
    z0 = amplitude * np.exp(1j * omega)
    a = [1, -2 * np.real(z0), abs(z0)**2]
    b = [abs(z0)**2, -2 * np.real(z0), 1]
    return b, a


def random_phase_mangle(x, min_f, max_f, amp, sr):
    angle = random_angle(min_f, max_f, sr)
    b, a = pole_to_z_filter(angle, amp)
    return lfilter(b, a, x)


def amp_to_impulse_response(amp, target_size):
    """
    transforms frequecny amps to ir on the last dimension
    """
    amp = torch.stack([amp, torch.zeros_like(amp, device="cuda")], -1)
    amp = torch.view_as_complex(amp)
    amp = fft.irfft(amp)

    filter_size = amp.shape[-1]

    amp = torch.roll(amp, filter_size // 2, -1)
    win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device)

    amp = amp * win

    amp = nn.functional.pad(
        amp,
        (0, int(target_size) - int(filter_size)),
    )
    amp = torch.roll(amp, -filter_size // 2, -1)

    return amp


def fft_convolve(signal, kernel):
    """
    convolves signal by kernel on the last dimension
    """
    signal = nn.functional.pad(signal, (0, signal.shape[-1]))
    kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0))

    output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel))
    output = output[..., output.shape[-1] // 2:]

    return output


def search_for_run(run_path, mode="last"):
    if run_path is None: return None
    if ".ckpt" in run_path: return run_path
    ckpts = map(str, Path(run_path).rglob("*.ckpt"))
    ckpts = filter(lambda e: mode in os.path.basename(str(e)), ckpts)
    ckpts = sorted(ckpts)
    if len(ckpts): return ckpts[-1]
    else: return None


def setup_gpu():
    return gpu.getAvailable(maxMemory=.05)


def get_beta_kl(step, warmup, min_beta, max_beta):
    if step > warmup: return max_beta
    t = step / warmup
    min_beta_log = np.log(min_beta)
    max_beta_log = np.log(max_beta)
    beta_log = t * (max_beta_log - min_beta_log) + min_beta_log
    return np.exp(beta_log)


def get_beta_kl_cyclic(step, cycle_size, min_beta, max_beta):
    return get_beta_kl(step % cycle_size, cycle_size // 2, min_beta, max_beta)


def get_beta_kl_cyclic_annealed(step, cycle_size, warmup, min_beta, max_beta):
    min_beta = get_beta_kl(step, warmup, min_beta, max_beta)
    return get_beta_kl_cyclic(step, cycle_size, min_beta, max_beta)


def n_fft_to_num_bands(n_fft: int) -> int:
    return n_fft // 2 + 1


def hinge_gan(score_real, score_fake):
    loss_dis = torch.relu(1 - score_real) + torch.relu(1 + score_fake)
    loss_dis = loss_dis.mean()
    loss_gen = -score_fake.mean()
    return loss_dis, loss_gen


def ls_gan(score_real, score_fake):
    loss_dis = (score_real - 1).pow(2) + score_fake.pow(2)
    loss_dis = loss_dis.mean()
    loss_gen = (score_fake - 1).pow(2).mean()
    return loss_dis, loss_gen


def nonsaturating_gan(score_real, score_fake):
    score_real = torch.clamp(torch.sigmoid(score_real), 1e-7, 1 - 1e-7)
    score_fake = torch.clamp(torch.sigmoid(score_fake), 1e-7, 1 - 1e-7)
    loss_dis = -(torch.log(score_real) + torch.log(1 - score_fake)).mean()
    loss_gen = -torch.log(score_fake).mean()
    return loss_dis, loss_gen


@torch.enable_grad()
def get_rave_receptive_field(model: nn.Module):
    N = 2**15
    model.eval()
    device = next(iter(model.parameters())).device

    for module in model.modules():
        if hasattr(module, 'gru_state') or hasattr(module, 'temporal'):
            module.disable()

    while True:
        x = torch.randn(1, 1, N, requires_grad=True, device=device)

        z = model.encode(x)
        y = model.decode(z)

        y[0, 0, N // 2].backward()
        assert x.grad is not None, "input has no grad"

        grad = x.grad.data.reshape(-1)
        left_grad, right_grad = grad.chunk(2, 0)
        large_enough = (left_grad[0] == 0) and right_grad[-1] == 0
        if large_enough:
            break
        else:
            N *= 2
    left_receptive_field = len(left_grad[left_grad != 0])
    right_receptive_field = len(right_grad[right_grad != 0])
    model.zero_grad()

    for module in model.modules():
        if hasattr(module, 'gru_state') or hasattr(module, 'temporal'):
            module.enable()
    ratio = x.shape[-1] // z.shape[-1]
    rate = model.sr / ratio
    return left_receptive_field, right_receptive_field


def valid_signal_crop(x, left_rf, right_rf):
    dim = x.shape[1]
    x = x[..., left_rf.item() // dim:]
    if right_rf.item():
        x = x[..., :-right_rf.item() // dim]
    return x


def relative_distance(
    x: torch.Tensor,
    y: torch.Tensor,
    norm: Callable[[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
    return norm(x - y) / norm(x)


def mean_difference(target: torch.Tensor,
                    value: torch.Tensor,
                    norm: str = 'L1',
                    relative: bool = False):
    diff = target - value
    if norm == 'L1':
        diff = diff.abs().mean()
        if relative:
            diff = diff / target.abs().mean()
        return diff
    elif norm == 'L2':
        diff = (diff * diff).mean()
        if relative:
            diff = diff / (target * target).mean()
        return diff
    else:
        raise Exception(f'Norm must be either L1 or L2, got {norm}')


class MelScale(nn.Module):

    def __init__(self, sample_rate: int, n_fft: int, n_mels: int) -> None:
        super().__init__()
        mel = li.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mels)
        mel = torch.from_numpy(mel).float().to("cuda")
        self.register_buffer('mel', mel)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mel = self.mel.type_as(x)
        y = torch.einsum('bft,mf->bmt', x, mel)
        return y


class MultiScaleSTFT(nn.Module):

    def __init__(self,
                 scales: Sequence[int],
                 sample_rate: int,
                 magnitude: bool = True,
                 normalized: bool = False,
                 num_mels: Optional[int] = None) -> None:
        super().__init__()
        self.scales = scales
        self.magnitude = magnitude
        self.num_mels = num_mels

        self.stfts = []
        self.mel_scales = []
        for scale in scales:
            self.stfts.append(
                torchaudio.transforms.Spectrogram(
                    n_fft=scale,
                    win_length=scale,
                    hop_length=scale // 4,
                    normalized=normalized,
                    power=None,
                ))
            if num_mels is not None:
                self.mel_scales.append(
                    MelScale(
                        sample_rate=sample_rate,
                        n_fft=scale,
                        n_mels=num_mels,
                    ))
            else:
                self.mel_scales.append(None)

        self.stfts = nn.ModuleList(self.stfts)
        self.mel_scales = nn.ModuleList(self.mel_scales)

    def forward(self, x: torch.Tensor) -> Sequence[torch.Tensor]:
        x = rearrange(x, "b c t -> (b c) t")
        stfts = []
        for stft, mel in zip(self.stfts, self.mel_scales):
            y = stft(x)
            if mel is not None:
                y = mel(y)
            if self.magnitude:
                y = y.abs()
            else:
                y = torch.stack([y.real, y.imag], -1)
            stfts.append(y)

        return stfts


class AudioDistanceV1(nn.Module):

    def __init__(self, multiscale_stft: Callable[[], nn.Module],
                 log_epsilon: float) -> None:
        super().__init__()
        self.multiscale_stft = multiscale_stft()
        self.log_epsilon = log_epsilon

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        stfts_x = self.multiscale_stft(x)
        stfts_y = self.multiscale_stft(y)
        distance = 0.

        for x, y in zip(stfts_x, stfts_y):
            logx = torch.log(x + self.log_epsilon)
            logy = torch.log(y + self.log_epsilon)

            lin_distance = mean_difference(x, y, norm='L2', relative=True)
            log_distance = mean_difference(logx, logy, norm='L1')

            distance = distance + lin_distance + log_distance

        return {'spectral_distance': distance}


class WeightedInstantaneousSpectralDistance(nn.Module):

    def __init__(self,
                 multiscale_stft: Callable[[], MultiScaleSTFT],
                 weighted: bool = False) -> None:
        super().__init__()
        self.multiscale_stft = multiscale_stft()
        self.weighted = weighted

    def phase_to_instantaneous_frequency(self,
                                         x: torch.Tensor) -> torch.Tensor:
        x = self.unwrap(x)
        x = self.derivative(x)
        return x

    def derivative(self, x: torch.Tensor) -> torch.Tensor:
        return x[..., 1:] - x[..., :-1]

    def unwrap(self, x: torch.Tensor) -> torch.Tensor:
        x = self.derivative(x)
        x = (x + np.pi) % (2 * np.pi)
        return (x - np.pi).cumsum(-1)

    def forward(self, target: torch.Tensor, pred: torch.Tensor):
        stfts_x = self.multiscale_stft(target)
        stfts_y = self.multiscale_stft(pred)
        spectral_distance = 0.
        phase_distance = 0.

        for x, y in zip(stfts_x, stfts_y):
            assert x.shape[-1] == 2

            x = torch.view_as_complex(x)
            y = torch.view_as_complex(y)

            # AMPLITUDE DISTANCE
            x_abs = x.abs()
            y_abs = y.abs()

            logx = torch.log1p(x_abs)
            logy = torch.log1p(y_abs)

            lin_distance = mean_difference(x_abs,
                                           y_abs,
                                           norm='L2',
                                           relative=True)
            log_distance = mean_difference(logx, logy, norm='L1')

            spectral_distance = spectral_distance + lin_distance + log_distance

            # PHASE DISTANCE
            x_if = self.phase_to_instantaneous_frequency(x.angle())
            y_if = self.phase_to_instantaneous_frequency(y.angle())

            if self.weighted:
                mask = torch.clip(torch.log1p(x_abs[..., 2:]), 0, 1)
                x_if = x_if * mask
                y_if = y_if * mask

            phase_distance = phase_distance + mean_difference(
                x_if, y_if, norm='L2')

        return {
            'spectral_distance': spectral_distance,
            'phase_distance': phase_distance
        }


class EncodecAudioDistance(nn.Module):

    def __init__(self, scales: int,
                 spectral_distance: Callable[[int], nn.Module]) -> None:
        super().__init__()
        self.waveform_distance = WaveformDistance(norm='L1')
        self.spectral_distances = nn.ModuleList(
            [spectral_distance(scale) for scale in scales])

    def forward(self, x, y):
        waveform_distance = self.waveform_distance(x, y)
        spectral_distance = 0
        for dist in self.spectral_distances:
            spectral_distance = spectral_distance + dist(x, y)

        return {
            'waveform_distance': waveform_distance,
            'spectral_distance': spectral_distance
        }


class WaveformDistance(nn.Module):

    def __init__(self, norm: str) -> None:
        super().__init__()
        self.norm = norm

    def forward(self, x, y):
        return mean_difference(y, x, self.norm)


class SpectralDistance(nn.Module):

    def __init__(
        self,
        n_fft: int,
        sampling_rate: int,
        norm: Union[str, Sequence[str]],
        power: Union[int, None],
        normalized: bool,
        mel: Optional[int] = None,
    ) -> None:
        super().__init__()
        if mel:
            self.spec = torchaudio.transforms.MelSpectrogram(
                sampling_rate,
                n_fft,
                hop_length=n_fft // 4,
                n_mels=mel,
                power=power,
                normalized=normalized,
                center=False,
                pad_mode=None,
            )
        else:
            self.spec = torchaudio.transforms.Spectrogram(
                n_fft,
                hop_length=n_fft // 4,
                power=power,
                normalized=normalized,
                center=False,
                pad_mode=None,
            )

        if isinstance(norm, str):
            norm = (norm, )
        self.norm = norm

    def forward(self, x, y):
        x = self.spec(x)
        y = self.spec(y)

        distance = 0
        for norm in self.norm:
            distance = distance + mean_difference(y, x, norm)
        return distance


class ProgressLogger(object):

    def __init__(self, name: str) -> None:
        self.env = lmdb.open("status")
        self.name = name

    def update(self, **new_state):
        current_state = self.__call__()
        with self.env.begin(write=True) as txn:
            current_state.update(new_state)
            current_state = json.dumps(current_state)
            txn.put(self.name.encode(), current_state.encode())

    def __call__(self):
        with self.env.begin(write=True) as txn:
            current_state = txn.get(self.name.encode())
        if current_state is not None:
            current_state = json.loads(current_state.decode())
        else:
            current_state = {}
        return current_state


class LoggerCallback(pl.Callback):

    def __init__(self, logger: ProgressLogger) -> None:
        super().__init__()
        self.state = {'step': 0, 'warmed': False}
        self.logger = logger

    def on_train_batch_end(self, trainer, pl_module, outputs, batch,
                           batch_idx) -> None:
        self.state['step'] += 1
        self.state['warmed'] = pl_module.warmed_up

        if not self.state['step'] % 100:
            self.logger.update(**self.state)

    def state_dict(self):
        return self.state.copy()

    def load_state_dict(self, state_dict):
        self.state.update(state_dict)


In [10]:
# import rave.dataset
def get_derivator_integrator(sr: int):
    alpha = 1 / (1 + 1 / sr * 2 * np.pi * 10)
    derivator = ([.5, -.5], [1])
    integrator = ([alpha**2, -alpha**2], [1, -2 * alpha, alpha**2])

    return lambda x: lfilter(*derivator, x), lambda x: lfilter(*integrator, x)


class AudioDataset(data.Dataset):

    @property
    def env(self) -> lmdb.Environment:
        if self._env is None:
            self._env = lmdb.open(self._db_path, lock=False)
        return self._env

    @property
    def keys(self) -> Sequence[str]:
        if self._keys is None:
            with self.env.begin() as txn:
                self._keys = list(txn.cursor().iternext(values=False))
        return self._keys

    def __init__(self,
                 db_path: str,
                 audio_key: str = 'waveform',
                 transforms: Optional[transforms.Transform] = None) -> None:
        super().__init__()
        self._db_path = db_path
        self._audio_key = audio_key
        self._env = None
        self._keys = None
        self._transforms = transforms

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, index):
        with self.env.begin() as txn:
            ae = AudioExample.FromString(txn.get(self.keys[index]))

        buffer = ae.buffers[self._audio_key]
        assert buffer.precision == AudioExample.Precision.INT16

        audio = np.frombuffer(buffer.data, dtype=np.int16)
        audio = audio.astype(np.float32) / (2**15 - 1)

        if self._transforms is not None:
            audio = self._transforms(audio)

        return audio


class LazyAudioDataset(data.Dataset):

    @property
    def env(self) -> lmdb.Environment:
        if self._env is None:
            self._env = lmdb.open(self._db_path, lock=False)
        return self._env

    @property
    def keys(self) -> Sequence[str]:
        if self._keys is None:
            with self.env.begin() as txn:
                self._keys = list(txn.cursor().iternext(values=False))
        return self._keys

    def __init__(self,
                 db_path: str,
                 n_signal: int,
                 sampling_rate: int,
                 transforms: Optional[transforms.Transform] = None) -> None:
        super().__init__()
        self._db_path = db_path
        self._env = None
        self._keys = None
        self._transforms = transforms
        self._n_signal = n_signal
        self._sampling_rate = sampling_rate

        self.parse_dataset()

    def parse_dataset(self):
        items = []
        for key in tqdm(self.keys, desc='Discovering dataset'):
            with self.env.begin() as txn:
                ae = AudioExample.FromString(txn.get(key))
            length = float(ae.metadata['length'])
            n_signal = int(math.floor(length * self._sampling_rate))
            n_chunks = n_signal // self._n_signal
            items.append(n_chunks)
        items = np.asarray(items)
        items = np.cumsum(items)
        self.items = items

    def __len__(self):
        return self.items[-1]

    def __getitem__(self, index):
        audio_id = np.where(index < self.items)[0][0]
        if audio_id:
            index -= self.items[audio_id - 1]

        key = self.keys[audio_id]

        with self.env.begin() as txn:
            ae = AudioExample.FromString(txn.get(key))

        audio = extract_audio(
            ae.metadata['path'],
            self._n_signal,
            self._sampling_rate,
            index * self._n_signal,
        )

        if self._transforms is not None:
            audio = self._transforms(audio)

        return audio


class HTTPAudioDataset(data.Dataset):

    def __init__(self, db_path: str):
        super().__init__()
        self.db_path = db_path
        logging.info("starting remote dataset session")
        self.length = int(requests.get("/".join([db_path, "len"])).text)
        logging.info("connection established !")

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        example = requests.get("/".join([
            self.db_path,
            "get",
            f"{index}",
        ])).text
        example = AudioExampleWrapper(base64.b64decode(example)).get("audio")
        return example.copy()


def normalize_signal(x: np.ndarray, max_gain_db: int = 30):
    peak = np.max(abs(x))
    if peak == 0: return x

    log_peak = 20 * np.log10(peak)
    log_gain = min(max_gain_db, -log_peak)
    gain = 10**(log_gain / 20)

    return x * gain


def get_dataset(db_path,
                sr,
                n_signal,
                derivative: bool = False,
                normalize: bool = False):
    if db_path[:4] == "http":
        return HTTPAudioDataset(db_path=db_path)
    with open(os.path.join(db_path, 'metadata.yaml'), 'r') as metadata:
        metadata = yaml.safe_load(metadata)
    lazy = metadata['lazy']

    transform_list = [
        lambda x: x.astype(np.float32),
        transforms.RandomCrop(n_signal),
        transforms.RandomApply(
            lambda x: random_phase_mangle(x, 20, 2000, .99, sr),
            p=.8,
        ),
        transforms.Dequantize(16),
    ]

    if normalize:
        transform_list.append(normalize_signal)

    if derivative:
        transform_list.append(get_derivator_integrator(sr)[0])

    transform_list.append(lambda x: x.astype(np.float32))

    transform_list = transforms.Compose(transform_list)

    if lazy:
        return LazyAudioDataset(db_path, n_signal, sr, transform_list)
    else:
        return AudioDataset(
            db_path,
            transforms=transform_list,
        )


@gin.configurable
def split_dataset(dataset, percent, max_residual: Optional[int] = None):
    split1 = max((percent * len(dataset)) // 100, 1)
    split2 = len(dataset) - split1
    if max_residual is not None:
        split2 = min(max_residual, split2)
        split1 = len(dataset) - split2
    split1, split2 = data.random_split(
        dataset,
        [split1, split2],
        generator=torch.Generator().manual_seed(42),
    )
    return split1, split2


def random_angle(min_f=20, max_f=8000, sr=24000):
    min_f = np.log(min_f)
    max_f = np.log(max_f)
    rand = np.exp(random() * (max_f - min_f) + min_f)
    rand = 2 * np.pi * rand / sr
    return rand


def pole_to_z_filter(omega, amplitude=.9):
    z0 = amplitude * np.exp(1j * omega)
    a = [1, -2 * np.real(z0), abs(z0)**2]
    b = [abs(z0)**2, -2 * np.real(z0), 1]
    return b, a


def random_phase_mangle(x, min_f, max_f, amp, sr):
    angle = random_angle(min_f, max_f, sr)
    b, a = pole_to_z_filter(angle, amp)
    return lfilter(b, a, x)


def extract_audio(path: str, n_signal: int, sr: int,
                  start_sample: int) -> Iterable[np.ndarray]:
    start_sec = start_sample / sr
    length = n_signal / sr + 0.1
    process = subprocess.Popen(
        [
            'ffmpeg',
            '-v',
            'error',
            '-ss',
            str(start_sec),
            '-i',
            path,
            '-ar',
            str(sr),
            '-ac',
            '1',
            '-t',
            str(length),
            '-f',
            's16le',
            '-',
        ],
        stdout=subprocess.PIPE,
    )

    chunk = process.communicate()[0]

    chunk = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 2**15
    chunk = np.concatenate([chunk, np.zeros(n_signal)], -1)
    return chunk[:n_signal]


In [11]:
import torch, time, gc

# Timing utilities
start_time = None

def start_timer():
    global start_time
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.synchronize()
    start_time = time.time()

def end_timer_and_print(local_msg):
    torch.cuda.synchronize()
    end_time = time.time()
    print("\n" + local_msg)
    print("Total execution time = {:.3f} sec".format(end_time - start_time))
    print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))

In [12]:
NAME = "compile"
CONFIG = ["v2"]
DB_PATH = "/home/ubuntu/preprocessed/"
MAX_STEPS = 1
VAL_EVERY = 1
N_SIGNAL = 131072
BATCH = 8
ckpt = None
OVERRIDE = []
WORKERS = 8
GPU = None
DERIVATIVE = False
NORMALIZE = False
PROGRESS = True

In [13]:
def add_gin_extension(config_name: str) -> str:
    if config_name[-4:] != '.gin':
        config_name += '.gin'
    return config_name

In [14]:
def setup():
    torch.backends.cudnn.benchmark = True
    gin.parse_config_files_and_bindings(
        map(add_gin_extension, CONFIG),
        OVERRIDE,
    )
    model = RAVE()
    # model = torch.compile(model, mode="reduce-overhead"); torch._dynamo.config.verbose=True
    if DERIVATIVE:
        model.integrator = get_derivator_integrator(model.sr)[1]

    dataset = get_dataset(
        DB_PATH, model.sr, N_SIGNAL, derivative=DERIVATIVE, normalize=NORMALIZE
    )
    train, val = split_dataset(dataset, 98)
    num_workers = WORKERS

    if os.name == "nt" or sys.platform == "darwin":
        num_workers = 0

    train = DataLoader(
        train, BATCH, True, drop_last=True, num_workers=num_workers
    )
    val = DataLoader(val, BATCH, False, num_workers=num_workers)

    # CHECKPOINT CALLBACKS
    validation_checkpoint = pl.callbacks.ModelCheckpoint(
        monitor="validation", filename="best")
    last_checkpoint = pl.callbacks.ModelCheckpoint(filename="last")
    val_check = {}
    if len(train) >= VAL_EVERY:
        val_check["val_check_interval"] = VAL_EVERY
    else:
        nepoch = VAL_EVERY // len(train)
        val_check["check_val_every_n_epoch"] = nepoch
    gin_hash = hashlib.md5(
        gin.operative_config_str().encode()).hexdigest()[:10]
    RUN_NAME = f'{NAME}_{gin_hash}'
    os.makedirs(os.path.join("runs", RUN_NAME), exist_ok=True)
    if GPU == [-1]:
        gpu = 0
    else:
        gpu = GPU or setup_gpu()
    print('selected gpu:', gpu)
    accelerator = None
    devices = None
    if GPU == [-1]:
        pass
    elif torch.cuda.is_available():
        accelerator = "cuda"
        devices = GPU or setup_gpu()
    elif torch.backends.mps.is_available():
        print(
            "Training on mac is not available yet. Use --gpu -1 to train on CPU (not recommended)."
        )
        exit()
        accelerator = "mps"
        devices = 1
    
    trainer = pl.Trainer(
        logger=pl.loggers.TensorBoardLogger(
            "runs",
            name=RUN_NAME,
        ),
        accelerator=accelerator,
        devices=devices,
        callbacks=[
            validation_checkpoint,
            last_checkpoint,
            WarmupCallback(),
            QuantizeCallback(),
            LoggerCallback(ProgressLogger(RUN_NAME)),
        ],
        max_epochs=100000,
        max_steps=MAX_STEPS,
        profiler="simple",
        enable_progress_bar=PROGRESS,
        **val_check,
    )
    run = search_for_run(ckpt)
    if run is not None:
        step = torch.load(run, map_location='cpu')["global_step"]
        trainer.fit_loop.epoch_loop._batches_that_stepped = step

    with open(os.path.join("runs", RUN_NAME, "config.gin"), "w") as config_out:
        config_out.write(gin.operative_config_str())

    return trainer, model, train, val, run

In [15]:
trainer, model, train, val, run = setup()

ERROR:root:Path not found: v2.gin
ERROR:root:Path not found: /home/ubuntu/rave-training/rave/v2.gin
ERROR:root:Path not found: configs/v1.gin


TypeError: RAVE.__init__() missing 13 required positional arguments: 'latent_size', 'sampling_rate', 'encoder', 'decoder', 'discriminator', 'phase_1_duration', 'gan_loss', 'valid_signal_crop', 'feature_matching_fun', 'num_skipped_features', 'audio_distance', 'multiband_audio_distance', and 'balancer'
  No values supplied by Gin or caller for arguments: ['audio_distance', 'balancer', 'decoder', 'discriminator', 'encoder', 'feature_matching_fun', 'gan_loss', 'latent_size', 'multiband_audio_distance', 'num_skipped_features', 'phase_1_duration', 'sampling_rate', 'valid_signal_crop']
  Gin had values bound for: []
  Caller supplied values for: ['self']
  In call to configurable 'RAVE' (<class '__main__.RAVE'>)

In [16]:
trainer.fit(model, train, val, ckpt_path=run)