In [1]:
pip install torch torchaudio soundfile



In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio


In [3]:
class SingleFileDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, sample_rate=16000, segment_duration=1.0):
        self.sample_rate = sample_rate
        self.num_samples = int(sample_rate * segment_duration)

        wav, sr = torchaudio.load(file_path)

        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)

        if sr != sample_rate:
            wav = torchaudio.functional.resample(wav, sr, sample_rate)

        self.segments = []
        for start in range(0, wav.shape[1], self.num_samples):
            seg = wav[:, start:start + self.num_samples]
            if seg.shape[1] < self.num_samples:
                seg = F.pad(seg, (0, self.num_samples - seg.shape[1]))
            self.segments.append(seg)

    def __len__(self):
        return len(self.segments)

    def __getitem__(self, idx):
        return self.segments[idx]


In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),

            nn.Conv1d(64, 128, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),

            nn.Conv1d(128, 32, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)


In [5]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose1d(32, 128, kernel_size=7, stride=2, padding=3, output_padding=1),
            nn.ReLU(),

            nn.ConvTranspose1d(128, 64, kernel_size=7, stride=2, padding=3, output_padding=1),
            nn.ReLU(),

            nn.ConvTranspose1d(64, 1, kernel_size=7, stride=2, padding=3, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.net(x)


In [6]:
!pip install TorchCodec
import torchcodec



In [7]:
class BottleneckSpeechAutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)


In [8]:
audio_path = "/content/tts_output_2.wav"

device = "cuda" if torch.cuda.is_available() else "cpu"

dataset = SingleFileDataset(audio_path)
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

model = BottleneckSpeechAutoEncoder().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
criterion = nn.L1Loss()


In [9]:
epochs = 30

for epoch in range(epochs):
    total_loss = 0
    for seg in loader:
        seg = seg.to(device)

        recon = model(seg)
        loss = criterion(recon, seg)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} | L1 Loss: {total_loss/len(loader):.5f}")


Epoch 1/30 | L1 Loss: 0.09717
Epoch 2/30 | L1 Loss: 0.07740
Epoch 3/30 | L1 Loss: 0.06178
Epoch 4/30 | L1 Loss: 0.05532
Epoch 5/30 | L1 Loss: 0.05319
Epoch 6/30 | L1 Loss: 0.05391
Epoch 7/30 | L1 Loss: 0.05261
Epoch 8/30 | L1 Loss: 0.04990
Epoch 9/30 | L1 Loss: 0.04936
Epoch 10/30 | L1 Loss: 0.04959
Epoch 11/30 | L1 Loss: 0.04902
Epoch 12/30 | L1 Loss: 0.04763
Epoch 13/30 | L1 Loss: 0.04766
Epoch 14/30 | L1 Loss: 0.04762
Epoch 15/30 | L1 Loss: 0.04645
Epoch 16/30 | L1 Loss: 0.04602
Epoch 17/30 | L1 Loss: 0.04593
Epoch 18/30 | L1 Loss: 0.04476
Epoch 19/30 | L1 Loss: 0.04402
Epoch 20/30 | L1 Loss: 0.04318
Epoch 21/30 | L1 Loss: 0.04180
Epoch 22/30 | L1 Loss: 0.04035
Epoch 23/30 | L1 Loss: 0.03834
Epoch 24/30 | L1 Loss: 0.03557
Epoch 25/30 | L1 Loss: 0.03342
Epoch 26/30 | L1 Loss: 0.03025
Epoch 27/30 | L1 Loss: 0.02765
Epoch 28/30 | L1 Loss: 0.02454
Epoch 29/30 | L1 Loss: 0.02236
Epoch 30/30 | L1 Loss: 0.02197


In [10]:
os.makedirs("outputs_bottleneck", exist_ok=True)

model.eval()
with torch.no_grad():
    for i, seg in enumerate(dataset):
        seg = seg.unsqueeze(0).to(device)
        recon = model(seg)

        torchaudio.save(f"outputs_bottleneck/original_{i}.wav", seg.cpu()[0], 16000)
        torchaudio.save(f"outputs_bottleneck/reconstructed_{i}.wav", recon.cpu()[0], 16000)


In [11]:
import os
from IPython.display import Audio, display

output_dir = "/content/outputs_bottleneck"

# List all files in the directory
files = sorted(os.listdir(output_dir))

for file_name in files:
    if file_name.endswith('.wav'):
        file_path = os.path.join(output_dir, file_name)
        print(f"Playing: {file_name}")
        display(Audio(file_path))

Playing: original_0.wav


Playing: original_1.wav


Playing: reconstructed_0.wav


Playing: reconstructed_1.wav
