-
Notifications
You must be signed in to change notification settings - Fork 10
/
evaluate.py
127 lines (119 loc) · 3.82 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Wrapper to Evaluate model
"""
# -------------------------------------------------------------------#
# Copyright 2019 The Tefla Authors. All Rights Reserved.
# -------------------------------------------------------------------#
from __future__ import division, print_function, absolute_import
import ast
import click
from prettytable import PrettyTable
from bokeh.io import save
from tefla.core.eval_metrices import Evaluation
class ListAsArguments(click.Option):
"""
class to overload type_cast_value method from click.option
"""
def type_cast_value(self, ctx, value):
try:
return ast.literal_eval(value)
except BaseException:
raise click.BadParameter(value)
# pylint: disable=no-value-for-parameter
@click.command()
@click.option(
'--truth_file',
default=None,
show_default=True,
required=True,
help='Path to file containing ground truth')
@click.option(
'--pred_files',
default='[]',
cls=ListAsArguments,
show_default=True,
required=True,
help='Path to file containing predictions.')
@click.option(
'--eval_list',
default='[]',
cls=ListAsArguments,
show_default=True,
required=True,
help='List of Evaluation matrices to be evaluated.')
@click.option(
'--plot_list',
default='[]',
cls=ListAsArguments,
show_default=True,
help='List of Evaluation plots.')
@click.option(
'--over_all',
default=False,
show_default=True,
help='Flag If overall results are required instead of classwise results.')
@click.option(
'--ensemble_voting',
default="soft",
show_default=True,
help='The type of voting strategy to be used incase of ensemble.')
@click.option(
'--ensemble_weights',
default='[]',
cls=ListAsArguments,
show_default=True,
help='Weights in case of ensemble with weights')
@click.option(
'--class_names', default='[]', cls=ListAsArguments, show_default=True, help='Name of classes')
@click.option(
'--convert_binary',
default=False,
show_default=True,
help='Flag to indicate if problem should be evalauted as binary(normal vs abnormal)')
@click.option(
'--binary_threshold',
default=0.5,
show_default=True,
help='Threshold value to determine nomral and abnormal.')
@click.option(
'--save_dir', default='.', show_default=True, help='Path where evaluation plots will be saved.')
@click.option(
'--eval_type',
default=None,
show_default=True,
required=True,
help='Evaluation type classification or Regression.')
# pylint: disable=too-many-locals
# pylint: disable=too-many-arguments
def main(truth_file, pred_files, eval_list, plot_list, over_all, ensemble_voting, ensemble_weights,
class_names, convert_binary, binary_threshold, save_dir, eval_type):
"""
wrapper function to call eval_metrices api.
"""
eval_model = Evaluation()
if eval_type.lower() == 'classification':
evl_result, evl_plots = eval_model.eval_classification(
truth_file,
pred_files,
eval_list,
plot_list,
over_all=over_all,
ensemble_voting=ensemble_voting,
ensemble_weights=ensemble_weights,
class_names=class_names,
convert_binary=convert_binary,
binary_threshold=binary_threshold)
elif eval_type.lower() == 'regression':
evl_result, evl_plots = eval_model.eval_regression(
truth_file, pred_files, eval_list, plot_list, ensemble_weights=ensemble_weights)
else:
raise ValueError("invalid option, provide either classification or regression")
p_tab = PrettyTable()
p_tab.field_names = ["class"] + eval_list
for keys, values in evl_result.items():
p_tab.add_row([keys] + [values[evl] for evl in eval_list])
print(p_tab)
if evl_plots:
for i, plt in enumerate(evl_plots):
save(plt, filename=save_dir + '/' + str(i) + '.html')
if __name__ == '__main__':
main()