In [1]:
pip install torch torchaudio soundfile TorchCodec

Collecting TorchCodec
  Downloading torchcodec-0.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (11 kB)
Downloading torchcodec-0.9.1-cp312-cp312-manylinux_2_28_x86_64.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m36.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: TorchCodec
Successfully installed TorchCodec-0.9.1


In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchcodec

In [3]:
class SingleFileDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, sample_rate=16000, segment_duration=1.0):
        self.sample_rate = sample_rate
        self.num_samples = int(sample_rate * segment_duration)

        wav, sr = torchaudio.load(file_path)

        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)

        if sr != sample_rate:
            wav = torchaudio.functional.resample(wav, sr, sample_rate)

        self.segments = []
        for start in range(0, wav.shape[1], self.num_samples):
            seg = wav[:, start:start + self.num_samples]
            if seg.shape[1] < self.num_samples:
                seg = F.pad(seg, (0, self.num_samples - seg.shape[1]))
            self.segments.append(seg)

    def __len__(self):
        return len(self.segments)

    def __getitem__(self, idx):
        return self.segments[idx]


In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 64, 7, stride=2, padding=3),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, 7, stride=2, padding=3),
            nn.LeakyReLU(0.2),
            nn.Conv1d(128, 16, 7, stride=2, padding=3),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.net(x)


In [5]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_codes=512, code_dim=16, beta=0.25):
        super().__init__()
        self.code_dim = code_dim
        self.beta = beta
        self.codebook = nn.Embedding(num_codes, code_dim)
        self.codebook.weight.data.uniform_(-1 / num_codes, 1 / num_codes)

    def forward(self, z):
        z_perm = z.permute(0, 2, 1).contiguous()
        flat_z = z_perm.view(-1, self.code_dim)

        distances = (
            flat_z.pow(2).sum(1, keepdim=True)
            - 2 * flat_z @ self.codebook.weight.t()
            + self.codebook.weight.pow(2).sum(1)
        )

        indices = torch.argmin(distances, dim=1)
        z_q = self.codebook(indices).view(z_perm.shape)

        commit = F.mse_loss(z_q.detach(), z_perm)
        codebook = F.mse_loss(z_q, z_perm.detach())
        loss = codebook + self.beta * commit

        z_q = z_perm + (z_q - z_perm).detach()
        return z_q.permute(0, 2, 1), loss


In [6]:
class ResidualVQ(nn.Module):
    def __init__(self, n_q=4):
        super().__init__()
        self.vqs = nn.ModuleList([VectorQuantizer() for _ in range(n_q)])

    def forward(self, z):
        residual = z
        out = 0
        loss = 0
        for vq in self.vqs:
            z_q, l = vq(residual)
            residual = residual - z_q
            out = out + z_q
            loss = loss + l
        return out, loss


