In [134]:
import pandas as pd
import numpy as np
import torch
import os
import umap

from tqdm import tqdm
from functools import reduce

In [135]:
DATA_FOLDER = "./data/"
CHECKPOINT_ID = "TEST"
DIR = os.path.join(DATA_FOLDER, CHECKPOINT_ID)
N = 20

In [136]:
def get_top_n_indices(df, column_name, n):
    return np.argsort(-df[column_name].values)[:n]

In [137]:
def get_bottom_n_indices(df, column_name, n):
    return np.argsort(df[column_name].values)[:n]

In [138]:
metadata = pd.read_csv(os.path.join(DIR, "metadata.csv"))

In [139]:
def common_elements(*arrays):
    sets = [set(array) for array in arrays]
    common = set()
    for i in range(len(arrays)):
        for j in range(i+1, len(arrays)):
            common |= sets[i] & sets[j]
    return list(common)

In [144]:
def get_supervised_feature_ids(df, n, tolerance = 1_000):
    duplicates = []
    iterations = 1
    
    def retrieve_ids(current_df, idx):
        return current_df['id'].values[idx]
        
    while True:
        current_df = df[~df['id'].isin(duplicates)]
        
        top_bright_idx = get_top_n_indices(current_df, 'brightness_diff', n)
        top_bright = retrieve_ids(current_df, top_bright_idx)
        
        bottom_bright_idx = get_bottom_n_indices(current_df, 'brightness_diff', n)
        bottom_bright = retrieve_ids(current_df, bottom_bright_idx)
        
        top_depth_idx = get_top_n_indices(current_df, 'depth_diff', n)
        top_depth = retrieve_ids(current_df, top_depth_idx)
        
        bottom_depth_idx = get_bottom_n_indices(current_df, 'depth_diff', n)
        bottom_depth = retrieve_ids(current_df, bottom_depth_idx)
        
        intersection = common_elements(top_bright,bottom_bright,top_depth,bottom_depth)
        
        if len(intersection) == 0:
            print(f"Found top {n} features in {iterations} iteration(s)!")
            return top_bright, bottom_bright, top_depth, bottom_depth
        
        duplicates.extend(intersection)
        
        iterations += 1
        
        if iterations >= tolerance:
            print(f"Iterations surpassed tolerance of {tolerance}, no IDs returned")
            return None, None, None, None

In [145]:
b_t, b_b, d_t, d_b = get_supervised_feature_ids(metadata, N)

Found top 20 features in 2 iteration(s)!


In [156]:
def set_class_labels(df, bright_top, bright_bottom, deep_top, deep_bottom):
    # give labels to columns
    df['class'] = -1

    df.loc[df['id'].isin(bright_top), 'class'] = 1
    df.loc[df['id'].isin(bright_bottom), 'class'] = 2
    df.loc[df['id'].isin(deep_top), 'class'] = 3
    df.loc[df['id'].isin(deep_bottom), 'class'] = 4

In [157]:
set_class_labels(metadata, b_t, b_b, d_t, d_b)

metadata['class'].value_counts()

-1    920
 1     20
 3     20
 2     20
 4     20
Name: class, dtype: int64

In [161]:
N_SUPERVISED = [0, 5, 10, 25, 100]
N_NEIGHBOURS = 15
MIN_DIST = 0.1
METRIC = 'euclidean'

In [162]:
embeddings = dict()

y_emb = np.load(os.path.join(DIR, 'y_embeddings.npy'))

In [164]:
for N in tqdm(N_SUPERVISED):
    reducer = umap.UMAP(n_neighbors=N_NEIGHBOURS,
                        min_dist=MIN_DIST,
                        metric=METRIC)
    
    if N == 0: 
        emb = reducer.fit_transform(y_emb)
        
    else:
        b_t, b_b, d_t, d_b = get_supervised_feature_ids(metadata, N)
        set_class_labels(metadata, b_t, b_b, d_t, d_b)
        masked_target = metadata['class'].values
        
        emb = reducer.fit_transform(y_emb, y=masked_target)
        
    embeddings[f'n={N}'] = emb    

 20%|█████████▏                                    | 1/5 [00:05<00:20,  5.20s/it]

