In [6]:
import pandas as pd
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold
from sklearn.naive_bayes import GaussianNB
import scipy.stats as stats

from utils.base_set import X_train, y_train, seed
from utils.randomized_search import pNpUniform

### Naive Bayes

In [7]:
model = GaussianNB()
k = 5
cv = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
params = {
    "priors": pNpUniform(0, 1), # esto prob. es overfiteo. Deberiamos testear a lo sumo alrededor de las prob. estimadas del set
    "var_smoothing": stats.uniform(0, 1e-2)
}
n = 100
classifier_NBayes = RandomizedSearchCV(estimator=model, 
                                    param_distributions=params, 
                                    n_iter=n, 
                                    cv=cv, 
                                    scoring='roc_auc',
                                    random_state=seed)
classifier_NBayes.fit(X_train, y_train)

In [8]:
best = classifier_NBayes.best_params_
auc_roc = classifier_NBayes.best_score_
print(best)
print(auc_roc)

{'priors': (0.46353737605279044, 0.5364626239472096), 'var_smoothing': 0.009998052376029345}
0.8409246677937755


In [9]:
columns_to_keep = ['param_priors','param_var_smoothing','mean_test_score','rank_test_score']
pd.DataFrame(classifier_NBayes.cv_results_).sort_values("rank_test_score")[columns_to_keep]

Unnamed: 0,param_priors,param_var_smoothing,mean_test_score,rank_test_score
6,"(0.46353737605279044, 0.5364626239472096)",0.009998,0.840925,1
52,"(0.5409363175488333, 0.4590636824511667)",0.009751,0.840805,2
16,"(0.2535078793981672, 0.7464921206018328)",0.009733,0.840805,2
43,"(0.19961932066469246, 0.8003806793353075)",0.009767,0.840805,2
59,"(0.5631251785254403, 0.4368748214745597)",0.009801,0.840692,5
...,...,...,...,...
26,"(0.9331706881631359, 0.06682931183686414)",0.000324,0.792758,96
37,"(0.08590641854125636, 0.9140935814587436)",0.000172,0.790466,97
63,"(0.8831395670588915, 0.11686043294110848)",0.000132,0.789434,98
14,"(0.9535598966771824, 0.04644010332281756)",0.000093,0.787484,99
