This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
opt.py
148 lines (124 loc) · 4.41 KB
/
opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Opt is the system for passing around options throughout ParlAI.
"""
import copy
import json
import pickle
import traceback
import parlai.utils.logging as logging
from typing import List
from parlai.utils.io import PathManager
# these keys are automatically removed upon save. This is a rather blunt hammer.
# It's preferred you indicate this at option definiton time.
__AUTOCLEAN_KEYS__: List[str] = [
"override",
"batchindex",
"download_path",
"datapath",
"batchindex",
# we don't save interactive mode, it's only decided by scripts or CLI
"interactive_mode",
]
class Opt(dict):
"""
Class for tracking options.
Functions like a dict, but allows us to track the history of arguments as they are
set.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.history = []
self.deepcopies = []
def __setitem__(self, key, val):
loc = traceback.format_stack(limit=2)[-2]
self.history.append((key, val, loc))
super().__setitem__(key, val)
def __getstate__(self):
return (self.history, self.deepcopies, dict(self))
def __setstate__(self, state):
self.history, self.deepcopies, data = state
self.update(data)
def __reduce__(self):
return (Opt, (), self.__getstate__())
def __deepcopy__(self, memo):
"""
Override deepcopy so that history is copied over to new object.
"""
# track location of deepcopy
loc = traceback.format_stack(limit=3)[-3]
self.deepcopies.append(loc)
# copy all our children
memo = Opt({k: copy.deepcopy(v) for k, v in self.items()})
# deepcopy the history. history is only tuples, so we can do it shallow
memo.history = copy.copy(self.history)
# deepcopy the list of deepcopies. also shallow bc only strings
memo.deepcopies = copy.copy(self.deepcopies)
return memo
def display_deepcopies(self):
"""
Display all deepcopies.
"""
if len(self.deepcopies) == 0:
return 'No deepcopies performed on this opt.'
return '\n'.join(f'{i}. {loc}' for i, loc in enumerate(self.deepcopies, 1))
def display_history(self, key):
"""
Display the history for an item in the dict.
"""
changes = []
i = 0
for key_, val, loc in self.history:
if key != key_:
continue
i += 1
changes.append(f'{i}. {key} was set to {val} at:\n{loc}')
if changes:
return '\n'.join(changes)
else:
return f'No history for {key}'
def save(self, filename: str) -> None:
"""
Save the opt to disk.
Attempts to 'clean up' any residual values automatically.
"""
# start with a shallow copy
dct = dict(self)
# clean up some things we probably don't want to save
for key in __AUTOCLEAN_KEYS__:
if key in dct:
del dct[key]
with PathManager.open(filename, 'w', encoding='utf-8') as f:
json.dump(dct, fp=f, indent=4)
# extra newline for convenience of working with jq
f.write('\n')
@classmethod
def load(cls, optfile: str) -> 'Opt':
"""
Load an Opt from disk.
"""
try:
# try json first
with PathManager.open(optfile, 'r', encoding='utf-8') as t_handle:
dct = json.load(t_handle)
except UnicodeDecodeError:
# oops it's pickled
with PathManager.open(optfile, 'rb') as b_handle:
dct = pickle.load(b_handle)
for key in __AUTOCLEAN_KEYS__:
if key in dct:
del dct[key]
return cls(dct)
def log(self, header="Opt"):
from parlai.core.params import print_git_commit
logging.info(header + ":")
for key in sorted(self.keys()):
valstr = str(self[key])
if valstr.replace(" ", "").replace("\n", "") != valstr:
# show newlines as escaped keys, whitespace with quotes, etc
valstr = repr(valstr)
logging.info(f" {key}: {valstr}")
print_git_commit()