In [7]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose1d(16, 128, 7, stride=2, padding=3, output_padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose1d(128, 64, 7, stride=2, padding=3, output_padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose1d(64, 1, 7, stride=2, padding=3, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.net(x)


In [8]:
class SpeechCodecGAN(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.rvq = ResidualVQ()
        self.decoder = Decoder()

    def forward(self, x):
        z = self.encoder(x)
        z_q, vq_loss = self.rvq(z)
        out = self.decoder(z_q)
        return out, vq_loss


In [9]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 32, 15, stride=1, padding=7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(32, 64, 41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, 41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(128, 256, 41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(256, 1, 5, stride=1, padding=2)
        )

    def forward(self, x):
        return self.net(x)


In [10]:
class MultiResolutionSTFTLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ffts = [512, 1024, 2048]
        self.hops = [128, 256, 512]
        self.wins = [512, 1024, 2048]

    def forward(self, x, y):
        loss = 0
        for f, h, w in zip(self.ffts, self.hops, self.wins):
            win = torch.hann_window(w).to(x.device)
            X = torch.stft(x.squeeze(1), f, h, w, win, return_complex=True)
            Y = torch.stft(y.squeeze(1), f, h, w, win, return_complex=True)
            loss += F.l1_loss(torch.abs(X), torch.abs(Y))
        return loss


In [11]:
audio_path = "/content/tts_output_2.wav"   # CHANGE
device = "cuda" if torch.cuda.is_available() else "cpu"

dataset = SingleFileDataset(audio_path)
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

G = SpeechCodecGAN().to(device)
D = Discriminator().to(device)

opt_G = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.9))
opt_D = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.9))

l1 = nn.L1Loss()
stft = MultiResolutionSTFTLoss()


In [12]:
epochs = 200

for epoch in range(epochs):
    for real in loader:
        real = real.to(device)

        # ---- Train Discriminator ----
        fake, _ = G(real)
        d_real = D(real)
        d_fake = D(fake.detach())

        loss_D = torch.mean((d_real - 1)**2) + torch.mean(d_fake**2)

        opt_D.zero_grad()
        loss_D.backward()
        opt_D.step()

        # ---- Train Generator ----
        fake, vq_loss = G(real)
        d_fake = D(fake)

        adv_loss = torch.mean((d_fake - 1)**2)
        recon = l1(fake, real)
        spec = stft(fake, real)

        loss_G = recon + 0.5 * spec + vq_loss + 0.1 * adv_loss

        opt_G.zero_grad()
        loss_G.backward()
        opt_G.step()

    print(f"Epoch {epoch+1}/{epochs} | G: {loss_G.item():.4f} | D: {loss_D.item():.4f}")


Epoch 1/200 | G: 1.5867 | D: 0.8845
Epoch 2/200 | G: 1.5196 | D: 0.4874
Epoch 3/200 | G: 1.5138 | D: 0.1349
Epoch 4/200 | G: 1.2485 | D: 0.0562
Epoch 5/200 | G: 1.2501 | D: 0.1143
Epoch 6/200 | G: 1.2598 | D: 0.0513
Epoch 7/200 | G: 1.3158 | D: 0.0180
Epoch 8/200 | G: 1.7342 | D: 0.0181
Epoch 9/200 | G: 1.7448 | D: 0.0175
Epoch 10/200 | G: 2.5156 | D: 0.0401
Epoch 11/200 | G: 2.8560 | D: 1.8459
Epoch 12/200 | G: 4.0993 | D: 0.6856
Epoch 13/200 | G: 3.6908 | D: 0.5730
Epoch 14/200 | G: 2.9932 | D: 0.5085
Epoch 15/200 | G: 2.9853 | D: 0.4290
Epoch 16/200 | G: 2.8832 | D: 0.4405
Epoch 17/200 | G: 2.4288 | D: 0.4782
Epoch 18/200 | G: 2.5848 | D: 0.4175
Epoch 19/200 | G: 2.1780 | D: 0.4855
Epoch 20/200 | G: 2.3708 | D: 0.4108
Epoch 21/200 | G: 1.9242 | D: 0.4556
Epoch 22/200 | G: 2.1596 | D: 0.4078
Epoch 23/200 | G: 1.9946 | D: 0.4243
Epoch 24/200 | G: 1.9105 | D: 0.3742
Epoch 25/200 | G: 1.5424 | D: 0.5000
Epoch 26/200 | G: 1.7926 | D: 0.4912
Epoch 27/200 | G: 1.3509 | D: 0.5337
Epoch 28/2

In [13]:
os.makedirs("outputs_gan", exist_ok=True)

G.eval()
with torch.no_grad():
    for i, seg in enumerate(dataset):
        seg = seg.unsqueeze(0).to(device)
        out, _ = G(seg)

        torchaudio.save(f"outputs_gan/original_{i}.wav", seg.cpu()[0], 16000)
        torchaudio.save(f"outputs_gan/reconstructed_{i}.wav", out.cpu()[0], 16000)


In [14]:
from IPython.display import Audio
import os

output_dir = "/content/outputs_gan"

for filename in os.listdir(output_dir):
    if filename.endswith(".wav"):
        file_path = os.path.join(output_dir, filename)
        print(f"Playing: {filename}")
        display(Audio(file_path))

Playing: original_0.wav


Playing: reconstructed_0.wav


Playing: original_1.wav


Playing: reconstructed_1.wav
