-
Notifications
You must be signed in to change notification settings - Fork 1
/
64shot_mlp.py
106 lines (91 loc) · 3.29 KB
/
64shot_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
from sklearn.metrics import balanced_accuracy_score
from utils import (build_model, compute_accuracy_heatmaps, diagnose_output,
prepare_dataset, print_dataset_info, repeat_and_collate,
set_classification_targets)
def classify(**args):
"""
Main method that prepares dataset, builds model, executes training and displays results.
:param args: keyword arguments passed from cli parser
"""
# only allow print-outs if execution has no repetitions
allow_print = args['repetitions'] == 1
# determine classification targets and parameters to construct datasets properly
cls_target, cls_str = set_classification_targets(args['cls_choice'])
d = prepare_dataset(
args['dataset_choice'],
cls_target,
args['batch_size'],
args['norm_choice'],
mp_heatmap=True)
print('\n\tTask: Classify «{}» using «{}»\n'.format(cls_str, d['data_str']))
print_dataset_info(d)
model = build_model(0, d['num_classes'], name='64shot_mlp', new_input=True)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
if allow_print:
model.summary()
print('')
# train and evaluate
model.fit(
x=d['train_data'],
steps_per_epoch=d['train_steps'],
epochs=args['epochs'],
verbose=1,
class_weight=d['class_weights'])
model.evaluate(d['eval_data'], steps=d['test_steps'], verbose=1)
# predict on testset and calculate classification report and confusion matrix for diagnosis
pred = model.predict(d['test_data'], steps=d['test_steps'], verbose=1)
# instead of argmax, reduce list to only on-target predictions to see how accurate the model judged each shot
target_preds = [pred[i][l] for i,l in enumerate(d['test_labels'])]
pred = pred.argmax(axis=1)
compute_accuracy_heatmaps(d, target_preds, cls_target, args['epochs'])
return balanced_accuracy_score(d['test_labels'], pred)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-r', '--repetitions',
type=int,
default=1,
help='Number of times to repeat experiment',
dest='repetitions'
)
parser.add_argument(
'-b', '--batchsize',
type=int,
default=64,
help='Target batch size of dataset preprocessing',
dest='batch_size'
)
parser.add_argument(
'-d', '--dataset',
type=int,
choices=[1, 2],
default=1,
help='Which dataset(s) to use. 1=hh_12, 2=hh_all',
dest='dataset_choice'
)
parser.add_argument(
'-c', '--classification',
type=int,
choices=[0, 1, 2],
default=2,
help='Which classification target to pursue. 0=classes, 1=subgroups, 2=minerals',
dest='cls_choice'
)
parser.add_argument(
'-e', '--epochs',
type=int,
default=10,
help='How many epochs to train for',
dest='epochs'
)
parser.add_argument(
'-n', '--normalisation',
type=int,
choices=[0, 1, 2],
default=2,
help='Which normalisation to use. 0=None, 1=snv, 2=minmax',
dest='norm_choice'
)
args = parser.parse_args()
repeat_and_collate(classify, **vars(args))