-
Notifications
You must be signed in to change notification settings - Fork 0
/
performance_report.py
35 lines (27 loc) · 1.13 KB
/
performance_report.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import argparse
import json
import os
import pandas as pd
from src.analyze.performance_measures import PerformanceMeasures
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--results-dir', type=str, default=None)
args = parser.parse_args()
def main(results_dir):
try:
with open(os.path.join(results_dir, 'args.json'), 'r') as f:
num_classes = json.load(f).get('n_classes', 4)
except FileNotFoundError:
num_classes = 4
preds = pd.read_csv(os.path.join(results_dir, 'test_predictions.csv'))
gts = pd.read_csv(os.path.join(results_dir, 'test_ground_truths.csv'))
perf_meas = PerformanceMeasures(gts, preds, num_classes=num_classes)
perf_meas.generate_report(save_dir=os.path.join(results_dir, 'performance-report'))
perf_meas.create_figures(save_dir=os.path.join(results_dir, 'performance-report'))
if __name__ == '__main__':
# import glob
#
# dir0 = 'results/spaceml-results-light/**/rey-multilabel-classifier'
# for d in glob.glob(dir0, recursive=True):
# main(d)
main('results/spaceml-results/data_232x300-seed_1/reg_clf')
# main(args.results_dir)