In [1]:
pip install torch torchaudio soundfile TorchCodec

Collecting TorchCodec
  Downloading torchcodec-0.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (11 kB)
Downloading torchcodec-0.9.1-cp312-cp312-manylinux_2_28_x86_64.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m81.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: TorchCodec
Successfully installed TorchCodec-0.9.1


In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchcodec

In [3]:
class SingleFileDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, sample_rate=16000, segment_duration=1.0):
        self.sample_rate = sample_rate
        self.num_samples = int(sample_rate * segment_duration)

        wav, sr = torchaudio.load(file_path)

        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)

        if sr != sample_rate:
            wav = torchaudio.functional.resample(wav, sr, sample_rate)

        self.segments = []
        for start in range(0, wav.shape[1], self.num_samples):
            seg = wav[:, start:start + self.num_samples]
            if seg.shape[1] < self.num_samples:
                seg = F.pad(seg, (0, self.num_samples - seg.shape[1]))
            self.segments.append(seg)

    def __len__(self):
        return len(self.segments)

    def __getitem__(self, idx):
        return self.segments[idx]


In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 64, 7, stride=2, padding=3),
            nn.ReLU(),
            nn.Conv1d(64, 128, 7, stride=2, padding=3),
            nn.ReLU(),
            nn.Conv1d(128, 16, 7, stride=2, padding=3),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)


In [5]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_codes=512, code_dim=16, beta=0.25):
        super().__init__()
        self.code_dim = code_dim
        self.num_codes = num_codes
        self.beta = beta

        self.codebook = nn.Embedding(num_codes, code_dim)
        self.codebook.weight.data.uniform_(-1 / num_codes, 1 / num_codes)

    def forward(self, z):
        z_perm = z.permute(0, 2, 1).contiguous()
        flat_z = z_perm.view(-1, self.code_dim)

        distances = (
            flat_z.pow(2).sum(1, keepdim=True)
            - 2 * flat_z @ self.codebook.weight.t()
            + self.codebook.weight.pow(2).sum(1)
        )

        indices = torch.argmin(distances, dim=1)
        z_q = self.codebook(indices).view(z_perm.shape)

        commit_loss = F.mse_loss(z_q.detach(), z_perm)
        codebook_loss = F.mse_loss(z_q, z_perm.detach())
        loss = codebook_loss + self.beta * commit_loss

        z_q = z_perm + (z_q - z_perm).detach()
        z_q = z_q.permute(0, 2, 1)

        return z_q, loss


In [6]:
class ResidualVQ(nn.Module):
    def __init__(self, num_quantizers=4):
        super().__init__()
        self.num_quantizers = num_quantizers
        self.quantizers = nn.ModuleList(
            [VectorQuantizer() for _ in range(num_quantizers)]
        )

    def forward(self, z):
        residual = z
        z_q_total = 0
        total_loss = 0

        for vq in self.quantizers:
            z_q, loss = vq(residual)
            residual = residual - z_q
            z_q_total = z_q_total + z_q
            total_loss = total_loss + loss

        return z_q_total, total_loss


In [7]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose1d(16, 128, 7, stride=2, padding=3, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(128, 64, 7, stride=2, padding=3, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(64, 1, 7, stride=2, padding=3, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.net(x)


In [8]:
class RVQSpeechCodec(nn.Module):
    def __init__(self, num_quantizers=4):
        super().__init__()
        self.encoder = Encoder()
        self.rvq = ResidualVQ(num_quantizers)
        self.decoder = Decoder()

    def forward(self, x):
        z = self.encoder(x)
        z_q, vq_loss = self.rvq(z)
        x_hat = self.decoder(z_q)
        return x_hat, vq_loss


In [9]:
audio_path = "/content/tts_output_2.wav"
device = "cuda" if torch.cuda.is_available() else "cpu"

dataset = SingleFileDataset(audio_path)
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

model = RVQSpeechCodec(num_quantizers=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

recon_loss_fn = nn.L1Loss()


In [10]:
epochs = 200

for epoch in range(epochs):
    total_loss = 0
    for seg in loader:
        seg = seg.to(device)

        recon, vq_loss = model(seg)
        recon_loss = recon_loss_fn(recon, seg)

        loss = recon_loss + vq_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(loader):.4f}")


Epoch 1/200 | Loss: 0.1342
Epoch 2/200 | Loss: 0.1233
Epoch 3/200 | Loss: 0.1192
Epoch 4/200 | Loss: 0.1211
Epoch 5/200 | Loss: 0.1307
Epoch 6/200 | Loss: 0.1527
Epoch 7/200 | Loss: 0.1915
Epoch 8/200 | Loss: 0.2428
Epoch 9/200 | Loss: 0.2646
Epoch 10/200 | Loss: 0.2427
Epoch 11/200 | Loss: 0.1947
Epoch 12/200 | Loss: 0.1451
Epoch 13/200 | Loss: 0.1056
Epoch 14/200 | Loss: 0.0795
Epoch 15/200 | Loss: 0.0656
Epoch 16/200 | Loss: 0.0610
Epoch 17/200 | Loss: 0.0605
Epoch 18/200 | Loss: 0.0591
Epoch 19/200 | Loss: 0.0562
Epoch 20/200 | Loss: 0.0530
Epoch 21/200 | Loss: 0.0511
Epoch 22/200 | Loss: 0.0505
Epoch 23/200 | Loss: 0.0506
Epoch 24/200 | Loss: 0.0501
Epoch 25/200 | Loss: 0.0493
Epoch 26/200 | Loss: 0.0488
Epoch 27/200 | Loss: 0.0487
Epoch 28/200 | Loss: 0.0489
Epoch 29/200 | Loss: 0.0490
Epoch 30/200 | Loss: 0.0489
Epoch 31/200 | Loss: 0.0489
Epoch 32/200 | Loss: 0.0490
Epoch 33/200 | Loss: 0.0494
Epoch 34/200 | Loss: 0.0496
Epoch 35/200 | Loss: 0.0498
Epoch 36/200 | Loss: 0.0499
E

In [15]:
os.makedirs("outputs_rvq", exist_ok=True)

model.eval()
with torch.no_grad():
    for i, seg in enumerate(dataset):
        seg = seg.unsqueeze(0).to(device)
        recon, _ = model(seg)

        torchaudio.save(f"outputs_rvq/original_{i}.wav", seg.cpu()[0], 16000)
        torchaudio.save(f"outputs_rvq/reconstructed_{i}.wav", recon.cpu()[0], 16000)

In [16]:
from IPython.display import Audio
import os

output_dir = "/content/outputs_rvq"

for filename in os.listdir(output_dir):
    if filename.endswith(".wav"):
        file_path = os.path.join(output_dir, filename)
        print(f"Playing: {filename}")
        display(Audio(file_path))

Playing: original_0.wav


Playing: reconstructed_0.wav


Playing: original_1.wav


Playing: reconstructed_1.wav
