-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
executable file
·48 lines (38 loc) · 1.51 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import dataclasses as dc
@dc.dataclass
class CtrlArguments:
train_data: str = dc.field(
default="data/training_cunique_with_distractors.json",
metadata={"help": "A CSV list of training data files"}
)
formulation: str = dc.field(
default="areg_ltr",
metadata={"help": "Type of problem definition: autoregressive (areg) or u-PMLM (upmlm) or mixed (if predict_questions is set)"}
)
context_strategy: str = dc.field(
default="take_first",
metadata={"help": "How to deal with contexts greater than a specified length"}
)
tokenizer_file: str = dc.field(
default="tokenizer.json",
metadata={"help": "A JSON file (in the format provided by HuggingFace's tokenizers library) with a trained tokenizer"}
)
sequence_length: int = dc.field(
default=256,
metadata={"help": "The max sequence length"}
)
force_prepend_control: bool = dc.field(
default=False,
metadata={"help": "If the control code should be prepended for all sliding windows. Otherwise, it is only prepended at the start of the sequence"}
)
class GradientPrinter:
def __init__(self, name):
self.name = name
def __call__(self, grad):
np_grad = grad.cpu().numpy()
print("======== GRAD FOR {} ========".format(self.name))
print("\tGRAD {}".format(grad))
print("\tGRAD NORM {}".format(np.linalg.norm(np_grad)))
print("\tGRAD MEAN {}".format(np.mean(np_grad)))
print()