-
Notifications
You must be signed in to change notification settings - Fork 381
/
global_vars.py
116 lines (103 loc) · 4.03 KB
/
global_vars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import yaml
from transformers import GenerationConfig
from models import alpaca, stablelm, koalpaca, flan_alpaca, mpt
from models import camel, t5_vicuna, vicuna, starchat, redpajama, bloom
def initialize_globals(args):
global model, model_type, stream_model, tokenizer
global gen_config, gen_config_raw
global gen_config_summarization
model_type_tmp = "alpaca"
if "vicuna" in args.base_url.lower():
model_type_tmp = "vicuna"
elif "mpt" in args.base_url.lower():
model_type_tmp = "mpt"
elif "redpajama" in args.base_url.lower():
model_type_tmp = "redpajama"
elif "starchat" in args.base_url.lower():
model_type_tmp = "starchat"
elif "camel" in args.base_url.lower():
model_type_tmp = "camel"
elif "flan-alpaca" in args.base_url.lower():
model_type_tmp = "flan-alpaca"
elif "openassistant/stablelm" in args.base_url.lower():
model_type_tmp = "os-stablelm"
elif "stablelm" in args.base_url.lower():
model_type_tmp = "stablelm"
elif "fastchat-t5" in args.base_url.lower():
model_type_tmp = "t5-vicuna"
elif "koalpaca-polyglot" in args.base_url.lower():
model_type_tmp = "koalpaca-polyglot"
elif "stable-vicuna" in args.base_url.lower():
model_type_tmp = "stable-vicuna"
elif "alpacagpt4" in args.ft_ckpt_url.lower():
model_type_tmp = "alpaca-gpt4"
elif "alpaca" in args.ft_ckpt_url.lower():
model_type_tmp = "alpaca"
elif "llama-deus" in args.ft_ckpt_url.lower():
model_type_tmp = "llama-deus"
elif "vicuna-lora-evolinstruct" in args.ft_ckpt_url.lower():
model_type_tmp = "evolinstruct-vicuna"
elif "alpacoom" in args.ft_ckpt_url.lower():
model_type_tmp = "alpacoom"
else:
print("unsupported model type")
quit()
print(f"determined model type: {model_type_tmp}")
try:
if model is not None:
del model
if tokenizer is not None:
del tokenizer
except NameError:
pass
load_model = get_load_model(model_type_tmp)
model, tokenizer = load_model(
base=args.base_url,
finetuned=args.ft_ckpt_url,
multi_gpu=args.multi_gpu,
force_download_ckpt=args.force_download_ckpt
)
gen_config, gen_config_raw = get_generation_config(args.gen_config_path)
gen_config_summarization, _ = get_generation_config(args.gen_config_summarization_path)
model_type = model_type_tmp
stream_model = model
def get_load_model(model_type):
if model_type == "alpaca" or \
model_type == "alpaca-gpt4" or \
model_type == "llama-deus":
return alpaca.load_model
elif model_type == "stablelm" or model_type == "os-stablelm":
return stablelm.load_model
elif model_type == "koalpaca-polyglot":
return koalpaca.load_model
elif model_type == "flan-alpaca":
return flan_alpaca.load_model
elif model_type == "camel":
return camel.load_model
elif model_type == "t5-vicuna":
return t5_vicuna.load_model
elif model_type == "stable-vicuna":
return vicuna.load_model
elif model_type == "starchat":
return starchat.load_model
elif model_type == "mpt":
return mpt.load_model
elif model_type == "redpajama":
return redpajama.load_model
elif model_type == "vicuna":
return vicuna.load_model
elif model_type == "evolinstruct-vicuna":
return alpaca.load_model
elif model_type == "alpacoom":
return bloom.load_model
else:
return None
def get_generation_config(path):
with open(path, 'rb') as f:
generation_config = yaml.safe_load(f.read())
generation_config = generation_config["generation_config"]
return GenerationConfig(**generation_config), generation_config
def get_constraints_config(path):
with open(path, 'rb') as f:
constraints_config = yaml.safe_load(f.read())
return ConstraintsConfig(**constraints_config), constraints_config["constraints"]