In [72]:
import pandas as pd
import numpy as np
import torch

from tqdm import tqdm

In [73]:
NUM_PARAMS = 5
AUDIO_LEN = 131_072
NUM_EXAMPLES = 1_000

In [74]:
def get_dummy_audio():
    return torch.rand(1, AUDIO_LEN)

In [75]:
def get_audio_features(x):
    return x[0,0].abs().item(), x[0,-1].abs().item()

In [83]:
def get_random_settings():
    return np.ones(NUM_PARAMS) * 0.5

In [84]:
def apply(signal, params):
    return signal * params[0]

In [85]:
def get_audio_embeddings(x, y):
    return torch.rand(1, 128), torch.rand(1, 128)

In [86]:
def predict_for_embeddings(z_x, z_y):
    # y_hat, p, z
    return torch.rand(1, AUDIO_LEN), torch.ones(NUM_PARAMS), torch.concat([z_x, z_y], dim=1)

In [87]:
def save_audio(signal, f_name):
    f_name = f"./audio{i}.wav"
    return f_name

In [104]:
y_embeddings = []
metadata = []

x = get_dummy_audio()
x_b, x_d = get_audio_features(x)

for i in tqdm(range(NUM_EXAMPLES)):
    p = get_random_settings()
    y = apply(x, p)
    
    z_x, z_y = get_audio_embeddings(x, y)
    
    y_hat, p_hat, _ = predict_for_embeddings(z_x, z_y)
    
    y_hat_b, y_hat_d = get_audio_features(y_hat)
    
    f_name = save_audio(y_hat, i)   
    
    data_dict = {
        'idx': i,
        'audio_file': f_name,
        'p': p.tolist(),
        'p_hat': p_hat.cpu().detach().tolist(),
        'x_b': x_b,
        'x_d': x_d,
        'y_hat_b': y_hat_b,
        'y_hat_d': y_hat_d
    }
    
    y_embeddings.append(z_y.cpu().detach().numpy())
    metadata.append(data_dict)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1177.46it/s]


In [105]:
data_dict

{'idx': 999,
 'audio_file': './audio999.wav',
 'p': [0.5, 0.5, 0.5, 0.5, 0.5],
 'p_hat': [1.0, 1.0, 1.0, 1.0, 1.0],
 'x_b': 0.5457613468170166,
 'x_d': 0.26437288522720337,
 'y_hat_b': 0.13721776008605957,
 'y_hat_d': 0.682701587677002}

In [106]:
embeddings = np.array(y_embeddings)

In [107]:
embeddings

array([[[0.45345825, 0.09091669, 0.14712751, ..., 0.92848897,
         0.06354427, 0.7937085 ]],

       [[0.13964194, 0.14951897, 0.47155482, ..., 0.8443642 ,
         0.8752798 , 0.25924605]],

       [[0.7502417 , 0.5776028 , 0.18736202, ..., 0.51862115,
         0.6732958 , 0.69785696]],

       ...,

       [[0.908244  , 0.162776  , 0.418715  , ..., 0.8724022 ,
         0.7061511 , 0.8146104 ]],

       [[0.9140881 , 0.75790155, 0.52919185, ..., 0.9824308 ,
         0.6462053 , 0.01463723]],

       [[0.31213206, 0.61359096, 0.29308003, ..., 0.07543206,
         0.5162881 , 0.284436  ]]], dtype=float32)