In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm
from  stft.spectrograms import convert_npy_to_audio
from scipy.interpolate import interp2d

import logging

# Configure logging
logging.basicConfig(level=logging.DEBUG)

In [3]:
# Function to normalize and preprocess the magnitude and phase
def preprocess_spectrograms(magnitude_path):
    # Load the magnitude and phase
    magnitude = np.load(magnitude_path)
    phase = np.load(magnitude_path.replace("magnitude", "phase"))

    # Original shape
    original_shape = magnitude.shape

    # Normalize magnitude
    scaler = MinMaxScaler()
    magnitude = scaler.fit_transform(magnitude.reshape(-1, magnitude.shape[-1])).reshape(magnitude.shape)

    # Normalize phase
    phase = (phase + np.pi) / (2 * np.pi)  # Normalize to [0, 1]

    # Concatenate magnitude and phase into a 2-channel image
    spectrogram_image = np.stack([magnitude, phase], axis=0)
    return spectrogram_image, scaler, original_shape

def denormalize_spectrograms(spectrogram_image, scaler, original_shape):
    # Separate magnitude and phase
    magnitude = spectrogram_image[0]
    phase = spectrogram_image[1]

    # Get the original shape
    original_height, original_width = original_shape
    current_height, current_width = magnitude.shape

    # Check if magnitude is 2D or 3D
    if magnitude.ndim == 2:
        magnitude = np.expand_dims(magnitude, axis=-1)  # Add a dummy channel dimension
        phase = np.expand_dims(phase, axis=-1)

    # Rescale magnitude and phase to the original dimensions
    x = np.arange(current_width)
    y = np.arange(current_height)

    magnitude_interp = interp2d(x, y, magnitude[:, :, 0], kind='cubic')
    phase_interp = interp2d(x, y, phase[:, :, 0], kind='cubic')

    new_x = np.linspace(0, current_width - 1, original_width)
    new_y = np.linspace(0, current_height - 1, original_height)

    magnitude_rescaled = magnitude_interp(new_x, new_y)
    phase_rescaled = phase_interp(new_x, new_y)

    # Check if the scaler's shape matches the magnitude data
    if magnitude_rescaled.shape[-1] != scaler.data_max_.shape[0]:
        raise ValueError("Scaler shape mismatch with magnitude data")

    # Reshape magnitude to 2D for scaling
    magnitude_rescaled_reshaped = magnitude_rescaled.reshape(-1, magnitude_rescaled.shape[-1])

    # Denormalize magnitude
    magnitude_denorm = scaler.inverse_transform(magnitude_rescaled_reshaped).reshape(original_height, original_width, -1)

    # Denormalize phase
    phase_denorm = phase_rescaled * 2 * np.pi - np.pi

    # If phase is 2D, add a channel dimension for consistency
    if phase_denorm.ndim == 2:
        phase_denorm = np.expand_dims(phase_denorm, axis=-1)
        
    # Squeeze la tercera dimensión si es de tamaño 1
    magnitude_denorm = np.squeeze(magnitude_denorm, axis=-1) if magnitude_denorm.shape[-1] == 1 else magnitude_denorm
    phase_denorm = np.squeeze(phase_denorm, axis=-1) if phase_denorm.shape[-1] == 1 else phase_denorm

    return magnitude_denorm, phase_denorm

