-
Notifications
You must be signed in to change notification settings - Fork 166
/
constants.py
110 lines (86 loc) · 2.68 KB
/
constants.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import enum
#TODO naming..
class DeploymentType(enum.Enum):
LOCAL = 1
AML = 2
MII_CONFIGS_KEY = 'mii_configs'
class Tasks(enum.Enum):
TEXT_GENERATION = 1
TEXT_CLASSIFICATION = 2
QUESTION_ANSWERING = 3
FILL_MASK = 4
TOKEN_CLASSIFICATION = 5
CONVERSATIONAL = 6
TEXT2IMG = 7
TEXT_GENERATION_NAME = 'text-generation'
TEXT_CLASSIFICATION_NAME = 'text-classification'
QUESTION_ANSWERING_NAME = 'question-answering'
FILL_MASK_NAME = 'fill-mask'
TOKEN_CLASSIFICATION_NAME = 'token-classification'
CONVERSATIONAL_NAME = 'conversational'
TEXT2IMG_NAME = "text-to-image"
class ModelProvider(enum.Enum):
HUGGING_FACE = 1
ELEUTHER_AI = 2
HUGGING_FACE_LLM = 3
DIFFUSERS = 4
MODEL_PROVIDER_NAME_HF = "hugging-face"
MODEL_PROVIDER_NAME_EA = "eleuther-ai"
MODEL_PROVIDER_NAME_HF_LLM = "hugging-face-llm"
MODEL_PROVIDER_NAME_DIFFUSERS = "diffusers"
MODEL_PROVIDER_MAP = {
MODEL_PROVIDER_NAME_HF: ModelProvider.HUGGING_FACE,
MODEL_PROVIDER_NAME_EA: ModelProvider.ELEUTHER_AI,
MODEL_PROVIDER_NAME_HF_LLM: ModelProvider.HUGGING_FACE_LLM,
MODEL_PROVIDER_NAME_DIFFUSERS: ModelProvider.DIFFUSERS
}
SUPPORTED_MODEL_TYPES = {
'roberta': ModelProvider.HUGGING_FACE,
'gpt2': ModelProvider.HUGGING_FACE,
'bert': ModelProvider.HUGGING_FACE,
'gpt_neo': ModelProvider.HUGGING_FACE,
'gptj': ModelProvider.HUGGING_FACE,
'opt': ModelProvider.HUGGING_FACE,
'gpt-neox': ModelProvider.ELEUTHER_AI,
'bloom': ModelProvider.HUGGING_FACE_LLM,
'stable-diffusion': ModelProvider.DIFFUSERS
}
SUPPORTED_TASKS = [
TEXT_GENERATION_NAME,
TEXT_CLASSIFICATION_NAME,
QUESTION_ANSWERING_NAME,
FILL_MASK_NAME,
TOKEN_CLASSIFICATION_NAME,
CONVERSATIONAL_NAME,
TEXT2IMG_NAME
]
REQUIRED_KEYS_PER_TASK = {
TEXT_GENERATION_NAME: ["query"],
TEXT_CLASSIFICATION_NAME: ["query"],
QUESTION_ANSWERING_NAME: ["context",
"question"],
FILL_MASK_NAME: ["query"],
TOKEN_CLASSIFICATION_NAME: ["query"],
CONVERSATIONAL_NAME:
['text',
'conversation_id',
'past_user_inputs',
'generated_responses'],
TEXT2IMG_NAME: ["query"]
}
MODEL_NAME_KEY = 'model_name'
TASK_NAME_KEY = 'task_name'
MODEL_PATH_KEY = 'model_path'
ENABLE_DEEPSPEED_KEY = 'ds_optimize'
ENABLE_DEEPSPEED_ZERO_KEY = 'ds_zero'
DEEPSPEED_CONFIG_KEY = 'ds_config'
CHECKPOINT_KEY = "checkpoint"
MII_CACHE_PATH = "MII_CACHE_PATH"
MII_CACHE_PATH_DEFAULT = "/tmp/mii_cache"
MII_DEBUG_MODE = "MII_DEBUG_MODE"
MII_DEBUG_MODE_DEFAULT = "0"
MII_DEBUG_DEPLOY_KEY = "MII_DEBUG_DEPLOY_KEY"
MII_DEBUG_BRANCH = "MII_DEBUG_BRANCH"
MII_DEBUG_BRANCH_DEFAULT = "main"
MII_MODEL_PATH_DEFAULT = "/tmp/mii_models"
GRPC_MAX_MSG_SIZE = 2**30 # 1GB