In [None]:
import argparse
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from audio_encoders_pytorch.audio_encoders_pytorch import AutoEncoder1d, TanhBottleneck
import utils.load_datasets
import torch.nn.functional as F
import os
from tqdm import tqdm
from accelerate import Accelerator
import random
import json
import matplotlib.pyplot as plt
import numpy as np
from torchaudio.transforms import Spectrogram
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import seaborn as sns

def plot_latent_space(latent_representations, labels, idx=0, save_dir = "plots", save_name="latent_space"):
    """
    Plot the latent space of the encoded signals using different methods.
    
    Parameters:
    - latent_representations: The encoded latent representations.
    - idx: Index for distinguishing multiple plots if needed.
    """
    latent_representations = latent_representations.cpu().detach().numpy()
        # Flatten latent representations if necessary
    if latent_representations.ndim > 2:
        num_samples = latent_representations.shape[0]
        num_features = np.prod(latent_representations.shape[1:])
        latent_representations = latent_representations.reshape(num_samples, num_features)
        
    os.makedirs(save_dir, exist_ok=True)
    # Create a 2D histogram (heatmap) from the latent space data
    plt.figure(figsize=(10, 8))
    heatmap, xedges, yedges = np.histogram2d(latent_representations[:, 0], latent_representations[:, 1], bins=50, range=[[-1, 1], [-1, 1]])

    # Plot heatmap
    sns.heatmap(heatmap.T, cmap='viridis', cbar=True, xticklabels=50, yticklabels=50)

    plt.title('Latent Space Heatmap')
    plt.xlabel('Latent Dimension 1')
    plt.ylabel('Latent Dimension 2')
    plt.savefig(os.path.join(save_dir, f"{save_name}_heatmap_{idx}.jpg"), format="jpg")
    plt.close()
    
    max_samples = 10000  # Adjust this based on memory constraints
    if latent_representations.shape[0] > max_samples:
        indices = np.random.choice(latent_representations.shape[0], max_samples, replace=False)
        latent_representations = latent_representations[indices]

    pca_result = PCA(n_components=2).fit_transform(latent_representations)
    tsne_result = TSNE(n_components=2, random_state=42).fit_transform(latent_representations)
    
    # Plot configurations
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Simple plot (first two dimensions)
    axes[0].scatter(latent_representations[:, 0], latent_representations[:, 1], c=labels, alpha=0.7, cmap='viridis')
    axes[0].set_title(f'Latent Space - Simple (Pair {idx + 1})')
    axes[0].set_xlabel('Latent Dimension 1')
    axes[0].set_ylabel('Latent Dimension 2')
    axes[0].grid(True)
    fig.colorbar(scatter, ax=axes[0])
    
    # PCA plot
    axes[1].scatter(pca_result[:, 0], pca_result[:, 1], c=labels, alpha=0.7, cmap='viridis')
    axes[1].set_title(f'Latent Space - PCA (Pair {idx + 1})')
    axes[1].set_xlabel('PCA Component 1')
    axes[1].set_ylabel('PCA Component 2')
    axes[1].grid(True)
    fig.colorbar(scatter, ax=axes[1])
    
    # t-SNE plot
    axes[2].scatter(tsne_result[:, 0], tsne_result[:, 1], c=labels, alpha=0.7, cmap='viridis')
    axes[2].set_title(f'Latent Space - t-SNE (Pair {idx + 1})')
    axes[2].set_xlabel('t-SNE Dimension 1')
    axes[2].set_ylabel('t-SNE Dimension 2')
    axes[2].grid(True)
    fig.colorbar(scatter, ax=axes[2])
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"{save_name}_{idx}.jpg"), format="jpg")
    plt.close()
    
