# RVQ Breakdown

In [1]:
import torch
import math

## Input

1 second audio file at a 44.1kHz sample rate.

Batch size of 1 (first dim), 1 channel mono audio (second dim) and exactly 2 seconds of audio.

In [2]:
batch_size = 1
channels = 1
audio_length = 1 * 44100
audio_batch = torch.randn(batch_size, channels, audio_length).to("cpu")

In [3]:
audio_batch.shape

torch.Size([1, 1, 44100])

## Encoder

In [4]:
from torch.nn.utils import weight_norm
import torch.nn as nn
from dac.nn.layers import Snake1d

WNConv1d = lambda *args, **kwargs: weight_norm(nn.Conv1d(*args, **kwargs))

First convolution

In [5]:
d_model = 64
block1 = WNConv1d(1, d_model, kernel_size=7, padding=3)



In [6]:
block1_out = block1(audio_batch)

In [7]:
block1_out.shape

torch.Size([1, 64, 44100])

Encoder blocks

In [8]:
class ResidualUnit(nn.Module):
    def __init__(self, dim: int = 16, dilation: int = 1):
        super().__init__()
        pad = ((7 - 1) * dilation) // 2
        self.block = nn.Sequential(
            Snake1d(dim),
            WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
            Snake1d(dim),
            WNConv1d(dim, dim, kernel_size=1),
        )

    def forward(self, x):
        y = self.block(x)
        pad = (x.shape[-1] - y.shape[-1]) // 2
        if pad > 0:
            x = x[..., pad:-pad]
        return x + y


class EncoderBlock(nn.Module):
    def __init__(self, dim: int = 16, stride: int = 1):
        super().__init__()
        self.block = nn.Sequential(
            ResidualUnit(dim // 2, dilation=1),
            ResidualUnit(dim // 2, dilation=3),
            ResidualUnit(dim // 2, dilation=9),
            Snake1d(dim // 2),
            WNConv1d(
                dim // 2,
                dim,
                kernel_size=2 * stride,
                stride=stride,
                padding=math.ceil(stride / 2),
            ),
        )

    def forward(self, x):
        return self.block(x)

In [9]:
encblock1 = EncoderBlock(d_model*2, stride=2)

Audio is downsampled, while number of channels is doubled.

In [10]:
enc1out = encblock1(block1_out)

In [11]:
enc1out.shape

torch.Size([1, 128, 22050])

In [12]:
encblock2 = EncoderBlock(d_model*2*2, stride=4)

In [13]:
enc2out = encblock2(enc1out)

In [14]:
enc2out.shape

torch.Size([1, 256, 5512])

In [15]:
encblock3 = EncoderBlock(d_model*2*2*2, stride=8)

In [16]:
encblock4 = EncoderBlock(d_model*2*2*2*2, stride=8)

In [17]:
x = encblock3(enc2out)
x.shape

torch.Size([1, 512, 689])

In [18]:
x = encblock4(x)
x.shape

torch.Size([1, 1024, 86])

In [19]:
lastact = Snake1d(d_model*2*2*2*2)

In [20]:
x = lastact(x)

In [21]:
x.shape

torch.Size([1, 1024, 86])

Last conv takes that downsampled collection and projects back into the latnt size of 64.

In [22]:
encout = WNConv1d(d_model*2*2*2*2, 64, kernel_size=3, padding=1)

In [23]:
out = encout(x)

In [24]:
out.shape

torch.Size([1, 64, 86])

So now we have an encoding of the audio. For 1 seconds of audio, we have 86 frames.

## Vector Quantize

Let's start by exploring basic vector quantising.

In [25]:
from dac.nn.quantize import VectorQuantize

In [26]:
vec_quantize = VectorQuantize(64, 1024, 8)

In [39]:
vec_quantize.codebook.weight.shape

torch.Size([1024, 8])

In [27]:
z_q, commitment_loss, codebook_loss, indices, z_e = vec_quantize(out)

In [29]:
indices.shape

torch.Size([1, 86])

In [28]:
indices

tensor([[326, 939, 428, 993, 557, 240, 745, 118, 468, 428,  78, 847, 444, 468,
         542, 660,  23, 428, 224, 993, 110, 548, 727, 212, 428, 847, 418, 750,
         280, 993, 685, 589, 550, 468, 479, 260, 770, 321, 619, 224, 260, 877,
         939, 428, 446, 260, 869, 554, 260, 849, 855,  26, 260, 848, 524, 847,
         997, 908, 226, 750, 826, 670, 847,  26, 554, 555, 750,  61,  36, 939,
         727, 316, 325, 849, 118,  89, 295, 595, 446, 555, 847, 997, 769, 212,
         428, 914]])

## Residual Vector Quantize

In [45]:
from dac.nn.quantize import ResidualVectorQuantize

In [46]:
quantizer = ResidualVectorQuantize(
    input_dim=64,
    n_codebooks=9,
    codebook_dim=8,
    quantizer_dropout=False
)

From the quantizer, we get 5 outputs:
- z: quantized continuous representation of input
- codes: codebook indicies for each codebook - this is the quantized discrete representation.
- latents: projected latents (continuous representation of input before quantization.
- commitment_loss - committment loss to train encoder to predict vectors closer to codebook.
- codebook_loss - codebook loss to update the codebook.

In [49]:
z, codes, latents, commitment_loss, codebook_loss = quantizer(out)

In [50]:
z.shape

torch.Size([1, 64, 172])

In [51]:
codes.shape

torch.Size([1, 9, 172])

In [56]:
codes[:,[0],]

tensor([[[ 101,  764,  495,  507,  187,  541,  661,  115,  187,  187,  708,
           494,  115,  187,  494,  928,  381,  639,  494,  541,  541,  516,
           928,  354,  115,  187,  494,  507,  106,  187,  708,  708,  661,
           629,  495,  740,  507,   19,  774,  928,  393,  115,  789,  347,
           740,  494,  115,  928,  886,  928,  187,  101,   46,  187,  187,
            68,  928,  642,  928,  393,  507,  187,  740,  187,  557,  187,
           507,  784,  502,   19,  187,  187,  187,  719,  834,  187,  106,
           187,  784,  928,  546,   46,  129,  507,  541,  347,  295,  928,
            46,  187,  494,  507,  516,  115,  708,  740,  347,  115,  115,
           187,  642,  766,  187,  115,  187,  928,  928,  928,  494,   97,
           834,  987,  187,  541,  494,  928,  912,  642,  101,  129,  642,
            46,  187,  784,  115,  642,  494,  115,  495,  240,  354,  507,
           347,  495,  507,  886,  494,  347,  600,  187,  494,  507,  187,
           9

In [57]:
latents.shape

torch.Size([1, 72, 172])

In [59]:
commitment_loss, codebook_loss

(tensor(3.7737, grad_fn=<AddBackward0>),
 tensor(3.7737, grad_fn=<AddBackward0>))