In [6]:
from joblib import Parallel, delayed
import json
from multiprocessing import Pool
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm

# Local libraries
import sys
sys.path.insert(1, '../../Data/scripts/strategies')
from baseline import CasesArchive, Explorer
from classifiers import include_splits
from trees import CaseNode

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
data_dir = '/Users/Javiera/Desktop/Harvard/Code/Data'
features_path = data_dir + '/features/features-2.csv'
clf_path =  data_dir + '/scripts/strategies/classifiers.pkl'
clf_ids_path = data_dir + '/scripts/strategies/clf_ids.csv'
rl_ids_path = data_dir + '/scripts/strategies/rl_ids.csv'
test_ids_path = data_dir + '/scripts/strategies/test_ids.csv'

input_order = ['photo', 'spec', 'color']
settings_path = data_dir + '/scripts/strategies/strategies/settings8.json'

splits = {'rl':rl_ids_path, 'test':test_ids_path}
features = pd.read_csv(features_path)
features = include_splits(features, splits, drop_rows=True)
clf_df = pd.read_pickle(clf_path)

with open(settings_path) as f:
    settings = json.load(f)
root_args = settings['root_args']
tree_args = settings['tree_args']
tree_args['all_clf'] = clf_df

In [13]:
features

Unnamed: 0,id_gaia,id_sdss,label,K_sdss,color_gaia,feh_sdss,gravity_sdss,h_0_0_gaia,h_0_1_gaia,h_0_2_gaia,...,h_9_5_gaia,h_9_6_gaia,h_9_7_gaia,h_9_8_gaia,h_9_9_gaia,lengths_gaia,max_time,min_time,z_sdss,split
0,3705768362187755776,spec-0847-52426-0599.fits,RRC,0.515388,0.132263,0.175637,0.393701,0.218460,0.343241,0.367859,...,0.705418,0.612690,0.620811,0.344524,0.815633,25,2331.350724,1689.732893,0.877002,test
1,2780868247577183104,spec-0418-51817-0365.fits,RRAB,0.515388,0.143858,0.317280,0.393701,0.344367,0.461929,0.398477,...,0.588351,0.339127,0.401497,0.401976,0.446707,6,2330.702549,1694.854847,0.520935,test
2,3303225709169790848,spec-3121-54749-0169.fits,RRAB,0.552201,0.172824,0.458923,0.393701,0.104263,0.318136,0.567020,...,0.498554,0.667637,0.567155,0.441816,0.909953,14,2250.313420,1711.184346,0.591609,rl
3,4410124135237650176,spec-0344-51693-0409.fits,RRAB,0.515388,0.167132,0.317280,0.393701,0.218550,0.222448,0.628313,...,0.615428,0.635436,0.719707,0.456332,0.856989,11,2248.256900,1713.489451,0.743860,test
4,1308954745892960768,spec-2808-54524-0607.fits,RRC,0.552201,0.127830,0.458923,0.787402,0.259685,0.262177,0.684108,...,0.715492,0.679118,0.700970,0.381787,0.854940,13,2244.909256,1721.753031,0.542656,rl
5,4460161874328256896,spec-2530-53881-0559.fits,RRC,0.552201,0.124512,0.458923,0.787402,0.047066,0.478550,0.170266,...,0.563934,0.705951,0.662846,0.441300,0.738589,22,2296.144273,1716.745561,0.650571,test
6,3689922166248251264,spec-0292-51609-0470.fits,RRAB,0.441761,0.138760,0.317280,0.393701,0.186982,0.282644,0.680774,...,0.671900,0.754174,0.875972,0.696356,0.764046,17,2331.347443,1691.230567,0.667653,rl
7,3688533345623140608,spec-0337-51997-0580.fits,RRAB,0.478575,0.144940,0.317280,0.393701,0.283365,0.304835,0.490551,...,0.560764,0.790582,0.629856,0.509946,0.854497,12,2332.347087,1693.481220,0.854600,test
8,6909836470632143488,spec-0637-52174-0294.fits,RRC,0.515388,0.135398,0.458923,0.787402,0.281359,0.264978,0.673002,...,0.789797,0.800348,0.909220,0.533152,0.861621,10,2275.242848,1758.842053,0.660690,rl
9,3706857050498504320,spec-0845-52381-0244.fits,RRC,0.552201,0.123135,0.175637,0.393701,0.137594,0.295605,0.500561,...,0.740538,0.704151,0.821723,0.439899,0.926514,19,2329.274113,1685.227729,0.691201,rl


In [8]:
setup_base = root_args.copy()
setup_base.update(tree_args)
del setup_base['p_thres']

In [11]:
min_obs = 5
max_obs = features.lengths_gaia.max()
obs_range = list(range(min_obs, max_obs+1))
spec_range = [0,1]
col_range = [0,1]

archive = CasesArchive(features, setup_base)
archive.build_archive(obs_range, spec_range, col_range)

  0%|          | 0/188 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [12]:
features.shape

(2291, 522)

In [13]:
cases_groups = archive.cases_groups
explorer = Explorer(cases_groups, n_neighbors=100)
explorer.fill_memory()
explorer.set_knn()

In [21]:
def recommend_path(dataset, explorer):
    '''Recommend observational path for each sample in dataset with case-based strategy
    
    Parameters
    ----------
    dataset: pd.DataFrame
        Resulting dataframe from calling `build_sets` with features for each sample.
    explorer: baseline.Explorer
        Costume object for case-based recommendation.
        
    Returns
    -------
    dataset_cases: list
        List with a `CaseNode` object representing the recommended path for each sample in dataset.
        
    '''
    
    dataset_cases = []
    for i, sample in dataset.iterrows():
        n_sources_ = (5,0,0)
        setup = CasesArchive.costumize_setup(*n_sources_, setup_base)
        setup['all_features'] = sample
        case = CaseNode(**setup)
        stop = False

        while not stop:
            step, delta_r = explorer.recommend_step(case)
            if step is None:
                break
            if delta_r>0:
                if (step=='photo') & (case.n_obs<case.max_obs):
                    setup['n_obs'] = setup['n_obs'] + 1
                    case = CaseNode(**setup)
                elif (step=='spectrum') & (case.n_spec==0):
                    setup['n_spec'] = 1
                    case = CaseNode(**setup)
                elif (step=='color') & (case.n_color==0):
                    setup['n_color'] = 1
                    case = CaseNode(**setup)
                else:
                    stop = True
            else:
                stop = True
        dataset_cases.append(case)
    
    return dataset_cases

In [22]:
cases_val = recommend_path(val, explorer)
cases_test = recommend_path(test, explorer)

In [23]:
with open(val_save, 'wb') as f:
    pickle.dump(cases_val, f)
    
with open(test_save, 'wb') as f:
    pickle.dump(cases_test, f)

In [24]:
np.mean([c.reward for c in cases_test]), np.mean([c.reward for c in cases_val])

(0.6848559729973343, 0.7022574139066285)