def plot_waveform_and_spectrogram(input_signal, decoded_signal, idx, save_dir="plots", save_name = "waveform_spectrogram"):
    """
    Plot the waveform and spectrogram of the input and decoded signals.
    """
    os.makedirs(save_dir, exist_ok=True)
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Time-domain signals
    axes[0, 0].plot(input_signal.cpu().numpy().flatten(), color='blue')
    axes[0, 0].set_title(f'Input Signal - Pair {idx + 1}')
    
    axes[0, 1].plot(decoded_signal.cpu().numpy().flatten(), color='red')
    axes[0, 1].set_title(f'Decoded Signal - Pair {idx + 1}')
    
    # Spectrograms
    spectrogram_transform = Spectrogram(n_fft=1024).to(input_signal.device)
    
    input_spectrogram = spectrogram_transform(input_signal).log2()[0, :, :].detach().cpu().numpy()
    decoded_spectrogram = spectrogram_transform(decoded_signal).log2()[0, :, :].detach().cpu().numpy()
    
    axes[1, 0].imshow(input_spectrogram, aspect='auto', origin='lower')
    axes[1, 0].set_title(f'Input Spectrogram - Pair {idx + 1}')
    
    axes[1, 1].imshow(decoded_spectrogram, aspect='auto', origin='lower')
    axes[1, 1].set_title(f'Decoded Spectrogram - Pair {idx + 1}')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"{save_name}_{idx}.jpg"), format="jpg")
    plt.close()

def compute_spectrogram_loss(input_signal, decoded_signal):
    """
    Compute the reconstruction loss based on spectrograms of the input and decoded signals.
    
    Parameters:
    - input_signal: The original input signal.
    - decoded_signal: The signal reconstructed by the autoencoder.
    - config: Configuration dictionary with spectrogram parameters.
    
    Returns:
    - spectrogram_loss: The loss between the spectrograms of input and decoded signals.
    """
    spectrogram_transform = Spectrogram(n_fft=1024).to(input_signal.device)
    
    input_spectrogram = spectrogram_transform(input_signal)
    decoded_spectrogram = spectrogram_transform(decoded_signal)
    
    spectrogram_loss = F.mse_loss(decoded_spectrogram, input_spectrogram)
    return spectrogram_loss

def setup_dataloader(batch_size, num_workers, val_split=0.2):
    dataset = utils.load_datasets.DeepSig2018Dataset(
        "/ext/trey/experiment_diffusion/experiment_rfdiffusion/dataset/GOLD_XYZ_OSC.0001_1024.hdf5")
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True,
                              num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)

    return train_loader, val_loader

def evaluate_model(model, data_loader, accelerator):
    model.eval()
    total_time_loss = 0.0
    total_spectrogram_loss = 0.0
    num_samples = 0
    
    # To collect all latent representations
    all_latent_representations = []
    all_labels = []
    with torch.no_grad():
        for batch_idx, (x, labels) in enumerate(data_loader):
            x = x.to(accelerator.device)
            y = model.encode(x)
            y_decoded = model.decode(y)
            #labels = labels.to(accelerator.device)  # Assuming labels are provided

            # Time-domain reconstruction loss
            time_loss = F.mse_loss(y_decoded, x)
            total_time_loss += time_loss.item() * x.size(0)
            
            # Spectrogram-based reconstruction loss
            spectrogram_loss = compute_spectrogram_loss(x, y_decoded)
            total_spectrogram_loss += spectrogram_loss.item() * x.size(0)
            all_labels.append(labels)
            num_samples += x.size(0)
            # Collect latent representations for visualization
            all_latent_representations.append(y.cpu())
            # Visualize the first 5 pairs in the batch
            if batch_idx < 3:
                for i in range(min(len(x), 1)):
                    plot_waveform_and_spectrogram(x[i], y_decoded[i], i)
    
    avg_time_loss = total_time_loss / num_samples
    avg_spectrogram_loss = total_spectrogram_loss / num_samples
    # Concatenate all latent representations into a single tensor
    all_latent_representations = torch.cat(all_latent_representations, dim=0)
    plot_latent_space(all_latent_representations, all_labels)
    # Log the losses to wandb
    accelerator.log({
        "eval_time_loss": avg_time_loss,
        "eval_spectrogram_loss": avg_spectrogram_loss
    })

    return avg_time_loss, avg_spectrogram_loss

