In [1]:
import pandas as pd
from sklearn.model_selection import StratifiedKFold, train_test_split
import numpy as np
from collections import Counter
import os

In [2]:
survey = 'Gaia'

Read the data

In [3]:
path = '/home/Data/Paper_2/Prepare_dataset/Gaia/V5/Dataset_Gaia_Phys_V5.dat'
df = pd.read_csv(path)
df.head()

Unnamed: 0,ID,Path,N,N_b,N_r,Class,T_eff,e_T_eff,E_T_eff,Lum,...,E_Rad,logg,e_logg,E_logg,Mass,e_Mass,E_Mass,rho,e_rho,E_rho
0,3985923473972534400,/home/Data/Databases/GAIA/Consolidate_Gaia/dat...,8,4,4,DSCT_SXPHE,-1.0,-1.0,-1.0,-1.0,...,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0
1,3986570197263160320,/home/Data/Databases/GAIA/Consolidate_Gaia/dat...,10,5,5,RRAB,6481.0,-1.0,-1.0,-1.0,...,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0
2,3986754189367115264,/home/Data/Databases/GAIA/Consolidate_Gaia/dat...,12,6,6,RRAB,7381.3335,7173.793,7558.6665,2.83599,...,-1.0,4.5185,-1.0,-1.0,1.55,-1.0,-1.0,1.0603,-1.0,-1.0
3,3987237630885709312,/home/Data/Databases/GAIA/Consolidate_Gaia/dat...,12,7,5,RRAB,7011.0,-1.0,-1.0,-1.0,...,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0
4,3987697089307190016,/home/Data/Databases/GAIA/Consolidate_Gaia/dat...,8,4,4,RRAB,7191.5,6808.0,8884.0,2.30248,...,-1.0,4.6352,-1.0,-1.0,1.58,-1.0,-1.0,1.5719,-1.0,-1.0


Define the properties of each light curve by Minimum length!

In [4]:
min_N =  10
max_N = 1000
    
max_L = 10000
min_L = 500
    
bands = [i for i in df.columns if 'N_' in i]
b = np.ones(df.shape[0], dtype=np.bool_)
for band in bands:
    b_band = df[band]>min_N
    b = np.logical_and(b, b_band)

df = df[b].copy()
df = df.reset_index().drop('index', axis=1)

Sample at most 40k elements per class.

20% -> for testing
70% -> for training
10% -> for validation


In [5]:

def sample_classes(data):
    # Leave up_to N_max objects per class
    dfs = []
    classes= list(data.Class.unique())
    num_classes = len(classes)
    
    for i in classes.copy():
        # Objects of the class
        bol = data.Class == i
        sel = data[bol]

        # Limit the minimum number of light curves
        if sel.shape[0] < min_L:
            # Update the classes
            classes.remove(i)
            num_classes = len(classes)
            # Skip the class
            continue

        # Return the min among the number of objects and max_L
        num = min(max_L, sel.shape[0])
        # Get a random sample
        sel = sel.sample(num, replace=False, axis=0)
        dfs.append(sel)
    # Join the dataframes of each class together
    data = pd.concat(dfs)
    return data

In [6]:
df = sample_classes(df)

In [7]:
Counter(df.Class)

Counter({'RRAB': 1000,
         'RRC': 1000,
         'DSCT_SXPHE': 1000,
         'MIRA_SR': 1000,
         'T2CEP': 1000,
         'CEP': 1000})

Create the folds, we will work with **5** folds. with this, the train split contains 80% of the total, and the test set contains 20%.

From this train split, we take 12.5% to validate and the remaining 87.5% is used to train.

From the total, the test set contains 20%, the trainning set contains 70% and the validation set, 10%.

In [8]:
kfolds = StratifiedKFold(n_splits=3, shuffle=True, )

In [9]:
path_folds = './Folds'
if not os.path.exists(path_folds):
    os.mkdir(path_folds)
# First split test
df_temp, df_test = train_test_split(df, stratify=df.Class, train_size=0.8)
df_temp.reset_index(inplace=True)
df_test.reset_index(inplace=True)

path_test = os.path.join(path_folds, 'test.csv')
df_test.to_csv(path_test, index=False, index_label=False)

for n, (train_index, val_index) in enumerate(kfolds.split(df_temp.index.values, df_temp.Class.values)):
    # Get the train and validation splits
    df_train = df_temp.loc[train_index]
    df_val = df_temp.loc[val_index]


    path_folds_ = os.path.join(path_folds, 'Fold_'+str(n+1))
    if not os.path.exists(path_folds_):
        os.mkdir(path_folds_)
        
    path_train = os.path.join(path_folds_, 'train.csv')

    path_val = os.path.join(path_folds_, 'val.csv')
    
    df_train.to_csv(path_train, index=False, index_label=False)
    df_val.to_csv(path_val, index=False, index_label=False)