Found top 5 features in 1 iteration(s)!


 40%|██████████████████▍                           | 2/5 [00:08<00:11,  3.83s/it]

Found top 10 features in 1 iteration(s)!


 60%|███████████████████████████▌                  | 3/5 [00:10<00:06,  3.31s/it]

Found top 25 features in 2 iteration(s)!


 80%|████████████████████████████████████▊         | 4/5 [00:13<00:03,  3.07s/it]

Found top 100 features in 7 iteration(s)!


100%|██████████████████████████████████████████████| 5/5 [00:15<00:00,  3.20s/it]


In [166]:
for k, v in embeddings.items():
    metadata[f'x_emb_{k}'] = v[:, 0]
    metadata[f'y_emb_{k}'] = v[:, 1]

In [167]:
metadata

Unnamed: 0,id,audio_file,x_b,x_d,y_hat_b,y_hat_d,brightness_diff,depth_diff,p_drive,p_hpf,...,x_emb_n=0,y_emb_n=0,x_emb_n=5,y_emb_n=5,x_emb_n=10,y_emb_n=10,x_emb_n=25,y_emb_n=25,x_emb_n=100,y_emb_n=100
0,0,./audio0.wav,0.491584,0.366129,0.179557,0.043799,-0.312027,-0.322329,0.5,0.5,...,4.375982,2.352890,5.141922,9.395411,2.983065,5.224524,4.036384,6.244405,7.815265,4.868262
1,1,./audio1.wav,0.491584,0.366129,0.471837,0.252225,-0.019747,-0.113904,0.5,0.5,...,6.368274,3.637930,5.560197,7.228184,5.753263,4.174337,4.018713,3.969869,5.302396,2.416869
2,2,./audio2.wav,0.491584,0.366129,0.329368,0.900148,-0.162216,0.534020,0.5,0.5,...,7.542970,1.791972,6.328410,6.736017,2.947152,4.533545,1.732041,6.375722,7.962466,6.277854
3,3,./audio3.wav,0.491584,0.366129,0.205323,0.487679,-0.286260,0.121551,0.5,0.5,...,7.668360,2.764169,6.358431,6.503232,5.879977,4.400071,4.064588,4.004843,4.031938,4.853545
4,4,./audio4.wav,0.491584,0.366129,0.626251,0.029368,0.134667,-0.336761,0.5,0.5,...,5.418626,5.127433,7.816084,10.375862,5.425009,4.477740,4.175371,7.850127,8.340319,3.233426
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,995,./audio995.wav,0.491584,0.366129,0.798466,0.889531,0.306882,0.523403,0.5,0.5,...,6.356206,3.661452,7.171679,8.201308,6.022523,6.252724,3.605871,5.807974,6.142357,3.342150
996,996,./audio996.wav,0.491584,0.366129,0.763310,0.865694,0.271727,0.499565,0.5,0.5,...,7.237864,2.426145,5.642355,6.615960,3.712283,3.820810,2.269624,6.818748,6.085785,6.159033
997,997,./audio997.wav,0.491584,0.366129,0.219154,0.087598,-0.272429,-0.278530,0.5,0.5,...,6.686103,3.992872,7.011633,7.014889,5.542364,6.425185,3.443305,6.061607,5.169100,6.250965
998,998,./audio998.wav,0.491584,0.366129,0.498680,0.463480,0.007096,0.097351,0.5,0.5,...,3.591636,2.518865,7.318441,9.835830,3.946112,5.076990,1.687181,5.263514,6.265598,4.829258


In [169]:
metadata.to_csv(os.path.join(DIR, "full_data.csv"))