# 2018-02-16 / FMA sub-sampling mk2

* In case entrofy is too slow / doesn't work out, we'll make a backup sampler
* Requirements:
    * subsample of size N
    * for each instrument, at least K positives
    

In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm

In [3]:
%matplotlib nbagg

In [24]:
df = pd.read_csv('/home/bmcfee/data/vggish-likelihoods-a226b3-maxagg10.csv.gz', index_col=0)

In [25]:
df.head(5)

Unnamed: 0,accordion,bagpipes,banjo,bass,cello,clarinet,cymbals,drums,flute,guitar,...,mandolin,organ,piano,saxophone,synthesizer,trombone,trumpet,ukulele,violin,voice
000002_0000,0.01542,0.008608,0.010215,0.035007,0.008873,0.00893,0.086853,0.67135,0.021807,0.13501,...,0.006079,0.011073,0.084341,0.015115,0.781432,0.012166,0.025021,0.044818,0.067646,0.999691
000002_0001,0.01542,0.008608,0.010215,0.076214,0.008873,0.00893,0.086853,0.630533,0.021807,0.244505,...,0.006079,0.011073,0.084341,0.015115,0.781432,0.012166,0.025021,0.044818,0.067646,0.999691
000002_0002,0.01542,0.008608,0.010215,0.076214,0.008873,0.00893,0.089177,0.858667,0.021807,0.244505,...,0.006079,0.011073,0.084341,0.015115,0.188291,0.012166,0.025021,0.044818,0.067646,0.999691
000002_0003,0.01542,0.008608,0.010215,0.076214,0.004974,0.00893,0.089177,0.858667,0.012667,0.244505,...,0.003388,0.009051,0.04038,0.00912,0.131694,0.00595,0.014247,0.044818,0.067646,0.999691
000002_0004,0.01542,0.008608,0.009334,0.076214,0.004974,0.00893,0.089177,0.858667,0.012667,0.244505,...,0.003388,0.017866,0.078745,0.00912,0.204007,0.00595,0.014247,0.028634,0.088025,0.999691


In [26]:
K = 500
N = K * 23

# Algorithm
- For each instrument, take the top-K track-distinct entries.
- Skip tracks that have already appeared
- Take instruments in ascending order of median likelihood

In [42]:
dfm = df.median(axis=0)

In [47]:
dfm.sort_values().index

Index(['banjo', 'mandolin', 'harp', 'clarinet', 'bagpipes', 'accordion',
       'harmonica', 'ukulele', 'trombone', 'trumpet', 'cello', 'saxophone',
       'flute', 'organ', 'mallet_percussion', 'cymbals', 'violin', 'piano',
       'bass', 'synthesizer', 'guitar', 'drums', 'voice'],
      dtype='object')

In [48]:
idx_set = set()

track_set = set()

for col in tqdm(dfm.sort_values().index):
        
    idx_inst = df[col].sort_values(ascending=False).index
    count = 0
    
    for v in idx_inst:
        track, seg = v.split('_', maxsplit=1)
        
        if track in track_set:
            continue
            
        idx_set.add(v)
        track_set.add(track)
        count += 1
        
        if count == K:
            break




In [49]:
df_sample = df.loc[idx_set]

In [50]:
df_sample.describe()

Unnamed: 0,accordion,bagpipes,banjo,bass,cello,clarinet,cymbals,drums,flute,guitar,...,mandolin,organ,piano,saxophone,synthesizer,trombone,trumpet,ukulele,violin,voice
count,11500.0,11500.0,11500.0,11500.0,11500.0,11500.0,11500.0,11500.0,11500.0,11500.0,...,11500.0,11500.0,11500.0,11500.0,11500.0,11500.0,11500.0,11500.0,11500.0,11500.0
mean,0.085787,0.07199531,0.054406,0.173591,0.102443,0.058038,0.09279,0.377352,0.124724,0.517153,...,0.059258,0.109836,0.20194,0.120517,0.266107,0.093009,0.111922,0.094019,0.289565,0.483926
std,0.216186,0.2189486,0.183548,0.196825,0.206596,0.155045,0.187944,0.319502,0.241927,0.337568,...,0.1596,0.198223,0.211666,0.23177,0.264546,0.202217,0.212164,0.201834,0.320587,0.331609
min,1.5e-05,1.050304e-07,3e-06,0.000856,0.000153,3.7e-05,2e-06,7e-05,2.3e-05,0.001715,...,4e-06,9.4e-05,0.001086,0.000132,2e-05,1.1e-05,4.1e-05,1.2e-05,0.000242,0.002652
25%,0.004035,0.001688136,0.001414,0.049797,0.012927,0.00481,0.010485,0.095295,0.011623,0.202477,...,0.002033,0.018933,0.078108,0.012683,0.049319,0.006875,0.009251,0.004268,0.054098,0.189448
50%,0.012542,0.006368988,0.00461,0.105575,0.028529,0.013288,0.030221,0.28279,0.030994,0.454625,...,0.005547,0.042239,0.139646,0.027828,0.173179,0.016869,0.023927,0.012511,0.141109,0.401221
75%,0.044121,0.02315261,0.015032,0.206402,0.068752,0.036228,0.078883,0.601537,0.084218,0.88963,...,0.018741,0.098972,0.236478,0.084476,0.411509,0.054253,0.076791,0.049772,0.427234,0.823966
max,0.989912,0.9993549,0.981382,0.924326,0.951616,0.965801,0.955233,0.999531,0.994685,0.999874,...,0.913021,0.981344,0.970724,0.966764,0.971533,0.950456,0.943341,0.944238,0.999661,0.999998


In [51]:
len(idx_set)

11500

In [53]:
import json

In [54]:
with open('../notebooks/subsample_idx_greedy.json', 'w') as fd:
    json.dump(list(idx_set), fd, indent=2)