In [1]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np 
import pandas as pd
import sklearn
import imblearn
import pickle
from joblib import load, dump

In [None]:
# Download features
!scp -r -P 22334 -i mi_llave guanaco.inf.uach.cl:/home/shared/astro/PLAsTiCC/fats_featurs.tar.gz .
!tar xzvf fats_features.tar.gz

In [20]:
# Select seed, load ids and features

seed = 1
with open(f"ids/seed{seed}/maxClass15k/dataset_ids_before_balancing.pkl", "rb") as f:
    lc_ids = pickle.load(f)
    
features = {}
for subset in ['train', 'validation', 'test']:
    tmp = []
    for lc_id in lc_ids[subset]:
        with open("features/fats"+str(int(lc_id))+".pkl", "rb") as f:
            tmp.append(load(f))
    features[subset] = pd.concat(tmp, axis=0)

In [34]:
# Train balanced RF
from sklearn.metrics import f1_score, classification_report
from imblearn.ensemble import BalancedRandomForestClassifier

rf = BalancedRandomForestClassifier(n_estimators=500, criterion='entropy', replacement=True,
                                    max_depth=10, class_weight='balanced', n_jobs=8) 
                

rf.fit(features['train'].fillna(-1000).values, 
       lc_ids['labels_train'].astype('int')) 

preds = rf.predict(features['validation'].fillna(-1000).values)
print(classification_report(preds, lc_ids['labels_validation'].astype('int')))
print(rf.features)          

              precision    recall  f1-score   support

           6       0.91      0.51      0.66       260
          16       0.95      0.95      0.95      1577
          53       0.98      0.99      0.99       146
          65       0.93      0.94      0.94      1581
          88       0.96      0.98      0.97      1494
          92       0.96      0.98      0.97      1487

    accuracy                           0.95      6545
   macro avg       0.95      0.89      0.91      6545
weighted avg       0.95      0.95      0.95      6545



In [36]:
# Result on test set
preds = rf.predict(features['test'].fillna(-1000).values)
print(classification_report(preds, lc_ids['labels_test'].astype('int')))

              precision    recall  f1-score   support

           6       0.94      0.48      0.64       286
          16       0.94      0.97      0.95      1547
          53       0.99      1.00      0.99       147
          65       0.93      0.95      0.94      1569
          88       0.96      0.99      0.97      1489
          92       0.97      0.98      0.97      1508

    accuracy                           0.95      6546
   macro avg       0.95      0.89      0.91      6546
weighted avg       0.95      0.95      0.95      6546



In [None]:
# Features sorted by importance
features['train'].columns.values[np.argsort(rf.feature_importances_)[::-1]]