forked from shogun-toolbox/shogun
/
trained_model_serialization_unittest.cc.py
158 lines (119 loc) · 4.81 KB
/
trained_model_serialization_unittest.cc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/env python
# This software is distributed under BSD 3-clause license (see LICENSE file).
#
# Authors: Michele Mazzoni, Sergey Lisitsyn
# Classes to ignore: mostly because default initialization isn't enough
# to setup the machine for training (i.e. Multitask and DomainAdaptation),
# different reasons are given below.
IGNORE = [
# LinearMachines
'CFeatureBlockLogisticRegression', 'CLibLinearMTL',
'CMultitaskLinearMachine', 'CMultitaskLogisticRegression',
'CMultitaskL12LogisticRegression', 'CMultitaskLeastSquaresRegression',
'CMultitaskTraceLogisticRegression', 'CMultitaskClusteredLogisticRegression',
'CLatentSVM', 'CLatentSOSVM', 'CDomainAdaptationSVMLinear',
'CLinearLatentMachine',
# KernelMachines
'CDomainAdaptationSVM', 'CMKLRegression',
'CMKLClassification', 'CMKLOneClass',
'CSVM', # doesn't implement a solver
'CMKL',
# LinearMulticlassMachines
'CDomainAdaptationMulticlassLibLinear',
'CMulticlassTreeGuidedLogisticRegression',
'CShareBoost', # apply() takes features subset
# KernelMulticlassMachines
'CMulticlassSVM', # doesn't implement a solver
'CMKLMulticlass',
'CScatterSVM', # error C <= 0
'CMulticlassLibSVM' # error C <= 0
]
def read_defined_guards(config_file):
with open(config_file) as f:
config = f.read().lower()
return re.findall('#define (\w+)', config)
def is_guarded(include, defined_guards):
with open(include) as header:
guards = re.findall('#ifdef (\w+)', header.read().lower())
return any([g not in defined_guards for g in guards])
def ignore_in_class_list(include):
with open(include) as header:
return 'IGNORE_IN_CLASSLIST' in header.read()
def is_pure_virtual(name, tags):
return any([name + '\timplementation:pure virtual' in tag for tag in tags])
def use_gpl(path, defined_guards):
return 'src/gpl/' not in path or 'use_gpl_shogun' in defined_guards
def is_shogun_class(c):
return c[0] == 'C' and c[1].isupper() and 'class' in c
def get_shogun_classes(tags):
classes = {}
# in ctags format it is TAG\tLOCATION\t..\tinherits:CLASS
for line in filter(is_shogun_class, tags):
attrs = line.strip().split('\t')
inherits_str = 'inherits:'
symbol, location = attrs[0], attrs[1]
base = attrs[-1][len(inherits_str):] if attrs[-1].startswith(inherits_str) else None
classes[symbol] = {
'include': location,
'base': base}
return classes
def get_ancestors(classes, name):
b = classes[name]['base']
return [b] + get_ancestors(classes, b) if b in classes else []
def read_ctags(filename):
if not os.path.exists(filename):
raise Exception('Failed to found ctags file at %s' % (filename))
with open(filename) as file:
return file.readlines()
def generate_tests(input_file, config_file):
tags = read_ctags(input_file)
classes = get_shogun_classes(tags)
guards = read_defined_guards(config_file)
bases = [
'CLinearMachine', 'CKernelMachine', 'CLinearMulticlassMachine',
'CKernelMulticlassMachine', 'CNativeMulticlassMachine'
]
# Gather all the machines that inherit from the classes in bases
machines = {b: {} for b in bases}
for name, attrs in classes.items():
ancestors = get_ancestors(classes, name)
header = attrs['include']
for base in bases:
if (base in ancestors) \
and not name in IGNORE \
and not is_guarded(header, guards) \
and not is_pure_virtual(name, tags) \
and not ignore_in_class_list(header) \
and use_gpl(header, guards):
machines[base][name] = attrs
include_template = '#include "{0}"\n'
typelist_template = 'typedef ::testing::Types<{0}> {1}Types;\n'
base_test_map = {
'CLinearMachine': 'Machine',
'CNativeMulticlassMachine': 'Machine',
'CLinearMulticlassMachine': 'Machine',
'CKernelMachine': 'KernelMachine',
'CKernelMulticlassMachine': 'KernelMachine',
}
test_machines_map = {
'Machine': [],
'KernelMachine': []
}
headers = ''
for b, m in machines.items():
test_machines_map[base_test_map[b]] += m.keys()
headers += ''.join([include_template.format(v['include']) for v in m.values()])
typelists = ''
for k, v in test_machines_map.items():
typelists += typelist_template.format(", ".join(v), k)
return headers + '\n' + typelists
# execution
# ./trained_model_serialization_unittest.cc.py
# <input file> <output file> <config file>
import sys, os, re
input_file = sys.argv[1]
output_file = sys.argv[2]
config_file = sys.argv[3]
outputText = generate_tests(input_file, config_file)
with open(output_file, 'w') as f:
f.writelines(outputText)