In [134]:
import pandas as pd
import numpy as np
import torch
import os
import umap

from tqdm import tqdm
from functools import reduce

In [135]:
DATA_FOLDER = "./data/"
CHECKPOINT_ID = "TEST"
DIR = os.path.join(DATA_FOLDER, CHECKPOINT_ID)
N = 20

In [136]:
def get_top_n_indices(df, column_name, n):
    return np.argsort(-df[column_name].values)[:n]

In [137]:
def get_bottom_n_indices(df, column_name, n):
    return np.argsort(df[column_name].values)[:n]

In [138]:
metadata = pd.read_csv(os.path.join(DIR, "metadata.csv"))

In [139]:
def common_elements(*arrays):
    sets = [set(array) for array in arrays]
    common = set()
    for i in range(len(arrays)):
        for j in range(i+1, len(arrays)):
            common |= sets[i] & sets[j]
    return list(common)

In [144]:
def get_supervised_feature_ids(df, n, tolerance = 1_000):
    duplicates = []
    iterations = 1
    
    def retrieve_ids(current_df, idx):
        return current_df['id'].values[idx]
        
    while True:
        current_df = df[~df['id'].isin(duplicates)]
        
        top_bright_idx = get_top_n_indices(current_df, 'brightness_diff', n)
        top_bright = retrieve_ids(current_df, top_bright_idx)
        
        bottom_bright_idx = get_bottom_n_indices(current_df, 'brightness_diff', n)
        bottom_bright = retrieve_ids(current_df, bottom_bright_idx)
        
        top_depth_idx = get_top_n_indices(current_df, 'depth_diff', n)
        top_depth = retrieve_ids(current_df, top_depth_idx)
        
        bottom_depth_idx = get_bottom_n_indices(current_df, 'depth_diff', n)
        bottom_depth = retrieve_ids(current_df, bottom_depth_idx)
        
        intersection = common_elements(top_bright,bottom_bright,top_depth,bottom_depth)
        
        if len(intersection) == 0:
            print(f"Found top {n} features in {iterations} iteration(s)!")
            return top_bright, bottom_bright, top_depth, bottom_depth
        
        duplicates.extend(intersection)
        
        iterations += 1
        
        if iterations >= tolerance:
            print(f"Iterations surpassed tolerance of {tolerance}, no IDs returned")
            return None, None, None, None

In [145]:
b_t, b_b, d_t, d_b = get_supervised_feature_ids(metadata, N)

Found top 20 features in 2 iteration(s)!


In [146]:
# give labels to columns
metadata['class'] = 'NA'

metadata.loc[metadata['id'].isin(b_t), 'class'] = 'bright'
metadata.loc[metadata['id'].isin(b_b), 'class'] = 'not_bright'
metadata.loc[metadata['id'].isin(d_t), 'class'] = 'deep'
metadata.loc[metadata['id'].isin(d_b), 'class'] = 'not_deep'

In [147]:
metadata['class'].value_counts()

NA            920
bright         20
deep           20
not_bright     20
not_deep       20
Name: class, dtype: int64