diff --git a/few/evaluation.py b/few/evaluation.py index db4c173..56e0b8c 100644 --- a/few/evaluation.py +++ b/few/evaluation.py @@ -8,7 +8,7 @@ import numpy as np from sklearn.metrics import explained_variance_score, mean_absolute_error, mean_squared_error, median_absolute_error, r2_score import pdb -from sklearn.metrics import silhouette_samples, silhouette_score, accuracy_score +from sklearn.metrics import silhouette_samples, silhouette_score, accuracy_score, roc_auc_score import itertools as it import sys from sklearn.metrics.pairwise import pairwise_distances @@ -107,6 +107,7 @@ class EvaluationMixin(object): 'fisher': lambda y,yhat: 1 - fisher(yhat,y), 'accuracy': lambda y,yhat: 1 - accuracy_score(yhat,y), 'random': lambda y,yhat: np.random.rand(), + 'roc_auc': lambda y,yhat: 1 - roc_auc_score(y,yhat) # 'relief': lambda y,yhat: 1-ReliefF(n_jobs=-1).fit(yhat.reshape(-1,1),y).feature_importances_ } #