-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_tasks.py
337 lines (301 loc) · 12 KB
/
test_tasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
"""
File implementing a single forward pass on the test set.
+ provides a main function with argument parsing to run the file as a script if needed!
Saves results in a file.
WARNING: some features might be deprecated because of the script has not been used in a long time.
"""
import numpy as np
import tensorflow as tf
import json
import os
import argparse
import shutil
import yaml
import time
import pickle
from scipy.special import softmax
from task_utils import (
MergeConfigs,
ForwardModel,
LoadDatasets,
LoadMetrics,
LoadLosses,
AssignWeights,
WeightRecoveryEvaluation
)
from model.gates.dselect_k_gate import DSelectKGate
from model.main_model import MoE
from model.main_model_stacked import MoEStacked
ROOT_PATH = os.path.dirname(os.path.realpath(__file__))
def main():
"""
Parses the arguments from the command and launches a test procedure using test_model.
inputs: None
outputs: None.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_config",
default="./config/model_config/dense_experts_trimmed_lasso_simplex_gate.json",
type=str,
help="Path to the model's config (in which its architecture is defined).",
)
parser.add_argument(
"--from_pretrained",
default=None,
type=str,
help="Path to pretrained weights of the model architecture defined in the config.",
)
parser.add_argument(
"--train_config",
default="./config/train_config/example_train.json",
type=str,
help="Path to train config (also containing val parameters to be used for testing).",
)
parser.add_argument(
"--task_config",
default="./config/task_config/task_configs.yml",
type=str,
help="Path to task configs in which task parameters are defined.",
)
parser.add_argument(
"--results_location",
default="../results/",
type=str,
help="Path to location where to save results (performance metrics).",
)
parser.add_argument(
"--experiment_name",
default="test_"+str(int(time.time())),
type=str,
help="Name of the folder where to save results (located in results_location).",
)
parser.add_argument(
"--ground_truth_weights_location",
default="./data/raw/synthetic_regressions/shared_features/gate_weights.pkl",
type=str,
help="Path to optimal weights for support recovery evaluation.",
)
parser.add_argument(
"--use_test_dataset",
default=True,
type=bool,
help="Whether to use the test or val dataset for out-of-sample evaluation.",
)
parser.add_argument(
"--perform_weight_recovery_analysis",
default=False,
action='store_true',
help="Whether to use the ground_truth_weights_location mentioned to perform weight recovery analysis.",
)
parser.add_argument(
"--use_MoE_stacked",
default=False,
action='store_true',
help="Whether or not to use the MoE implementation with stacked experts and gates (so long as they are all of the same type).",
)
args = parser.parse_args()
# We open, merge, and add info to the config files (used to define the model architecture,
# the training hyperparameters, and the characteristics of the set of tasks to solve).
with open(args.model_config) as f:
model_config = json.load(f)
with open(args.train_config) as f:
train_config = json.load(f)
with open(args.task_config) as f:
task_config = yaml.load(f,Loader=yaml.FullLoader)
config = MergeConfigs(model_config, train_config, task_config)
config['ground_truth_weights_location'] = args.ground_truth_weights_location
config['perform_weight_recovery_analysis'] = args.perform_weight_recovery_analysis
config['use_MoE_stacked'] = args.use_MoE_stacked
# We load the datasets used for our tests.
# NOTE: either the test dataset OR val dataset from trainval can be used.
print("Loading datasets...")
if args.use_test_dataset:
test_dataset = LoadDatasets(config, trainval=False)
else:
_, test_dataset = LoadDatasets(config)
# We instantiate our model to train.
print("\nInstantiating model...")
if config['use_MoE_stacked']:
model = MoEStacked(config)
else:
model = MoE(config)
# We instantiate the task losses and metrics.
print("Instantiating losses and metrics...")
task_losses_list = LoadLosses(config)
task_metrics_list = LoadMetrics(config)
# We configure the path/folder where we will store our test results.
results_path = args.results_location + args.experiment_name
print("Creating directory '% s' for performance records..." % results_path)
abs_results_path = os.path.join(ROOT_PATH, results_path)
if os.path.exists(abs_results_path):
# os.mkdir(abs_results_path)
shutil.rmtree(abs_results_path)
os.mkdir(abs_results_path)
# shutil.rmtree(abs_results_path)
os.mkdir(abs_results_path + "/performance")
# We execute once our custom testing loop here (defined below), which returns monitoring objects.
(
metrics_test,
batch_loss_test,
weight_recovery_eval
) = test_model(
model,
test_dataset,
task_losses_list,
task_metrics_list,
config
)
# we store our monitoring objects in a dict to facilitate saving in json files.
metrics = {
"test": metrics_test,
"weight_recovery_eval": weight_recovery_eval
}
loss = {
"test": batch_loss_test,
}
# We save our monitoring objects.
with open(abs_results_path + "/performance/metrics.json", "w") as f:
json.dump(metrics,f)
with open(abs_results_path + "/performance/loss.json", "w") as f:
json.dump(loss,f)
return
def test_model(
model,
test_dataset,
task_losses_list,
task_metrics_list,
config
):
"""
Computes the metrics of the model for each task, when tested on the test set (or the val set).
Performs weight recovery on unstacked static gates if requested in config.
inputs:
- instantiated model to train (does not need to be built yet, unless pretrained)
- test (or val) dataloader
- list of task losses objects to aggregate at the end of each forward pass
- list of task metrics objects whose value we need to save at the end of each forward pass
- merged config with training-related hyperparameter values
outputs:
- metrics computed for all epochs on the test (or val) set
- mini-batch loss computed for each mini-batch on the test (or val) set
- weight recovery evaluation results (if requested in config).
"""
# We instantiate the objects we will use to store the measurements we will make during the tests.
metrics_test = [
{
metric: 0
for metric in metrics_dict
}
for metrics_dict in task_metrics_list
]
batch_loss_test = []
# We save our test set size for convenience
test_dataset_size = sum([batch[0].shape[0] for batch in test_dataset])
start_time = time.time()
# In our implementation, all gates have to be either instance-specific or static; not a mix of both.
print("\nStart of test procedure")
# Iterate over the batches of the test dataset.
for batch in test_dataset:
# Forward pass and update test metrics
_, batch_loss = ForwardModel(
model,
batch,
task_losses_list,
task_metrics_list,
training=False
)
# We store batch-wise measurements here (in our objects).
batch_loss_test.append(float(batch_loss))
print("Time taken: %.2fs\n\n" % (time.time() - start_time))
# Aggregated metrics on the test set for this single epoch, based on our measurements for each step.
mean_loss = sum(
batch[0].shape[0] * batch_loss for batch,batch_loss in zip(test_dataset,batch_loss_test)
) / test_dataset_size
print(f"Test loss: {mean_loss}\n")
# We display test metrics at the end of this epoch.
for i,task_metrics in enumerate(task_metrics_list):
print(f"--- Test metrics for Task {config['taskset']['tasks'][i]['name']} ---")
for metric in task_metrics:
metric_value = float(task_metrics[metric].result())
metrics_test[i][metric] = metric_value
print(f"-------- Test {metric}: {metric_value} --------\n")
# If our taskset involves a weight recovery evaluation (i.e.: how do the weights of our trained model compare to that of
# a ground truth data generating function), then we perform the weight recovery analysis here.
if (
config["perform_weight_recovery_analysis"]
) and (
(
not config['use_MoE_stacked'] and not model.gates[0].use_routing_input
)
):
# Using this on any gate other than (sparse) simplex, trimmed lasso simplex, softmax, topk softmax, dselect-k
# will trigger an error.
print("Weight recovery evaluation:")
# We first load the ground truth weights we expect to have recovered.
with open(config['ground_truth_weights_location'],"rb") as f:
all_ground_truth_weights = [
ground_truth_weights.reshape(1,-1) for ground_truth_weights in pickle.load(f)
]
all_ground_truth_weights = np.concatenate(all_ground_truth_weights)
# Then we extract the learned weights of the gates of our model.
# Some of our types of static gates need their underlying weights to be reparameterized in order to be compared
# to the ground truth.
all_gate_weights = []
for i,gate in enumerate(model.gates):
gate_weights = gate.get_weights()[0]
if "softmax" in gate.name and "top_k" in gate.name:
topk = tf.math.top_k(
tf.reshape(tf.expand_dims(gate_weights, 1), [-1]),
k=gate.get_config()["k"]
)
topk_scattered = tf.scatter_nd(
tf.reshape(topk.indices, [-1, 1]),
topk.values,
[gate_weights.shape[0]]
)
topk_prep = tf.where(
tf.math.equal(topk_scattered, tf.constant(0.0)),
-np.inf*tf.ones_like(topk_scattered), # we add the mask here
topk_scattered
)
gate_weights = tf.nn.softmax(
tf.expand_dims(topk_prep, 1), # else, we get an error in the softmax activation
axis=0
).numpy()
elif "softmax" in gate.name and "topk" not in gate.name:
gate_weights = softmax(
gate_weights
)
elif "d_select_k" in gate.name:
gate_weights = gate.gate.compute_expert_weights()[0].numpy()
gate_weights = gate_weights.reshape(1,-1)
print(f"--- Weights of gate for {config['taskset']['tasks'][i]['name']}: ---")
print(gate_weights)
print("\n")
all_gate_weights.append(gate_weights)
all_gate_weights = np.concatenate(all_gate_weights)
# Now that we have loaded the ground truth weights
# AND extracted/re-parameterized the learned weights of our gates,
# we finally perform the weight recovery analysis.
_, perm_predicted_weights = AssignWeights(
all_gate_weights,
all_ground_truth_weights
)
weight_recovery_eval = WeightRecoveryEvaluation(
all_ground_truth_weights,
all_gate_weights[:,perm_predicted_weights],
)
return (
metrics_test,
batch_loss_test,
weight_recovery_eval
)
else:
return (
metrics_test,
batch_loss_test,
"no weight recovery evaluation"
)
if __name__ == "__main__":
main()