In [30]:
import sys
import matplotlib.pyplot as plt
import torch
torch.backends.cudnn.benchmark = True
import numpy as np
import lightning as lt
torch.set_float32_matmul_precision('medium')
import soundfile as sf
from torch.utils.data import DataLoader, Dataset
import h5py
from tqdm.auto import tqdm
import dac
from audiotools import AudioSignal
import torchaudio
from torch import nn, optim
import torch.nn.functional as F



class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, n_embeddings, embedding_dim):
        super(Encoder, self).__init__()
        self.dropout = nn.Dropout(p=0.2)
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(hidden_channels, embedding_dim, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        # x = self.dropout(F.relu(self.conv1(x)))
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x


class Decoder(nn.Module):
    def __init__(self, embedding_dim, hidden_channels, out_channels):
        super(Decoder, self).__init__()
        self.dropout = nn.Dropout(p=0.2)
        self.conv1 = nn.ConvTranspose2d(embedding_dim, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(hidden_channels, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.ConvTranspose2d(hidden_channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        # x = self.dropout(F.relu(self.conv1(x)))
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.sigmoid(self.conv3(x))
        return x


class VectorQuantizer(nn.Module):
    def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.n_embeddings = n_embeddings
        self.embeddings = nn.Embedding(n_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1 / n_embeddings, 1 / n_embeddings)
        self.commitment_cost = commitment_cost

    def forward(self, inputs):
        flat_input = inputs.view(-1, self.embedding_dim)
        distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
                     + torch.sum(self.embeddings.weight ** 2, dim=1)
                     - 2 * torch.matmul(flat_input, self.embeddings.weight.t()))

        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        quantized = torch.index_select(self.embeddings.weight, 0, encoding_indices.view(-1))
        quantized = quantized.view(inputs.shape)

        # Commitment Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()
        return quantized, loss, encoding_indices


class VQVAE2(nn.Module):
    def __init__(self, in_channels=1, hidden_channels=64, n_embeddings=512, embedding_dim=64):
        super(VQVAE2, self).__init__()
        self.encoder = Encoder(in_channels, hidden_channels, n_embeddings, embedding_dim)
        self.vq = VectorQuantizer(n_embeddings, embedding_dim)
        self.decoder = Decoder(embedding_dim, hidden_channels, in_channels)

    def forward(self, x):
        z = self.encoder(x)
        quantized, vq_loss, _ = self.vq(z)
        x_recon = self.decoder(quantized)
        x_recon = F.interpolate(x_recon, size=x.shape[-2:], mode='bilinear', align_corners=False)
        return x_recon, vq_loss



class SnippetDatasetHDF(Dataset):
    def __init__(self, hdf, scaling='minmax'):
        self.num_rows = 0
        self.size = int(3.4 * 24000)  # fixed size for samples
        self.scaling = scaling
        self.data = self.create_data(hdf)
        
        if scaling == 'standard':
            self.mean = self.data.mean()
            self.std =  self.data.std()
            self.data = (self.data - self.mean) / self.std
        
        elif scaling == 'minmax':
            self.min = self.data.min()
            self.max = self.data.max()
            self.data = (self.data - self.min) / (self.max - self.min)
    
    def __len__(self):
        return self.num_rows
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def create_data(self, hdf):
        data = []
        keys = list(hdf.keys())
        self.num_rows = len(keys)
        for key in tqdm(keys):
            sample = hdf[key]['audio'][:]
            if len(sample) > self.size:
                self.num_rows -= 1
                continue

            if len(sample) < self.size:
                sample = np.pad(sample, (0, self.size - len(sample)), 'constant')

            data.append(sample)
         
        return torch.tensor(np.array(data)).float()

    def retransform(self, data):
        if self.scaling == 'standard':
            return data * self.std + self.mean
        elif self.scaling == 'minmax':
            return data * (self.max - self.min) + self.min

hdf = h5py.File('../XCM.hdf5', 'r')
dataset = SnippetDatasetHDF(hdf)
hdf.close()

dac_model_path = dac.utils.download(model_type='24kHz')
dac_model = dac.DAC.load(dac_model_path)

def generate_latents(dataset, model):
    latents_list = []
    clen = len(dataset)
    clen = 2
    for i in range(clen):
        print(i)
        signal = AudioSignal(dataset.retransform(dataset[i]), sample_rate=24000)
        wav_dac = model.preprocess(signal.audio_data, signal.sample_rate)
        z, codes, latents, _, _ = model.encode(wav_dac)
        latents = torch.nn.functional.pad(latents, (0, 1))  
        latents_list.append(latents)
    
    return latents_list

latents_list = generate_latents(dataset, dac_model)

def train_vqvae2(model, dataloader, epochs=10, lr=1e-4):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    for epoch in range(epochs):
        running_loss = 0.0
        for latents in dataloader:
            optimizer.zero_grad()
            outputs, vq_loss = model(latents)
            loss = criterion(outputs, latents) + vq_loss
            loss.backward(retain_graph=True)
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(dataloader)}')


for idx, latent in enumerate(latents_list):
    print(f"Latent {idx} shape: {latent.shape}")
latents_tensor = torch.stack(latents_list)
    
import torch
from torch.utils.data import Dataset, DataLoader

class LatentDataset(Dataset):
    def __init__(self, latents_tensor):
        self.latents_tensor = latents_tensor

    def __len__(self):
        return len(self.latents_tensor)

    def __getitem__(self, idx):
        return self.latents_tensor[idx]

latent_dataset = LatentDataset(latents_tensor)
train_loader = DataLoader(latent_dataset, batch_size=8, shuffle=True)

vqvae2_model = VQVAE2()
train_vqvae2(vqvae2_model, train_loader)

def generate_samples(model, num_samples=5):
    model.eval() 
    with torch.no_grad():
        encoding_indices = torch.randint(0, model.vq.n_embeddings, (num_samples, 1)).to(next(model.parameters()).device)
        quantized_latents = model.vq.embeddings(encoding_indices).view(num_samples, model.vq.embedding_dim, 1, 1)
        generated_samples = model.decoder(quantized_latents)
        generated_samples = F.interpolate(generated_samples, size=(1, 24000), mode='bilinear', align_corners=False)  # Example size
        return generated_samples

generated_audio = generate_samples(vqvae2_model, num_samples=5)

for i, sample in enumerate(generated_audio):
    sample_np = sample.squeeze(0).cpu().numpy()
    sample_np = sample_np.squeeze(0)
    sf.write(f'generated_sample.wav', sample_np, 24000)




  0%|          | 0/6675 [00:00<?, ?it/s]

0
1
Latent 0 shape: torch.Size([1, 256, 256])
Latent 1 shape: torch.Size([1, 256, 256])
Epoch 1/10, Loss: 2.8837718963623047
Epoch 2/10, Loss: 2.8790552616119385
Epoch 3/10, Loss: 2.8757424354553223
Epoch 4/10, Loss: 2.8735742568969727
Epoch 5/10, Loss: 2.8723886013031006
Epoch 6/10, Loss: 2.8720521926879883
Epoch 7/10, Loss: 2.87251353263855
Epoch 8/10, Loss: 2.8737995624542236
Epoch 9/10, Loss: 2.875980854034424
Epoch 10/10, Loss: 2.87921404838562


LibsndfileError: Error opening 'generated_sample.wav': Format not recognised.