-
Notifications
You must be signed in to change notification settings - Fork 4
/
plotroccurve.jl
88 lines (85 loc) · 2.42 KB
/
plotroccurve.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import LaTeXStrings
import PGFPlotsX
"""
"""
function plotroccurve(
estimator::AbstractFittable,
features_df::DataFrames.AbstractDataFrame,
labels_df::DataFrames.AbstractDataFrame,
single_label_name::Symbol,
positive_class::AbstractString;
kwargs...,
)
vectorofestimators = [estimator]
result = plotroccurve(
vectorofestimators,
features_df,
labels_df,
single_label_name,
positive_class;
kwargs...,
)
return result
end
"""
"""
function plotroccurve(
vectorofestimators::AbstractVector{AbstractFittable},
features_df::DataFrames.AbstractDataFrame,
labels_df::DataFrames.AbstractDataFrame,
single_label_name::Symbol,
positive_class::AbstractString;
legend_pos::AbstractString = "outer north east",
)
legend_pos::String = convert(String, legend_pos)
if length(vectorofestimators) == 0
error("length(vectorofestimators) == 0")
end
all_plots_and_legends = []
for i = 1:length(vectorofestimators)
estimator_i = vectorofestimators[i]
metrics_i = singlelabelbinaryclassificationmetrics_resultdict(
estimator_i,
features_df,
labels_df,
single_label_name,
positive_class;
threshold = 0.5,
)
ytrue_i = metrics_i[:ytrue]
yscore_i = metrics_i[:yscore]
allfpr_i, alltpr_i, allthresholds_i = roccurve(
ytrue_i,
yscore_i,
)
plot_i = PGFPlotsX.@pgf(
PGFPlotsX.Plot(
PGFPlotsX.Coordinates(
allfpr_i,
alltpr_i,
),
),
)
legend_i = PGFPlotsX.@pgf(
PGFPlotsX.LegendEntry(
LaTeXStrings.LaTeXString(estimator_i.name)
),
)
push!(all_plots_and_legends, plot_i)
push!(all_plots_and_legends, legend_i)
end
all_plots_and_legends = [all_plots_and_legends...]
p = PGFPlotsX.@pgf(
PGFPlotsX.Axis(
{
xlabel = "False positive rate",
ylabel = "True positive rate",
no_markers,
legend_pos = legend_pos,
},
all_plots_and_legends...,
),
)
wrapper = PGFPlotsXPlot(p)
return wrapper
end