In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm

import logging

# Configure logging
logging.basicConfig(level=logging.DEBUG)

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
# Function to normalize and preprocess the magnitude and phase
def preprocess_spectrograms(magnitude_path, phase_path):
    # Load the magnitude and phase
    magnitude = np.load(magnitude_path)
    phase = np.load(phase_path)

    # Original shape
    original_shape = magnitude.shape

    # Normalize magnitude
    scaler = MinMaxScaler()
    magnitude = scaler.fit_transform(magnitude.reshape(-1, magnitude.shape[-1])).reshape(magnitude.shape)

    # Normalize phase
    phase = (phase + np.pi) / (2 * np.pi)  # Normalize to [0, 1]

    # Concatenate magnitude and phase into a 2-channel image
    spectrogram_image = np.stack([magnitude, phase], axis=0)
    return spectrogram_image, scaler, original_shape

def denormalize_spectrograms(spectrogram_image, scaler, original_shape):
    # Separate magnitude and phase
    magnitude = spectrogram_image[0]
    phase = spectrogram_image[1]

    # Denormalize magnitude
    magnitude = scaler.inverse_transform(magnitude.reshape(-1, magnitude.shape[-1])).reshape(magnitude.shape)

    # Denormalize phase
    phase = phase * 2 * np.pi - np.pi

    return magnitude, phase

