forked from adaniefei/AccessMath_Pose
-
Notifications
You must be signed in to change notification settings - Fork 1
/
spk_train_03_train_classifier.py
59 lines (43 loc) · 2.1 KB
/
spk_train_03_train_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import sys
from AM_CommonTools.configuration.configuration import Configuration
from AccessMath.util.misc_helper import MiscHelper
from AccessMath.data.meta_data_DB import MetaDataDB
from AccessMath.speaker.actions.pose_feature_extractor import PoseFeatureExtractor
from sklearn.ensemble import RandomForestClassifier
def main():
if len(sys.argv) < 2:
print("Usage")
print("\tpython {0:s} config".format(sys.argv[0]))
return
# initialization #
config = Configuration.from_file(sys.argv[1])
try:
database = MetaDataDB.from_file(config.get_str("VIDEO_DATABASE_PATH"))
except:
print("Invalid AccessMath Database file")
return
# get paths and other configuration parameters ....
output_dir = config.get_str("OUTPUT_PATH")
features_dir = output_dir + "/" + config.get("SPEAKER_ACTION_FEATURES_DIR")
classifier_dir = output_dir + "/" + config.get_str("SPEAKER_ACTION_CLASSIFIER_DIR")
os.makedirs(classifier_dir, exist_ok=True)
classifier_filename = classifier_dir + "/" + config.get_str("SPEAKER_ACTION_CLASSIFIER_FILENAME")
dataset_name = config.get("SPEAKER_TRAINING_SET_NAME")
training_set = database.datasets[dataset_name]
training_titles = [lecture.title.lower() for lecture in training_set]
# get classifier parameters
rf_n_trees = config.get_int("SPEAKER_ACTION_CLASSIFIER_RF_TREES", 64)
rf_depth = config.get_int("SPEAKER_ACTION_CLASSIFIER_RF_DEPTH", 16)
# read all training data available ....
train_dataset = {}
for lecture in training_set:
input_filename = features_dir + "/" + database.name + "_" + lecture.title + ".pickle"
train_dataset[lecture.title.lower()] = MiscHelper.dump_load(input_filename)
train_x, train_y, train_frame_infos = PoseFeatureExtractor.combine_datasets(training_titles, train_dataset)
# classify and confusion matrix part
clf = RandomForestClassifier(n_estimators=rf_n_trees, max_depth=rf_depth, random_state=0)
clf = clf.fit(train_x, train_y)
MiscHelper.dump_save(clf, classifier_filename)
if __name__ == '__main__':
main()