def setup_model(config):
    return AutoEncoder1d(
        in_channels=config['ae_in_channels'],
        channels=config['ae_channels'],
        multipliers=config['ae_multipliers'],
        factors=config['ae_factors'],
        num_blocks=config['ae_num_blocks'],
        patch_size=config['ae_patch_size'],
        resnet_groups=config['ae_resnet_groups'],
        bottleneck=TanhBottleneck()  # You might want to make this configurable too
    )


def setup_training(config, model):
    optimizer = Adam(model.parameters(), lr=config['learning_rate'], betas=tuple(config['adam_betas']))
    criterion = nn.MSELoss()
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, config['gamma'])
    return optimizer, criterion, scheduler


def setup_dataloader(config, val_split=0.2):
    batch_size = config["batch_size"]
    num_workers = config["num_workers"]
    dataset = utils.load_datasets.DeepSig2018Dataset_MOD(
        "/ext/trey/experiment_diffusion/experiment_rfdiffusion/dataset/GOLD_XYZ_OSC.0001_1024.hdf5")
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True,
                              num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)

    return train_loader, val_loader


def setup_accelerator(config):
    accelerator = Accelerator(log_with="wandb")
    run_name = str(random.randint(0, 10e5))
    accelerator.init_trackers(
        config['project_name'],
        config=config,
        init_kwargs={"wandb": {"name": run_name}}
    )
    return accelerator, run_name


def train_model(model, optimizer, criterion, scheduler, train_loader, val_loader, accelerator, config):
    model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
        model, optimizer, train_loader, val_loader, scheduler
    )
    num_training_steps = config['epochs'] * len(train_loader)
    progress_bar = tqdm(range(num_training_steps), disable=not accelerator.is_local_main_process)

    model.train()
    step = 1

    for epoch in range(config['epochs']):
        for x, _ in train_loader:
            y = model(x)
            loss = criterion(y, x)

            accelerator.backward(loss)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            time_loss, freq_loss = evaluate_model(model, val_loader, accelerator)
            accelerator.log({"training_loss": loss, "learning_rate": scheduler.get_last_lr()[0]}, step=step)
            step += 1
            
           
        if epoch % config['save_every'] == 0 and accelerator.is_main_process:
            save_checkpoint(model, optimizer, epoch, config['model_save_dir'], f'model_epoch_{epoch}.pth')

    accelerator.end_training()


def save_checkpoint(model, optimizer, epoch, save_dir, filename):
    checkpoint_path = os.path.join(save_dir, filename)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, checkpoint_path)


def main():
    config_path = 'config_autoencoder.json'  # Specify your JSON file path here
    with open(config_path, 'r') as f:
        config = json.load(f)

    # Construct model_save_dir
    config['model_save_dir'] = os.path.join(config['base_save_dir'], config['project_name'])
    os.makedirs(config['model_save_dir'], exist_ok=True)

    accelerator, run_name = setup_accelerator(config)

    model = setup_model(config)
    optimizer, criterion, scheduler = setup_training(config, model)
    train_loader, val_loader = setup_dataloader(config)


    print(f"Training on {accelerator.num_processes} GPUs")
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    print(f"Models will be saved in: {config['model_save_dir']}")

    train_model(model, optimizer, criterion, scheduler, train_loader, val_loader, accelerator, config)

    if accelerator.is_main_process:
        final_checkpoint_path = os.path.join(config['model_save_dir'], f'model_{run_name}.pth')
        save_checkpoint(accelerator.unwrap_model(model), optimizer, config['epochs'], config['model_save_dir'],
                        final_checkpoint_path)

    print("Training complete and models saved.")


if __name__ == "__main__":
    main()

Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


Training on 1 GPUs
Number of parameters: 2358298
Models will be saved in: /ext/trey/experiment_diffusion/experiment_rfdiffusion/models/autoencoder_metric_testing





  0%|                                                                                                                                                                                                                                                                                              | 0/399400 [00:00<?, ?it/s][A[A[A


  0%|                                                                                                                                                                                                                                                                                 | 1/399400 [00:16<1784:42:34, 16.09s/it][A[A[A