This repository has been archived by the owner on Dec 18, 2023. It is now read-only.
/
classification_model.py
72 lines (58 loc) · 2.57 KB
/
classification_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Model artifact wrapping any classification model"""
from .base_model import Model
PREDICT_PROBA_FRAMEWORKS = ["sklearn", "xgboost"]
class ClassificationModel(Model):
"""Class wrapper around classification model to be assessed
ClassificationModel serves as an adapter between arbitrary binary or multi-class
classification models and the evaluations in Lens. Evaluations depend on
ClassificationModel instantiating `predict` and (optionally) `predict_proba`
Parameters
----------
name : str
Label of the model
model_like : model_like
A binary or multi-class classification model or pipeline. It must have a
`predict` function that returns array containing the class labels for each sample.
It can also optionally have a `predict_proba` function that returns array containing
the class labels probabilities for each sample.
"""
def __init__(self, name: str, model_like=None, tags=None):
super().__init__(
"Classification",
["predict", "predict_proba"],
["predict"],
name,
model_like,
tags,
)
def _update_functionality(self):
"""Conditionally updates functionality based on framework"""
if self.model_info["framework"] in PREDICT_PROBA_FRAMEWORKS:
func = getattr(self, "predict_proba", None)
if func and len(self.model_like.classes_) == 2:
self.__dict__["predict_proba"] = lambda x: func(x)[:, 1]
class DummyClassifier:
"""Class wrapper around classification model predictions
This class can be used when a classification model is not available but its outputs are.
The output include the array containing the predicted class labels and/or the array
containing the class labels probabilities.
Wrap the outputs with this class into a dummy classifier and pass it as
the model to `ClassificationModel`.
Parameters
----------
predict_output : array
Array containing the output of a model's "predict" method
predict_proba_output : array
Array containing the output of a model's "predict_proba" method
"""
def __init__(
self, name: str, predict_output=None, predict_proba_output=None, tags=None
):
self.predict_output = predict_output
self.predict_proba_output = predict_proba_output
self.name = name
self.tags = tags
def predict(self, X=None):
return self.predict_output
def predict_proba(self, X=None):
return self.predict_proba_output