In [1]:
import os
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, Binarizer
from sklearn.neighbors import NearestNeighbors
import bayessets

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [2]:
DATA_PATH = os.path.expanduser("~/Documents/datasets/synth/setexpansion")
DATASETS = ('densebinary', 'sparsebinary', 'denseinteger', 'sparseinteger')

In [3]:
np.random.seed(42)

In [4]:
SAMPLES_PER_ATTRIBUTE = 10 # how many queries will be produced per attribute
MINIMUM_EXAMPLE_COUNT = 6 # ignores attribute if there's less than this amount of examples
RATIO_OF_SAMPLE = 0.4 # 40% of sample size becomes query and 60% are valid targets
MAXIMUM_SAMPLE_SIZE = 12 # maximum amount of examples used for querying

def merge(targets):
    if len(targets) > 1:
        ret = set()
        for x in targets:
            ret = ret.union(set(x))
        return ','.join(sorted(ret))
    return ','.join(sorted(targets.iloc[0]))

for ds in DATASETS:
    df = pd.read_csv(os.path.join(DATA_PATH, ds + '.csv'))
    m = df.drop('target', axis=1)
    y = df.target
    
    queries = []
    targets = []
    
    for cls in y.unique():
        valid_indices = df.index[df.target == cls].tolist()
        size = len(valid_indices)
        
        if size >= MINIMUM_EXAMPLE_COUNT:
            sample_size = min(int(size * RATIO_OF_SAMPLE), MAXIMUM_SAMPLE_SIZE)
            for _ in range(SAMPLES_PER_ATTRIBUTE):
                np.random.shuffle(valid_indices)
                query = ','.join(map(str, sorted(valid_indices[:sample_size])))
                target = ','.join(map(str, sorted(valid_indices[sample_size:])))
                queries.append(query)
                targets.append(target)
                
    df_query = pd.DataFrame({'query': queries, 'target': targets}).drop_duplicates()
    
    unique_queries = len(df_query['query']) == df_query['query'].nunique()
    
    if not unique_queries:
        df_query = df_query.groupby(by='query').aggregate(merge).reset_index()
       
    unique_queries = len(df_query['query']) == df_query['query'].nunique() 
    assert unique_queries
    
    df_query.to_csv(os.path.join(DATA_PATH, ds + '_query.csv'), index=False)
    
    break

In [9]:
df_query.iloc[0].query

'681,1339,2425,3169,3241,4833,6097,7543,8179,8282,8715,9222'