-
Notifications
You must be signed in to change notification settings - Fork 1
/
compute_parsing_accuracy.py
341 lines (277 loc) · 12 KB
/
compute_parsing_accuracy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
"""Computes the parsing accuracy of the model on a test suite."""
import argparse
import copy
import json
import os
from os.path import dirname, abspath, join # noqa: E402, F401
import sys
from datetime import datetime
from typing import Any
from pytz import timezone
import gin
import numpy as np
import pandas as pd
import torch
np.random.seed(0)
torch.manual_seed(0)
parent = dirname(dirname(abspath(__file__)))
sys.path.append(parent)
from experiments.get_compositional_queries import run_iid_compositional_accuracies
from experiments.utils import load_test_data, check_correctness
from logic.core import ExplainBot # noqa: E402, F401
# Needed for gin configs
from parsing.gpt.few_shot_inference import get_few_shot_predict_f # noqa: E402, F401
def safe_name(model):
return model.replace("/", "_")
def compute_accuracy(data, predict_func, verbose: bool = False, error_analysis: bool = False,
feature_names: Any = None):
"""Computes the parsing accuracy across some data."""
nn_prompts = None
misses, total = 0, 0
parses_store = {}
print(f"There are {len(data)} eval points", flush=True)
for j, user_input in enumerate(data):
has_all_parse_words = None
if error_analysis:
parse_text, nn_prompts = predict_func(user_input)
else:
parse_text = predict_func(user_input)
if type(parse_text == tuple):
if len(parse_text) == 2 and parse_text[1] is None:
parse_text = parse_text[0]
# Get the gold label parse
correct_parse = data[user_input]
# Do this to make sure extra spaces are ignored around input
is_correct = check_correctness(parse_text, correct_parse)
if error_analysis:
parses = []
print(nn_prompts)
if nn_prompts is None:
parses = []
else:
for prompt in nn_prompts:
split_prompt = prompt.split("parsed: ")
nn_parse = split_prompt[1]
parses.append(nn_parse)
parses = " ".join(parses)
has_all_parse_words = True
for word in correct_parse.split(" "):
# Cases not to look at parse word, i.e., things like numbers,
# bools, and the feature names
try:
float(word)
continue
except:
pass
if word == "true" or word == "false":
continue
if word in feature_names:
continue
if word not in parses:
has_all_parse_words = False
if not is_correct:
misses += 1
if verbose:
print(">>>>>")
print(f"User: {user_input}")
print(f"Parsed: {parse_text}")
print(f"Correct: {correct_parse}")
print(">>>>>")
parses_store[f"parsed_text_{j}"] = parse_text
parses_store[f"correct_parse_{j}"] = correct_parse
parses_store[f"user_input_{j}"] = user_input
if error_analysis:
parses_store[f"includes_all_words_{j}"] = has_all_parse_words
total += 1
if j % 25 == 0:
print(f"Error Rate | it {j} | {round(misses / total, 3)}", flush=True)
error_stat = misses / total
if verbose:
print(f"Final Error Rate: {round(error_stat, 3)}", flush=True)
return error_stat, parses_store
def main():
results = {
"dataset": [],
"model": [],
"num_prompts": [],
"accuracy": [],
"in_domain_accuracy": [],
"compositional_accuracy": [],
"overall_accuracy": [],
"total_in_domain": [],
"total_compositional": [],
"guided_decoding": [],
"iid_errors_pct_not_all_words": [],
"comp_errors_pct_not_all_words": []
}
model = args.model
guided_decoding = args.gd
dset = args.dataset
program_only_text = ""
if args.program_only:
program_only_text += "-program-only"
results_location = (f"./experiments/results_store/{safe_name(model)}_{dset}_gd-{guided_decoding}"
f"_debug-{args.debug}{program_only_text}.csv")
if os.path.exists(results_location):
print(f"Skipping already existing results file: f{results_location}\nPlease delete or move the results file to "
f"produce new ones.")
return
print(f"-----------------", flush=True)
print("Debug:", args.debug, flush=True)
print("Dataset:", dset, flush=True)
print("Model:", model, flush=True)
if dset == "daily_dialog":
config_dset_id = "da"
else:
config_dset_id = dset
if dset not in ["boolq", "olid", "daily_dialog"]:
raise NameError(f"Unknown dataset {dset}")
test_suite = f"./experiments/parsing_interrolang_dev/dev_set_interrolang_{dset}.txt"
if sys.platform.startswith('linux') or sys.platform.startswith('darwin'):
if model == "nearest-neighbor":
config = f"./configs/{config_dset_id}_nn.gin"
elif model == "EleutherAI/gpt-neo-2.7B":
config = f"./configs/{config_dset_id}.gin"
elif model == "flan-t5-base":
config = f"./configs/{config_dset_id}_flan-t5.gin"
else:
config = f"./configs/{config_dset_id}_large.gin"
elif sys.platform.startswith('win32') or sys.platform.startswith('cygwin'):
if model == "\'nearest-neighbor\'":
config = f"./configs/{config_dset_id}_nn.gin"
elif model == "\'EleutherAI/gpt-neo-2.7B\'":
config = f"./configs/{config_dset_id}.gin"
elif model == "\'FLAN-T5\'":
config = f"./configs/{config_dset_id}_flan-t5.gin"
else:
config = f"./configs/{config_dset_id}_large.gin"
else:
raise OSError("Unknown operating system!")
# Parse config
gin.parse_config_file(config)
testing_data = load_test_data(test_suite)
# load the model
bot, get_parse_text = load_model(dset, guided_decoding, model)
error_analysis = False
if "t5" not in model:
error_analysis = True
# load the number of prompts to perform in the sweep
n_prompts_configs = load_n_prompts(model)
if args.debug:
n_prompts_configs = [10, 2]
feature_names = copy.deepcopy(list(bot.conversation.stored_vars["dataset"].contents["X"].columns))
for num_prompts in n_prompts_configs:
# Set the bot to the number of prompts
bot.set_num_prompts(num_prompts)
print("Num prompts:", bot.prompts.num_prompt_template)
assert bot.prompts.num_prompt_template == num_prompts, "Prompt update failing"
# Compute the accuracy
error_rate, all_parses = compute_accuracy(testing_data,
get_parse_text,
args.verbose,
error_analysis=error_analysis,
feature_names=feature_names)
# Add parses to results
for key in all_parses:
if key not in results:
results[key] = [all_parses[key]]
else:
results[key].append(all_parses[key])
# Compute the compositional / iid accuracy splits
iid_comp_results = run_iid_compositional_accuracies(dset,
all_parses,
bot,
program_only=args.program_only)
in_acc, comp_acc, ov_all, total_in, total_comp, iid_pct_keys, comp_pct_keys = iid_comp_results
# Store metrics
results["total_in_domain"].append(total_in)
results["total_compositional"].append(total_comp)
results["in_domain_accuracy"].append(in_acc)
results["compositional_accuracy"].append(comp_acc)
results["overall_accuracy"].append(ov_all)
results["guided_decoding"].append(guided_decoding)
results["model"].append(model)
results["dataset"].append(dset)
results["accuracy"].append(1 - error_rate)
results["num_prompts"].append(num_prompts)
results["iid_errors_pct_not_all_words"].append(iid_pct_keys)
results["comp_errors_pct_not_all_words"].append(comp_pct_keys)
# Write everything to dataframe
final_results = results
result_df = pd.DataFrame(final_results)
result_df.to_csv(results_location)
print("Saved locally...", flush=True)
# optionally upload to wandb
if args.wandb:
import wandb
results_table = wandb.Table(data=result_df)
if args.debug:
table_name = "parsing-accuracy-debug"
else:
table_name = "parsing-accuracy"
run.log({table_name: results_table})
print("Logged to wandb...", flush=True)
print(f"Saved results to {results_location}")
print(f"-----------------")
def load_n_prompts(model):
n_prompts_configs = [
20
]
if model == "EleutherAI/gpt-j-6B":
n_prompts_configs = [10]
if model == "EleutherAI/gpt-neo-2.7B":
n_prompts_configs = [10]
# doesn't matter if we draw many
# when taking nn as result
if model == "nearest-neighbor" or "t5" in model:
n_prompts_configs = [1]
return n_prompts_configs
def load_model(dset, guided_decoding, model):
"""Loads the model"""
print("Initializing model...", flush=True)
if "t5" not in model:
if sys.platform.startswith('linux') or sys.platform.startswith('darwin'):
gin.parse_config(f"ExplainBot.parsing_model_name = '{model}'")
elif sys.platform.startswith('win32') or sys.platform.startswith('cygwin'):
gin.parse_config(f"ExplainBot.parsing_model_name = {model}")
else:
raise OSError("Unknown operating system!")
gin.parse_config(f"ExplainBot.use_guided_decoding = {guided_decoding}")
# if args.debug or "gpt-neo-2.7B" in model:
# gin.parse_config("get_few_shot_predict_f.device = 'cpu'")
# else:
# gin.parse_config("get_few_shot_predict_f.device = 'cuda'")
gin.parse_config("get_few_shot_predict_f.device = 'cpu'")
# Case for NN and few shot gpt models
bot = ExplainBot()
def get_parse_text(user_input_to_parse):
includes_all_words = None
try:
_, result_parse_text, includes_all_words = bot.compute_parse_text(user_input_to_parse,
error_analysis=True)
except Exception as e:
result_parse_text = f"Exception: {e}, likely OOM"
return result_parse_text, includes_all_words
return bot, get_parse_text
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--debug", action="store_true")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--gd", action="store_true")
parser.add_argument("--dataset", type=str, required=True)
parser.add_argument("--id", type=str, required=True, help="a unique id to associate with the run")
parser.add_argument("--down_sample", action="store_true", help="this will break each run on 10 samples")
parser.add_argument("--program_only", action="store_true", help="only uses the program name for templates")
parser.add_argument("--subset", choices=["train", "dev", "test", "user"])
args = parser.parse_args()
if args.wandb:
import wandb
run = wandb.init(project="project-ttm", entity="dslack")
pst = timezone('US/Pacific')
sa_time = datetime.now(pst)
time = sa_time.strftime('%Y-%m-%d_%H-%M')
if args.wandb:
wandb.run.name = f"{args.id}-{safe_name(args.model)}_{args.dataset}_gd-{args.gd}"
main()