In [1]:
from functions import *

Audio to feature spectrogram:
1. Mel-spectrogram
2. Power-to-dB
3. Standardization: $z=\frac{x-\mu}{\sigma}$

Feature spectrogram to audio:
1. Unstandardization: $x=z\cdot\sigma+\mu$
2. dB-to-power
3. Griffin-Lim

In [3]:
n_segs = 10095

In [2]:
df = pd.read_csv('data\\deam_split.csv')
songs = df['song_id']
test = df['test']
val = df['validation']

In [4]:
max_pct_zero = 0.5
paths = [f'data\\DEAM_standard_10s\\{song}-*.wav' for song in songs]

X = np.zeros((n_segs, mels, ts, 1))
means = np.zeros((n_segs, 1))
stds  = np.zeros((n_segs, 1))
test_val_train = np.zeros((n_segs, 3), dtype=bool)
ids = []
j = 0
for i in tqdm(range(len(paths))):
    files = glob(paths[i])
    for file in files:
        seg, sr = lb.load(file)
        n = seg.shape[0]
        if (seg == 0).sum()/n <= max_pct_zero:
            db_norm, mean, std = audio_to
            X[j, :, :, 0] = db_norm
            means[j] = mean
            stds[j] = std
            test_val_train[j, 0] = test[i]
            test_val_train[j, 1] = val[i]
            test_val_train[j, 2] = 1 - max(test[i], val[i])
            ids.append(file)
            j += 1
X = X[:j, :, :, :]
test_val_train = test_val_train[:j, :]
means = means[:j]
stds = stds[:j]
ids = np.array(ids)

  0%|          | 0/1802 [00:00<?, ?it/s]

In [6]:
lrelu = keras.layers.LeakyReLU(alpha=0.01)

In [7]:
def get_encoder(latent_dim):
    encoder_inputs = keras.Input(shape=(mels, ts, 1))
    x = Conv2D(64, 3, activation=lrelu, strides=(2, 3), padding='same')(encoder_inputs)
    x = Conv2D(64, 3, activation=lrelu, strides=(2, 5), padding='same')(x)
    x = Conv2D(64, 3, activation=lrelu, strides=(2, 1), padding='same')(x)
    encoder_outputs = Conv2D(latent_dim, 1, padding='same')(x)
    return keras.Model(encoder_inputs, encoder_outputs, name='encoder')

In [8]:
def get_decoder(latent_dim):
    latent_inputs = keras.Input(shape=get_encoder(latent_dim).output.shape[1:])
    x = Conv2DTranspose(64, 3, activation=lrelu, strides=(2, 1), padding='same')(latent_inputs)
    x = Conv2DTranspose(64, 3, activation=lrelu, strides=(2, 5), padding='same')(x)
    x = Conv2DTranspose(64, 3, activation=lrelu, strides=(2, 3), padding='same')(x)
    decoder_outputs = Conv2DTranspose(1, 3, padding='same')(x)
    return keras.Model(latent_inputs, decoder_outputs, name='decoder')

In [9]:
def get_vqvae(latent_dim, num_embeddings, beta=0.25):
    vq_layer = VectorQuantizer(num_embeddings, latent_dim, beta=beta, name='vector_quantizer')
    encoder = get_encoder(latent_dim)
    decoder = get_decoder(latent_dim)
    inputs = keras.Input(shape=(mels, ts, 1))
    encoder_outputs = encoder(inputs)
    quantized_latents = vq_layer(encoder_outputs)
    reconstructions = decoder(quantized_latents)
    return keras.Model(inputs, reconstructions, name='vq_vae')

In [11]:
latent_dim = 32
num_embeddings = 512
beta = 0.25

In [12]:
X_train = X[test_val_train[:, 2], :, :, :]
X_val   = X[test_val_train[:, 1], :, :, :]
X_test  = X[test_val_train[:, 0], :, :, :]
X_nontest = np.concatenate([X_train, X_val], axis=0)

train_id = ids[test_val_train[:, 2]]
val_id   = ids[test_val_train[:, 1]]
test_id  = ids[test_val_train[:, 0]]

train_variance = np.var(X_nontest)

In [16]:
trainer = VQVAETrainer(train_variance, get_vqvae)
trainer.build((None, mels, ts, 1))
trainer.load_weights('vqvae_models/vqvae_model_2023_02_19_v0.h5')

In [17]:
pred = trainer.predict(X_test[0:2, :, :, :])



In [27]:
quant = trainer.vqvae.get_layer('vector_quantizer')
quant.embeddings.numpy()

array([[ 2.0979561e-03,  2.3785140e-02,  3.4285475e-02, ...,
         9.6459284e-02,  3.0276664e-02,  1.5243682e-02],
       [-4.5241009e-02, -2.9260887e-02, -2.1417916e-02, ...,
         7.8100711e-01, -1.3184089e-03, -3.9494004e-02],
       [ 7.5309915e-03, -2.9880738e-02, -3.3727456e-02, ...,
         4.5857084e-01, -7.6894350e-03,  1.1415266e-02],
       ...,
       [-2.2935154e-02,  4.5479570e-02, -2.0887196e-02, ...,
         8.9959401e-01, -3.8612567e-02, -3.6722198e-02],
       [-3.1723928e-02, -5.8093779e-03,  4.1860115e-02, ...,
         5.2589375e-01, -3.8901597e-02, -3.8535632e-02],
       [-2.4184074e-02,  1.2749281e-02,  1.7295409e-02, ...,
        -1.3421308e+00,  2.6127433e-02, -8.8082682e-03]], dtype=float32)