In [125]:
import pandas as pd
import umap
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm
from src.models.spectrogram_vae import SpectrogramVAE
from src.utils import audio_to_spectrogram
from src.plot_utils import *

In [126]:
sns.set(style='dark')

In [127]:
DAFX_NAME = "mda Overdrive"
NUM_EXAMPLES = 10_000
CHECKPOINT = "/home/kieran/Level5ProjectAudioVAE/src/l5proj_spectrogram_vae/hdx3y4ly/checkpoints/epoch=169-step=35530.ckpt"
CHECKPOINT_ID = CHECKPOINT.split("/")[-3]

In [128]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [129]:
dafx = dafx_from_name(DAFX_NAME)

In [130]:
model = SpectrogramVAE.load_from_checkpoint(CHECKPOINT).to(DEVICE)
model.eval()

SpectrogramVAE(
  (encoder_conv): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Sequential(
      (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (mu): Linear(in_features=37152, out_features=128, bias=True)
  (log_var): Linear(in_features=37152, 

In [131]:
dataset = get_audio_dataset(dafx_from_name('clean'),
                            num_examples_per_epoch=NUM_EXAMPLES)

100%|██████████████████████████████████████████| 88/88 [00:00<00:00, 638.10it/s]


Loaded 88 files for train = 66.89 hours.





In [None]:
settings = []
embeddings = []

x = next(iter(dataset))

for i in tqdm(range(NUM_EXAMPLES)):
    setting = dafx.get_random_parameter_settings()

    # Apply setting to audio
    y = dafx.apply(x, setting)
    y = y.unsqueeze(0).unsqueeze(0)

    X = audio_to_spectrogram(signal=y,
                 n_fft=model.hparams.n_fft,
                 hop_length=model.hparams.hop_length,
                 window_size=model.hparams.window_size).to(DEVICE)

    _, _, _, z = model(X)

    settings.append(setting.cpu().detach().numpy())
    embeddings.append(z.cpu().detach().numpy())

 36%|███▋      | 3642/10000 [00:37<00:59, 106.40it/s]

In [None]:
data = np.array(embeddings).squeeze()
settings = np.array(settings).squeeze()

In [None]:
emb = umap.UMAP(n_neighbors=15, min_dist=0.1, metric='euclidean').fit_transform(data)

In [None]:
df_data = {dafx.idx_to_param_map[i]: settings[:,i] for i in range(dafx.get_num_params())}
df_data.update({"x": emb[:, 0], "y": emb[:, 1]})

In [None]:
df = pd.DataFrame(df_data)

In [None]:
df

In [None]:
EXPERIMENT_NAME = f"{CHECKPOINT_ID}_{DAFX_NAME.split()[-1]}_{NUM_EXAMPLES}settings"

In [None]:
sns.set()

n = dafx.get_num_params() # Define the size of the plot
max_columns = 3  # set a maximum number of columns

num_rows, num_cols = get_subplot_dimensions(n, max_columns=max_columns)
# Create the figure and subplots
fig, axs = plt.subplots(num_rows, num_cols, figsize=(4*num_cols, 4*num_rows + 2))

count = 0
for i in range(dafx.get_num_params()):
    row_idx, col_idx = divmod(count, num_cols)

    if num_rows == 1:
            current_ax = axs[col_idx]
    elif num_cols == 1:
        current_ax = axs[row_idx]
    else:
        current_ax = axs[row_idx, col_idx]

    param_name = dafx.idx_to_param_map[i]

    sc = current_ax.scatter(x=df['x'], y=df['y'], c=df[param_name], s=10, alpha=0.7, cmap='magma')

    current_ax.set_title(param_name)
    current_ax.grid()
    current_ax.set_xlabel("")
    current_ax.set_ylabel("")
    # current_ax.set_aspect('equal', 'datalim')

    count += 1

    # If we have processed all the subplots, break out of the loop
    if count == n:
        break

fig.suptitle(f"{DAFX_NAME} {NUM_EXAMPLES} parameter configurations latent space")

# If we have fewer subplots than required, remove the remaining subplots
while count < n:
    row_idx, col_idx = divmod(count, num_cols)
    axs[row_idx, col_idx].remove()
    # fig.delaxes()
    count += 1

# Adjust the spacing between subplots
fig.tight_layout()

# Add colorbar
cbar = fig.colorbar(sc, ax=axs.ravel().tolist(), aspect=20, shrink=.5, pad=.1, orientation='horizontal')
cbar.set_label('Parameter value')

plt.savefig(f"./figures/random_param_plot/{EXPERIMENT_NAME}.png")