class ConvVAE(nn.Module):
    def __init__(self, in_channels, latent_dim, img_height, img_width, device):
        super(ConvVAE, self).__init__()
        self.latent_dim = latent_dim
        self.img_height = img_height
        self.img_width = img_width
        self.device = device

        # Encoder
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=4, stride=2, padding=1)  # Output: (32, img_height/2, img_width/2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)          # Output: (64, img_height/4, img_width/4)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)         # Output: (128, img_height/8, img_width/8)

        # Compute the size of the flattened layer
        flattened_size = 128 * (img_height // 8) * (img_width // 8)
        self.fc_mean = nn.Linear(flattened_size, latent_dim)
        self.fc_log_var = nn.Linear(flattened_size, latent_dim)

        # Decoder
        self.fc_decode = nn.Linear(latent_dim, flattened_size)
        self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # Output: (64, img_height/4, img_width/4)
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)   # Output: (32, img_height/2, img_width/2)
        self.deconv3 = nn.ConvTranspose2d(32, in_channels, kernel_size=4, stride=2, padding=1)  # Output: (in_channels, img_height, img_width)

    def encode(self, x):
        logging.debug(f"Input shape to encoder: {x.shape}")
        e = F.relu(self.conv1(x))
        logging.debug(f"Shape after conv1: {e.shape}")
        e = F.relu(self.conv2(e))
        logging.debug(f"Shape after conv2: {e.shape}")
        e = F.relu(self.conv3(e))
        logging.debug(f"Shape after conv3: {e.shape}")
        e = e.view(e.size(0), -1)
        logging.debug(f"Shape after flattening: {e.shape}")
        z_mean = self.fc_mean(e)
        z_log_var = self.fc_log_var(e)
        logging.debug(f"Shape of z_mean: {z_mean.shape}")
        logging.debug(f"Shape of z_log_var: {z_log_var.shape}")
        return z_mean, z_log_var

    def reparameterize(self, mean, log_var):
        epsilon = torch.randn_like(mean)
        return mean + torch.exp(0.5 * log_var) * epsilon

    def decode(self, z):
        logging.debug(f"Input shape to decoder: {z.shape}")
        d = self.fc_decode(z)
        logging.debug(f"Shape after fc_decode: {d.shape}")
        d = d.view(d.size(0), 128, self.img_height // 8, self.img_width // 8)  # Reshape
        logging.debug(f"Shape after reshape: {d.shape}")
        d = F.relu(self.deconv1(d))
        logging.debug(f"Shape after deconv1: {d.shape}")
        d = F.relu(self.deconv2(d))
        logging.debug(f"Shape after deconv2: {d.shape}")
        outputs = torch.sigmoid(self.deconv3(d))
        logging.debug(f"Shape after deconv3: {outputs.shape}")
        return outputs

    def forward(self, x):
        z_mean, z_log_var = self.encode(x)
        z = self.reparameterize(z_mean, z_log_var)
        x_reconstructed = self.decode(z)
        return x_reconstructed, z_mean, z_log_var

    def compute_loss(self, x, x_reconstructed, z_mean, z_log_var):
        # Debugging logs
        logging.debug(f"Shape of input: {x.shape}")
        logging.debug(f"Shape of reconstructed: {x_reconstructed.shape}")
        logging.debug(f"x min: {x.min().item()}, x max: {x.max().item()}")
        logging.debug(f"x_reconstructed min: {x_reconstructed.min().item()}, x_reconstructed max: {x_reconstructed.max().item()}")
        # Compute the reconstruction loss using MSE
        reconstruction_loss = F.mse_loss(x_reconstructed, x, reduction='sum')
        kl_loss = -0.5 * torch.mean(z_log_var - torch.square(z_mean) - torch.exp(z_log_var) + 1)
        total_loss = reconstruction_loss + kl_loss
        # More debugging logs
        logging.debug(f"Reconstruction loss: {reconstruction_loss.item()}")
        logging.debug(f"KL loss: {kl_loss.item()}")
        logging.debug(f"Total loss: {total_loss.item()}")
        return total_loss

    def training_step(self, x):
        self.optimizer.zero_grad()
        x_reconstructed, z_mean, z_log_var = self(x)
        loss = self.compute_loss(x, x_reconstructed, z_mean, z_log_var)
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def generate_images(self, num_samples):
        with torch.no_grad():
            z = torch.randn(num_samples, self.latent_dim).to(self.device)
            return self.decode(z).cpu().numpy()

    def plot_images(self, training_images, generated_images):
        num_images = min(5, training_images.shape[0], generated_images.shape[0])
        fig, axs = plt.subplots(2, num_images, figsize=(15, 6))

        for i in range(num_images):
            # Plot training images
            axs[0, i].imshow(training_images[i].transpose(1, 2, 0))
            axs[0, i].set_title('Training Image')
            axs[0, i].axis('off')

            # Plot generated images
            axs[1, i].imshow(generated_images[i].transpose(1, 2, 0))
            axs[1, i].set_title('Generated Image')
            axs[1, i].axis('off')

        plt.suptitle('Training vs Generated Images')
        plt.savefig('CVAE.png')
        plt.close()

In [4]:
def get_image_paths(clean_data_dir):
    image_paths = []
    for filename in os.listdir(clean_data_dir):
        if filename.endswith('_magnitude.npy'):
            ex = filename.split('_')[0]
            magnitude_path = os.path.join(clean_data_dir, filename)
            phase_path = os.path.join(clean_data_dir, f'{ex}_phase.npy')
            if os.path.exists(phase_path):
                image_paths.append((magnitude_path, phase_path))
    return image_paths

class SpectrogramDataset(Dataset):
    def __init__(self, image_paths):
        self.image_paths = image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        magnitude_path, phase_path = self.image_paths[idx]
        spectrogram_image, scaler, original_shape = preprocess_spectrograms(magnitude_path, phase_path)

        # Remove the last row and column from the spectrogram image
        spectrogram_image_cropped = spectrogram_image[:, :-1, :-1]  # Assumes spectrogram_image has shape (channels, height, width)

        return torch.tensor(spectrogram_image_cropped, dtype=torch.float32)

In [5]:
clean_data_dir = 'drive/MyDrive/spectrogram_npy'
batch_size = 8
latent_dim = 32
epochs = 10

In [6]:
# Get all image paths
image_paths = get_image_paths(clean_data_dir)
dataset = SpectrogramDataset(image_paths)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Image dimensions
img_height = 128
img_width = 10328

# Create the model instance
vae = ConvVAE(in_channels=2, latent_dim=latent_dim, img_height=img_height, img_width=img_width, device=device).to(device)
vae.optimizer = optim.Adam(vae.parameters(), lr=1e-3)

In [None]:
# Training loop with loss history
loss_history = []
training_images_batch = None
for epoch in range(epochs):
    epoch_loss = 0
    print(f"epoch {epoch}")
    for batch_x in tqdm(dataloader):
        batch_x = batch_x.to(device)
        loss = vae.training_step(batch_x)
        epoch_loss += loss

        # Save a batch of training images for plotting
        if training_images_batch is None:
            training_images_batch = batch_x.cpu().numpy()

    epoch_loss /= len(dataloader)
    loss_history.append(epoch_loss)
    if epoch % 1 == 0:
        print(f'Epoch {epoch}, Loss: {epoch_loss}')

epoch 0


  2%|▏         | 2/88 [01:16<53:22, 37.24s/it]  

In [None]:
# Plot loss history
plt.figure(figsize=(10, 5))
plt.plot(loss_history, label='Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss History')
plt.legend()
plt.savefig('loss_history.png')
plt.close()

# Generate and plot images
num_samples = 5
generated_images = vae.generate_images(num_samples)
vae.plot_images(training_images_batch[:num_samples], generated_images)


In [None]:
# Save the model's state dictionary
torch.save(vae.state_dict(), 'vae_model.pth')
print("Model saved to 'vae_model.pth'")