generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
model_config.py
90 lines (82 loc) · 3.22 KB
/
model_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from dataclasses import dataclass, field
from typing import List, Optional
from ..core import flatten_dict
@dataclass
class ModelConfig:
"""
Arguments which define the model and tokenizer to load.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={"help": ("The model checkpoint for weights initialization.")},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
attn_implementation: Optional[str] = field(
default=None,
metadata={
"help": (
"Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`"
)
},
)
use_peft: bool = field(
default=False,
metadata={"help": ("Whether to use PEFT or not for training.")},
)
lora_r: Optional[int] = field(
default=16,
metadata={"help": ("LoRA R value.")},
)
lora_alpha: Optional[int] = field(
default=32,
metadata={"help": ("LoRA alpha.")},
)
lora_dropout: Optional[float] = field(
default=0.05,
metadata={"help": ("LoRA dropout.")},
)
lora_target_modules: Optional[List[str]] = field(
default=None,
metadata={"help": ("LoRA target modules.")},
)
lora_modules_to_save: Optional[List[str]] = field(
default=None,
metadata={"help": ("Model layers to unfreeze & train")},
)
lora_task_type: str = field(
default="CAUSAL_LM", metadata={"help": "The task_type to pass for LoRA (use SEQ_CLS for reward modeling)"}
)
load_in_8bit: bool = field(
default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}
)
load_in_4bit: bool = field(
default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}
)
bnb_4bit_quant_type: Optional[str] = field(
default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
)
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
def to_dict(self):
output_dict = {}
for key, value in self.__dict__.items():
output_dict[key] = value
return flatten_dict(output_dict)
def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
if self.lora_target_modules == ["all-linear"]:
self.lora_target_modules = "all-linear"