In [278]:
import pandas as pd
import umap
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm
from src.models.spectrogram_vae import SpectrogramVAE
from src.utils import audio_to_spectrogram
from src.plot_utils import *

In [279]:
sns.set(style='dark')

In [280]:
DAFX_NAME = "mda Ambience"
NUM_EXAMPLES = 2_000
CHECKPOINT = "/home/kieran/Level5ProjectAudioVAE/src/l5proj_spectrogram_vae/hdx3y4ly/checkpoints/epoch=169-step=35530.ckpt"
CHECKPOINT_ID = CHECKPOINT.split("/")[-3]

In [281]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [282]:
dafx = dafx_from_name(DAFX_NAME)

In [283]:
model = SpectrogramVAE.load_from_checkpoint(CHECKPOINT).to(DEVICE)
model.eval()

In [284]:
dataset = get_audio_dataset(dafx_from_name('clean'),
                            num_examples_per_epoch=NUM_EXAMPLES)

In [285]:
settings = []
embeddings = []

x = next(iter(dataset))

for i in tqdm(range(NUM_EXAMPLES)):
    setting = dafx.get_random_parameter_settings()

    # Apply setting to audio
    y = dafx.apply(x, setting)
    y = y.unsqueeze(0).unsqueeze(0)

    X = audio_to_spectrogram(signal=y,
                 n_fft=model.hparams.n_fft,
                 hop_length=model.hparams.hop_length,
                 window_size=model.hparams.window_size).to(DEVICE)

    _, _, _, z = model(X)

    settings.append(setting.cpu().detach().numpy())
    embeddings.append(z.cpu().detach().numpy())

In [286]:
data = np.array(embeddings).squeeze()
settings = np.array(settings).squeeze()

In [287]:
emb = umap.UMAP(n_neighbors=15, min_dist=0.1, metric='euclidean').fit_transform(data)

In [288]:
df_data = {dafx.idx_to_param_map[i]: settings[:,i] for i in range(dafx.get_num_params())}
df_data.update({"x": emb[:, 0], "y": emb[:, 1]})

In [289]:
df = pd.DataFrame(df_data)

In [290]:
df

In [291]:
EXPERIMENT_NAME = f"{CHECKPOINT_ID}_{DAFX_NAME.split()[-1]}_{NUM_EXAMPLES}settings"

In [292]:
sns.set()

n = dafx.get_num_params() # Define the size of the plot
max_columns = 3  # set a maximum number of columns

num_rows, num_cols = get_subplot_dimensions(n, max_columns=max_columns)
# Create the figure and subplots
fig, axs = plt.subplots(num_rows,
                        num_cols,
                        figsize=(4*num_cols, 4*num_rows + 2),
                        sharex=True,
                        sharey=True)

count = 0
for i in range(dafx.get_num_params()):
    row_idx, col_idx = divmod(count, num_cols)

    if num_rows == 1:
            current_ax = axs[col_idx]
    elif num_cols == 1:
        current_ax = axs[row_idx]
    else:
        current_ax = axs[row_idx, col_idx]

    param_name = dafx.idx_to_param_map[i]

    sc = current_ax.scatter(x=df['x'], y=df['y'], c=df[param_name], alpha=0.7, cmap='magma')

    current_ax.set_title(param_name)
    current_ax.grid()
    current_ax.set_xlabel("")
    current_ax.set_ylabel("")
    # current_ax.set_aspect('equal', 'datalim')

    count += 1

    # If we have processed all the subplots, break out of the loop
    if count == n:
        break

fig.suptitle(f"{DAFX_NAME} {NUM_EXAMPLES} parameter configurations latent space")

# If we have fewer subplots than required, remove the remaining subplots
while count < n:
    row_idx, col_idx = divmod(count, num_cols)
    axs[row_idx, col_idx].remove()
    # fig.delaxes()
    count += 1

# Adjust the spacing between subplots
fig.tight_layout()

# Add colorbar
cbar = fig.colorbar(sc, ax=axs.ravel().tolist(), aspect=20, shrink=.5, pad=.1, orientation='horizontal')
cbar.set_label('Parameter value')

plt.savefig(f"./figures/random_param_plot/{EXPERIMENT_NAME}.png")