class ConvVAE(nn.Module):
    def __init__(self, in_channels, latent_dim, img_height, img_width, device, debug=False):
        super(ConvVAE, self).__init__()
        self.latent_dim = latent_dim
        self.img_height = img_height
        self.img_width = img_width
        self.device = device
        self.debug = debug

        # Encoder
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=(4, 8), stride=(2, 4), padding=1)  # Output: (32, img_height//2, img_width//4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(4, 8), stride=(2, 4), padding=1)          # Output: (64, img_height//4, img_width//16)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=(4, 8), stride=(2, 4), padding=1)         # Output: (128, img_height//8, img_width//64)

        # Compute the size of the flattened layer
        flattened_size = 128 * (img_height // 8) * (img_width // 64)
        self.fc_mean = nn.Linear(flattened_size, latent_dim)
        self.fc_log_var = nn.Linear(flattened_size, latent_dim)

        # Decoder
        self.fc_decode = nn.Linear(latent_dim, flattened_size)
        self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=(4, 8), stride=(2, 4), padding=1)  # Output: (64, img_height//4, img_width//16)
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=(4, 8), stride=(2, 4), padding=1)   # Output: (32, img_height//2, img_width//4)
        self.deconv3 = nn.ConvTranspose2d(32, in_channels, kernel_size=(4, 8), stride=(2, 4), padding=1)  # Output: (in_channels, img_height, img_width)

    def encode(self, x):
        logging.debug(f"Input shape to encoder: {x.shape}") if self.debug else None
        e = F.relu(self.conv1(x))
        logging.debug(f"Shape after conv1: {e.shape}") if self.debug else None
        e = F.relu(self.conv2(e))
        logging.debug(f"Shape after conv2: {e.shape}") if self.debug else None
        e = F.relu(self.conv3(e))
        logging.debug(f"Shape after conv3: {e.shape}") if self.debug else None
        e = e.view(e.size(0), -1)
        logging.debug(f"Shape after flattening: {e.shape}") if self.debug else None
        z_mean = self.fc_mean(e)
        z_log_var = self.fc_log_var(e)
        logging.debug(f"Shape of z_mean: {z_mean.shape}") if self.debug else None
        logging.debug(f"Shape of z_log_var: {z_log_var.shape}") if self.debug else None
        return z_mean, z_log_var

    def reparameterize(self, mean, log_var):
        epsilon = torch.randn_like(mean)
        return mean + torch.exp(0.5 * log_var) * epsilon

    def decode(self, z):
        logging.debug(f"Input shape to decoder: {z.shape}") if self.debug else None
        d = self.fc_decode(z)
        logging.debug(f"Shape after fc_decode: {d.shape}") if self.debug else None
        d = d.view(d.size(0), 128, self.img_height // 8, self.img_width // 64)  # Reshape
        logging.debug(f"Shape after reshape: {d.shape}") if self.debug else None
        d = F.relu(self.deconv1(d))
        logging.debug(f"Shape after deconv1: {d.shape}") if self.debug else None
        d = F.relu(self.deconv2(d))
        logging.debug(f"Shape after deconv2: {d.shape}") if self.debug else None
        outputs = torch.sigmoid(self.deconv3(d))
        logging.debug(f"Shape after deconv3: {outputs.shape}") if self.debug else None
        return outputs

    def forward(self, x):
        z_mean, z_log_var = self.encode(x)
        z = self.reparameterize(z_mean, z_log_var)
        x_reconstructed = self.decode(z)
        return x_reconstructed, z_mean, z_log_var

    def compute_loss(self, x, x_reconstructed, z_mean, z_log_var):
        # Debugging logs
        logging.debug(f"Shape of input: {x.shape}") if self.debug else None
        logging.debug(f"Shape of reconstructed: {x_reconstructed.shape}") if self.debug else None
        logging.debug(f"x min: {x.min().item()}, x max: {x.max().item()}") if self.debug else None
        logging.debug(f"x_reconstructed min: {x_reconstructed.min().item()}, x_reconstructed max: {x_reconstructed.max().item()}") if self.debug else None
        # Compute the reconstruction loss using MSE
        reconstruction_loss = F.mse_loss(x_reconstructed, x, reduction='sum')
        kl_loss = -0.5 * torch.mean(z_log_var - torch.square(z_mean) - torch.exp(z_log_var) + 1)
        total_loss = reconstruction_loss + kl_loss
        # More debugging logs
        logging.debug(f"Reconstruction loss: {reconstruction_loss.item()}") if self.debug else None
        logging.debug(f"KL loss: {kl_loss.item()}") if self.debug else None
        logging.debug(f"Total loss: {total_loss.item()}") if self.debug else None
        return total_loss

    def training_step(self, x):
        self.optimizer.zero_grad()
        x_reconstructed, z_mean, z_log_var = self(x)
        loss = self.compute_loss(x, x_reconstructed, z_mean, z_log_var)
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def generate_images(self, num_samples):
        with torch.no_grad():
            z = torch.randn(num_samples, self.latent_dim).to(self.device)
            return self.decode(z).cpu().numpy()

    def plot_images(self, training_images, generated_images, scaler, original_shape):
        num_images = min(5, training_images.shape[0], generated_images.shape[0])
        fig, axs = plt.subplots(2, num_images, figsize=(15, 6))

        for i in range(num_images):
            train_mag, train_phase = denormalize_spectrograms(training_images[i], scaler, original_shape)
            gen_mag, gen_phase = denormalize_spectrograms(generated_images[i], scaler, original_shape)

            train_combined = np.concatenate([train_mag, train_phase], axis=-1)
            gen_combined = np.concatenate([gen_mag, gen_phase], axis=-1)

            axs[0, i].imshow(train_combined, aspect='auto', cmap='viridis')
            axs[0, i].set_title('Training Image')
            axs[0, i].axis('off')

            axs[1, i].imshow(gen_combined, aspect='auto', cmap='viridis')
            axs[1, i].set_title('Generated Image')
            axs[1, i].axis('off')

        plt.suptitle('Training vs Generated Images')
        plt.savefig('cnn_vae/generated_images.png')
        plt.close()


In [4]:
def get_image_paths(clean_data_dir):
    image_paths = []
    for filename in os.listdir(clean_data_dir):
        if filename.endswith('_magnitude.npy'):
            ex = filename.split('_')[0]
            magnitude_path = os.path.join(clean_data_dir, filename)
            phase_path = os.path.join(clean_data_dir, f'{ex}_phase.npy')
            if os.path.exists(phase_path):
                image_paths.append((magnitude_path, phase_path))
    return image_paths

class SpectrogramDataset(Dataset):
    def __init__(self, image_paths, subsample_factor=2):
        self.image_paths = image_paths
        self.subsample_factor = subsample_factor
        
        # Fit the scaler on the magnitude of the first item
        first_magnitude_path, _ = self.image_paths[0]
        first_magnitude = np.load(first_magnitude_path)
        self.scaler = MinMaxScaler().fit(first_magnitude.reshape(-1, first_magnitude.shape[-1]))

        # # Store the shape for later use
        self.original_shape = (first_magnitude.shape[0], first_magnitude.shape[1])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        magnitude_path, phase_path = self.image_paths[idx]
        spectrogram_image, _, _ = preprocess_spectrograms(magnitude_path)

        # Subsample the spectrogram image
        spectrogram_image_subsampled = spectrogram_image[:, :, ::self.subsample_factor]
        spectrogram_image_cropped = spectrogram_image_subsampled[:, :-1, :-self.subsample_factor-1]

        return torch.tensor(spectrogram_image_cropped, dtype=torch.float32)

Se subsamplea la imagen original (únicamente en el eje horizontal) con el fin the disminuir dimensionalidad, después se reconstruira esta parte usando una interpolación.

In [5]:
clean_data_dir = 'spectrogram_npy' #drive/MyDrive/spectrogram_npy'
batch_size = 64
latent_dim = 32
epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

device: cuda


In [6]:
subsample_factor = 2
#  Get all image paths
image_paths = get_image_paths(clean_data_dir)
dataset = SpectrogramDataset(image_paths, subsample_factor=subsample_factor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [7]:
# Image dimensions
img_height = 128
img_width = 10328//subsample_factor

# Create the model instance
vae = ConvVAE(in_channels=2, latent_dim=latent_dim, img_height=img_height, img_width=img_width, device=device, debug=False).to(device)
vae.optimizer = optim.Adam(vae.parameters(), lr=1e-4)

In [8]:
# Training loop with loss history
loss_history = []
training_images_batch = None
for epoch in range(epochs):
    epoch_loss = 0
    print(f"epoch {epoch+1}")
    for batch_x in tqdm(dataloader):
        batch_x = batch_x.to(device)
        loss = vae.training_step(batch_x)
        epoch_loss += loss

        # Save a batch of training images for plotting
        if training_images_batch is None:
            training_images_batch = batch_x.cpu().numpy()

    epoch_loss /= len(dataloader)
    loss_history.append(epoch_loss)
    if epoch % 1 == 0:
        print(f'Epoch {epoch}, Loss: {epoch_loss}')

epoch 1


100%|██████████| 11/11 [00:15<00:00,  1.45s/it]


Epoch 0, Loss: 12384765.727272727
epoch 2


100%|██████████| 11/11 [00:15<00:00,  1.44s/it]


Epoch 1, Loss: 10266858.454545455
epoch 3


100%|██████████| 11/11 [00:16<00:00,  1.52s/it]


Epoch 2, Loss: 7753004.045454546
epoch 4


100%|██████████| 11/11 [00:16<00:00,  1.51s/it]


Epoch 3, Loss: 5989357.545454546
epoch 5


100%|██████████| 11/11 [00:15<00:00,  1.37s/it]


Epoch 4, Loss: 4989014.818181818
epoch 6


100%|██████████| 11/11 [00:15<00:00,  1.38s/it]


Epoch 5, Loss: 4461197.9772727275
epoch 7


100%|██████████| 11/11 [00:14<00:00,  1.35s/it]


Epoch 6, Loss: 4224992.5
epoch 8


100%|██████████| 11/11 [00:14<00:00,  1.35s/it]


Epoch 7, Loss: 4103906.0454545454
epoch 9


100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


Epoch 8, Loss: 4041791.840909091
epoch 10


100%|██████████| 11/11 [00:15<00:00,  1.38s/it]


Epoch 9, Loss: 4003557.3636363638
epoch 11


100%|██████████| 11/11 [00:14<00:00,  1.29s/it]


Epoch 10, Loss: 3978384.0454545454
epoch 12


100%|██████████| 11/11 [00:14<00:00,  1.27s/it]


Epoch 11, Loss: 3959497.0454545454
epoch 13


100%|██████████| 11/11 [00:14<00:00,  1.34s/it]


Epoch 12, Loss: 3944723.6363636362
epoch 14


100%|██████████| 11/11 [00:14<00:00,  1.32s/it]


Epoch 13, Loss: 3932326.909090909
epoch 15


100%|██████████| 11/11 [00:13<00:00,  1.26s/it]


Epoch 14, Loss: 3921376.0454545454
epoch 16


100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Epoch 15, Loss: 3911301.6363636362
epoch 17


100%|██████████| 11/11 [00:13<00:00,  1.24s/it]


Epoch 16, Loss: 3902666.4545454546
epoch 18


100%|██████████| 11/11 [00:13<00:00,  1.27s/it]


Epoch 17, Loss: 3894544.159090909
epoch 19


100%|██████████| 11/11 [00:13<00:00,  1.25s/it]


Epoch 18, Loss: 3888059.022727273
epoch 20


100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Epoch 19, Loss: 3882662.8863636362
epoch 21


100%|██████████| 11/11 [00:14<00:00,  1.28s/it]


Epoch 20, Loss: 3878243.7045454546
epoch 22


100%|██████████| 11/11 [00:14<00:00,  1.36s/it]


Epoch 21, Loss: 3873821.75
epoch 23


100%|██████████| 11/11 [00:14<00:00,  1.34s/it]


Epoch 22, Loss: 3870093.5681818184
epoch 24


100%|██████████| 11/11 [00:14<00:00,  1.29s/it]


Epoch 23, Loss: 3866348.0681818184
epoch 25


100%|██████████| 11/11 [00:14<00:00,  1.35s/it]


Epoch 24, Loss: 3863084.727272727
epoch 26


100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Epoch 25, Loss: 3860053.8636363638
epoch 27


100%|██████████| 11/11 [00:14<00:00,  1.34s/it]


Epoch 26, Loss: 3857330.3181818184
epoch 28


100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Epoch 27, Loss: 3854793.4545454546
epoch 29


100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Epoch 28, Loss: 3852393.022727273
epoch 30


100%|██████████| 11/11 [00:14<00:00,  1.29s/it]


Epoch 29, Loss: 3850360.4545454546
epoch 31


100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Epoch 30, Loss: 3848155.159090909
epoch 32


100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Epoch 31, Loss: 3845872.590909091
epoch 33


100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Epoch 32, Loss: 3844128.2954545454
epoch 34


100%|██████████| 11/11 [00:15<00:00,  1.38s/it]


Epoch 33, Loss: 3841947.090909091
epoch 35


100%|██████████| 11/11 [00:15<00:00,  1.39s/it]


Epoch 34, Loss: 3839990.7954545454
epoch 36


100%|██████████| 11/11 [00:14<00:00,  1.28s/it]


Epoch 35, Loss: 3838144.7045454546
epoch 37


100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Epoch 36, Loss: 3836295.409090909
epoch 38


100%|██████████| 11/11 [00:14<00:00,  1.29s/it]


Epoch 37, Loss: 3834531.0681818184
epoch 39


100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Epoch 38, Loss: 3832573.590909091
epoch 40


100%|██████████| 11/11 [00:14<00:00,  1.29s/it]


Epoch 39, Loss: 3830475.5681818184
epoch 41


100%|██████████| 11/11 [00:14<00:00,  1.36s/it]


Epoch 40, Loss: 3828083.8863636362
epoch 42


100%|██████████| 11/11 [00:15<00:00,  1.37s/it]


Epoch 41, Loss: 3826022.022727273
epoch 43


100%|██████████| 11/11 [00:14<00:00,  1.34s/it]


Epoch 42, Loss: 3824486.6818181816
epoch 44


100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


Epoch 43, Loss: 3824072.977272727
epoch 45


100%|██████████| 11/11 [00:14<00:00,  1.34s/it]


Epoch 44, Loss: 3822459.7954545454
epoch 46


100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


Epoch 45, Loss: 3820503.772727273
epoch 47


100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Epoch 46, Loss: 3819129.8636363638
epoch 48


100%|██████████| 11/11 [00:14<00:00,  1.28s/it]


Epoch 47, Loss: 3817915.4318181816
epoch 49


100%|██████████| 11/11 [00:14<00:00,  1.28s/it]


Epoch 48, Loss: 3816831.727272727
epoch 50


100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


Epoch 49, Loss: 3815695.272727273
epoch 51


100%|██████████| 11/11 [00:14<00:00,  1.32s/it]


Epoch 50, Loss: 3814671.022727273
epoch 52


100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


Epoch 51, Loss: 3813678.772727273
epoch 53


100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Epoch 52, Loss: 3812874.022727273
epoch 54


100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Epoch 53, Loss: 3811893.659090909
epoch 55


100%|██████████| 11/11 [00:14<00:00,  1.29s/it]


Epoch 54, Loss: 3811179.3181818184
epoch 56


100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Epoch 55, Loss: 3809823.8181818184
epoch 57


100%|██████████| 11/11 [00:14<00:00,  1.28s/it]


Epoch 56, Loss: 3808721.1136363638
epoch 58


100%|██████████| 11/11 [00:14<00:00,  1.28s/it]


Epoch 57, Loss: 3807746.4318181816
epoch 59


100%|██████████| 11/11 [00:16<00:00,  1.47s/it]


Epoch 58, Loss: 3806710.9545454546
epoch 60


100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Epoch 59, Loss: 3805749.022727273
epoch 61


100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Epoch 60, Loss: 3804762.3863636362
epoch 62


100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Epoch 61, Loss: 3803753.159090909
epoch 63


100%|██████████| 11/11 [00:14<00:00,  1.29s/it]


Epoch 62, Loss: 3802739.022727273
epoch 64


100%|██████████| 11/11 [00:14<00:00,  1.34s/it]


Epoch 63, Loss: 3801632.5
epoch 65


100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Epoch 64, Loss: 3800626.4545454546
epoch 66


100%|██████████| 11/11 [00:14<00:00,  1.29s/it]


Epoch 65, Loss: 3799458.727272727
epoch 67


100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Epoch 66, Loss: 3798539.5
epoch 68


100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


Epoch 67, Loss: 3798368.522727273
epoch 69


100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Epoch 68, Loss: 3797295.9545454546
epoch 70


100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


Epoch 69, Loss: 3795832.8181818184
epoch 71


100%|██████████| 11/11 [00:15<00:00,  1.38s/it]


Epoch 70, Loss: 3794230.0
epoch 72


100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


Epoch 71, Loss: 3793197.5681818184
epoch 73


100%|██████████| 11/11 [00:14<00:00,  1.34s/it]


Epoch 72, Loss: 3792064.5
epoch 74


100%|██████████| 11/11 [00:16<00:00,  1.52s/it]


Epoch 73, Loss: 3791643.0681818184
epoch 75


100%|██████████| 11/11 [00:16<00:00,  1.54s/it]


Epoch 74, Loss: 3790530.659090909
epoch 76


100%|██████████| 11/11 [00:17<00:00,  1.62s/it]


Epoch 75, Loss: 3789504.3181818184
epoch 77


100%|██████████| 11/11 [00:17<00:00,  1.55s/it]


Epoch 76, Loss: 3787494.909090909
epoch 78


100%|██████████| 11/11 [00:16<00:00,  1.49s/it]


Epoch 77, Loss: 3785834.772727273
epoch 79


100%|██████████| 11/11 [00:16<00:00,  1.49s/it]


Epoch 78, Loss: 3784389.8181818184
epoch 80


100%|██████████| 11/11 [00:16<00:00,  1.47s/it]


Epoch 79, Loss: 3782784.25
epoch 81


100%|██████████| 11/11 [00:17<00:00,  1.56s/it]


Epoch 80, Loss: 3781599.909090909
epoch 82


100%|██████████| 11/11 [00:16<00:00,  1.47s/it]


Epoch 81, Loss: 3779909.7954545454
epoch 83


100%|██████████| 11/11 [00:18<00:00,  1.66s/it]


Epoch 82, Loss: 3778664.022727273
epoch 84


100%|██████████| 11/11 [00:17<00:00,  1.62s/it]


Epoch 83, Loss: 3777130.909090909
epoch 85


100%|██████████| 11/11 [00:16<00:00,  1.48s/it]


Epoch 84, Loss: 3775965.022727273
epoch 86


100%|██████████| 11/11 [00:17<00:00,  1.56s/it]


Epoch 85, Loss: 3774917.522727273
epoch 87


100%|██████████| 11/11 [00:16<00:00,  1.53s/it]


Epoch 86, Loss: 3773762.7954545454
epoch 88


100%|██████████| 11/11 [00:17<00:00,  1.57s/it]


Epoch 87, Loss: 3772309.227272727
epoch 89


100%|██████████| 11/11 [00:17<00:00,  1.58s/it]


Epoch 88, Loss: 3770771.25
epoch 90


100%|██████████| 11/11 [00:17<00:00,  1.59s/it]


Epoch 89, Loss: 3769381.8863636362
epoch 91


100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Epoch 90, Loss: 3768442.4545454546
epoch 92


100%|██████████| 11/11 [00:15<00:00,  1.37s/it]


Epoch 91, Loss: 3767146.3863636362
epoch 93


100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Epoch 92, Loss: 3765962.340909091
epoch 94


100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Epoch 93, Loss: 3764858.1136363638
epoch 95


100%|██████████| 11/11 [00:14<00:00,  1.32s/it]


Epoch 94, Loss: 3763873.977272727
epoch 96


100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Epoch 95, Loss: 3763034.25
epoch 97


100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


Epoch 96, Loss: 3762145.5454545454
epoch 98


100%|██████████| 11/11 [00:14<00:00,  1.36s/it]


Epoch 97, Loss: 3762084.5681818184
epoch 99


100%|██████████| 11/11 [00:14<00:00,  1.35s/it]


Epoch 98, Loss: 3761044.75
epoch 100


100%|██████████| 11/11 [00:15<00:00,  1.41s/it]

Epoch 99, Loss: 3760440.8863636362





In [9]:
# Plot loss history
plt.figure(figsize=(10, 5))
plt.plot(loss_history, label='Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss History')
plt.legend()
plt.savefig('cnn_vae/loss_history.png')
plt.close()

DEBUG:matplotlib.pyplot:Loaded backend module://matplotlib_inline.backend_inline version unknown.
DEBUG:matplotlib.pyplot:Loaded backend module://matplotlib_inline.backend_inline version unknown.
DEBUG:matplotlib.font_manager:findfont: Matching sans\-serif:style=normal:variant=normal:weight=normal:stretch=normal:size=10.0.
DEBUG:matplotlib.font_manager:findfont: score(FontEntry(fname='c:\\Users\\USUARIO\\anaconda3\\envs\\torch\\Lib\\site-packages\\matplotlib\\mpl-data\\fonts\\ttf\\DejaVuSans.ttf', name='DejaVu Sans', style='normal', variant='normal', weight=400, stretch='normal', size='scalable')) = 0.05
DEBUG:matplotlib.font_manager:findfont: score(FontEntry(fname='c:\\Users\\USUARIO\\anaconda3\\envs\\torch\\Lib\\site-packages\\matplotlib\\mpl-data\\fonts\\ttf\\DejaVuSansMono-Oblique.ttf', name='DejaVu Sans Mono', style='oblique', variant='normal', weight=400, stretch='normal', size='scalable')) = 11.05
DEBUG:matplotlib.font_manager:findfont: score(FontEntry(fname='c:\\Users\\USUARIO\

In [10]:
# Save the model's state dictionary
torch.save(vae.state_dict(), 'cnn_vae/cnn_vae_model.pth')
print("Model saved to 'cnn_vae_model.pth'")

Model saved to 'cnn_vae_model.pth'


In [11]:
# Load the model's state dictionary
vae = ConvVAE(in_channels=2, latent_dim=latent_dim, img_height=img_height, img_width=img_width, device=device).to(device)
vae.load_state_dict(torch.load('cnn_vae/cnn_vae_model.pth'))
vae.eval()
print("Model loaded from 'cnn_vae_model.pth'")

# Generate images using the loaded model
num_samples = 5
generated_images = vae.generate_images(num_samples)

Model loaded from 'cnn_vae_model.pth'


In [12]:
def plot_generated_images(generated_images, scaler, original_shape):
    num_images = min(5, generated_images.shape[0])
    fig, axs = plt.subplots(2, num_images, figsize=(15, 12))

    for i in range(num_images):
        gen_mag, gen_phase = denormalize_spectrograms(generated_images[i], scaler, original_shape)

        # Ensure that the magnitude and phase are both 2D
        if gen_mag.ndim == 3:
            gen_mag = gen_mag[:, :, 0]
        if gen_phase.ndim == 3:
            gen_phase = gen_phase[:, :, 0]

        # Plot magnitude
        axs[0, i].imshow(gen_mag, aspect='auto', cmap='viridis')
        axs[0, i].set_title('Magnitude')
        axs[0, i].axis('off')

        # Plot phase
        axs[1, i].imshow(gen_phase, aspect='auto', cmap='hsv')  # Use 'hsv' for phase to show colors
        axs[1, i].set_title('Phase')
        axs[1, i].axis('off')

    plt.suptitle('Generated Images')
    plt.savefig('cnn_vae/generated_images.png')
    plt.close()

scaler, _ = dataset.scaler, dataset.original_shape
original_shape = (129, 10329)
print(f"original shape:{original_shape}")

# Assuming 'generated_images' is available
plot_generated_images(generated_images, scaler, original_shape)

original shape:(129, 10329)



For legacy code, nearly bug-for-bug compatible replacements are
`RectBivariateSpline` on regular grids, and `bisplrep`/`bisplev` for
scattered 2D data.

In new code, for regular grids use `RegularGridInterpolator` instead.
For scattered data, prefer `LinearNDInterpolator` or
`CloughTocher2DInterpolator`.

For more details see
`https://scipy.github.io/devdocs/notebooks/interp_transition_guide.html`

  magnitude_interp = interp2d(x, y, magnitude[:, :, 0], kind='cubic')

For legacy code, nearly bug-for-bug compatible replacements are
`RectBivariateSpline` on regular grids, and `bisplrep`/`bisplev` for
scattered 2D data.

In new code, for regular grids use `RegularGridInterpolator` instead.
For scattered data, prefer `LinearNDInterpolator` or
`CloughTocher2DInterpolator`.

For more details see
`https://scipy.github.io/devdocs/notebooks/interp_transition_guide.html`

  phase_interp = interp2d(x, y, phase[:, :, 0], kind='cubic')

For legacy code, nearly bug-for-bug compatible replacements a

In [13]:
# Define the output directory
output_dir = 'cnn_vae/generated_spectrograms'
os.makedirs(output_dir, exist_ok=True)

# Save generated images
for i, spectrogram in enumerate(generated_images):
    # magnitude, phase = spectrogram
    magnitude, phase = denormalize_spectrograms(spectrogram, scaler, original_shape)
    magnitude_path = os.path.join(output_dir, f'generated_magnitude_{i}.npy')
    phase_path = os.path.join(output_dir, f'generated_phase_{i}.npy')
    np.save(magnitude_path, magnitude)
    np.save(phase_path, phase)
    
    print(f'Saved magnitude to {magnitude_path}')
    print(f'Saved phase to {phase_path}')


For legacy code, nearly bug-for-bug compatible replacements are
`RectBivariateSpline` on regular grids, and `bisplrep`/`bisplev` for
scattered 2D data.

In new code, for regular grids use `RegularGridInterpolator` instead.
For scattered data, prefer `LinearNDInterpolator` or
`CloughTocher2DInterpolator`.

For more details see
`https://scipy.github.io/devdocs/notebooks/interp_transition_guide.html`

  magnitude_interp = interp2d(x, y, magnitude[:, :, 0], kind='cubic')

For legacy code, nearly bug-for-bug compatible replacements are
`RectBivariateSpline` on regular grids, and `bisplrep`/`bisplev` for
scattered 2D data.

In new code, for regular grids use `RegularGridInterpolator` instead.
For scattered data, prefer `LinearNDInterpolator` or
`CloughTocher2DInterpolator`.

For more details see
`https://scipy.github.io/devdocs/notebooks/interp_transition_guide.html`

  phase_interp = interp2d(x, y, phase[:, :, 0], kind='cubic')

For legacy code, nearly bug-for-bug compatible replacements a

Saved magnitude to cnn_vae/generated_spectrograms\generated_magnitude_0.npy
Saved phase to cnn_vae/generated_spectrograms\generated_phase_0.npy
Saved magnitude to cnn_vae/generated_spectrograms\generated_magnitude_1.npy
Saved phase to cnn_vae/generated_spectrograms\generated_phase_1.npy



For legacy code, nearly bug-for-bug compatible replacements are
`RectBivariateSpline` on regular grids, and `bisplrep`/`bisplev` for
scattered 2D data.

In new code, for regular grids use `RegularGridInterpolator` instead.
For scattered data, prefer `LinearNDInterpolator` or
`CloughTocher2DInterpolator`.

For more details see
`https://scipy.github.io/devdocs/notebooks/interp_transition_guide.html`

  magnitude_rescaled = magnitude_interp(new_x, new_y)

For legacy code, nearly bug-for-bug compatible replacements are
`RectBivariateSpline` on regular grids, and `bisplrep`/`bisplev` for
scattered 2D data.

In new code, for regular grids use `RegularGridInterpolator` instead.
For scattered data, prefer `LinearNDInterpolator` or
`CloughTocher2DInterpolator`.

For more details see
`https://scipy.github.io/devdocs/notebooks/interp_transition_guide.html`

  phase_rescaled = phase_interp(new_x, new_y)

For legacy code, nearly bug-for-bug compatible replacements are
`RectBivariateSpline` on regu

Saved magnitude to cnn_vae/generated_spectrograms\generated_magnitude_2.npy
Saved phase to cnn_vae/generated_spectrograms\generated_phase_2.npy
Saved magnitude to cnn_vae/generated_spectrograms\generated_magnitude_3.npy
Saved phase to cnn_vae/generated_spectrograms\generated_phase_3.npy



For legacy code, nearly bug-for-bug compatible replacements are
`RectBivariateSpline` on regular grids, and `bisplrep`/`bisplev` for
scattered 2D data.

In new code, for regular grids use `RegularGridInterpolator` instead.
For scattered data, prefer `LinearNDInterpolator` or
`CloughTocher2DInterpolator`.

For more details see
`https://scipy.github.io/devdocs/notebooks/interp_transition_guide.html`

  phase_interp = interp2d(x, y, phase[:, :, 0], kind='cubic')

For legacy code, nearly bug-for-bug compatible replacements are
`RectBivariateSpline` on regular grids, and `bisplrep`/`bisplev` for
scattered 2D data.

In new code, for regular grids use `RegularGridInterpolator` instead.
For scattered data, prefer `LinearNDInterpolator` or
`CloughTocher2DInterpolator`.

For more details see
`https://scipy.github.io/devdocs/notebooks/interp_transition_guide.html`

  magnitude_rescaled = magnitude_interp(new_x, new_y)

For legacy code, nearly bug-for-bug compatible replacements are
`RectBivariat

Saved magnitude to cnn_vae/generated_spectrograms\generated_magnitude_4.npy
Saved phase to cnn_vae/generated_spectrograms\generated_phase_4.npy


In [14]:
# Define the output directory and files
output_wav_dir = 'cnn_vae/generated_audio'
os.makedirs(output_wav_dir, exist_ok=True)
# Convert and save audio files
for i in range(len(generated_images)):
    magnitude_path = os.path.join(output_dir, f'generated_magnitude_{i}.npy')
    phase_path = os.path.join(output_dir, f'generated_phase_{i}.npy')
    output_wav_path = os.path.join(output_wav_dir, f'generated_audio_{i}.wav')

    convert_npy_to_audio(magnitude_path, phase_path, output_wav_path)

Se guardó la señal de audio en cnn_vae/generated_audio\generated_audio_0.wav
Se guardó la señal de audio en cnn_vae/generated_audio\generated_audio_1.wav
Se guardó la señal de audio en cnn_vae/generated_audio\generated_audio_2.wav
Se guardó la señal de audio en cnn_vae/generated_audio\generated_audio_3.wav
Se guardó la señal de audio en cnn_vae/generated_audio\generated_audio_4.wav
