-
Notifications
You must be signed in to change notification settings - Fork 2
/
precision_recall_curve_model_comparision.py
67 lines (46 loc) · 1.77 KB
/
precision_recall_curve_model_comparision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Example of computing precision recall curve
"""
# %%
import matplotlib.pyplot as plt
import numpy as np
import sklearn.metrics as skm
y_true = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
# model1 curve will dominate the model0
# output from model0
y_score0 = np.array([0.75, 0.5, 0.3, 0.35, 0.45, 0.7, 0.3, 0.33, 0.5, 0.8])
# output from model1
y_score1 = np.array([0.6, 0.3, 0.3, 0.55, 0.65, 0.4, 0.55, 0.33, 0.75, 0.3])
# model 1 is better
# # output from model0
# y_score0 = np.array([0.75, 0.5, 0.3, 0.35, 0.45, 0.7, 0.3, 0.33, 0.5, 0.8])
# # output from model1
# y_score1 = np.array([0.7, 0.3, 0.3, 0.55, 0.75, 0.4, 0.5, 0.33, 0.72, 0.3])
# looking only at curves it is not so obvious, which one is better
# output from model0
y_score0 = np.array([0.7, 0.45, 0.3, 0.35, 0.45, 0.7, 0.3, 0.33, 0.55, 0.8])
# output from model1
y_score1 = np.array([0.6, 0.3, 0.3, 0.55, 0.65, 0.4, 0.5, 0.33, 0.75, 0.3])
# %
# first model
precision0, recall0, tresholds0 = skm.precision_recall_curve(y_true, y_score0)
# second model
precision1, recall1, tresholds1 = skm.precision_recall_curve(y_true, y_score1)
avg_prec0 = skm.average_precision_score(y_true, y_score0)
auc0 = skm.auc(recall0,precision0)
print(f"Model 0 average_precision={avg_prec0} area under curve={auc0}")
avg_prec1 = skm.average_precision_score(y_true, y_score1)
auc1 = skm.auc(recall1,precision1)
print(f"Model 1 average_precision={avg_prec1} area under curve={auc1}")
# % plot curve
plt.plot(recall0, precision0, 'ro')
plt.plot(recall0, precision0, 'r', label='model 0')
plt.plot(recall1, precision1, 'bo')
plt.plot(recall1, precision1, 'b', label='model 1')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title('Precision-Recall curve for 2 ml models')
plt.legend()
plt.show()
# %%