-
Notifications
You must be signed in to change notification settings - Fork 1
/
decoder.py
132 lines (115 loc) · 6.26 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""Parsing engine that converts natural language to the grammar.
This class 'decodes' natural language inputs into the grammar. There are several parsing options
supported, such as fine-tuned t5 models, few-shot gpt-j models, and KNN.
"""
import gin
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
from petals import AutoDistributedModelForCausalLM
@gin.configurable
class Decoder:
"""Class that defines parser options."""
def __init__(self,
parsing_model_name: str,
in_8_bits: bool,
no_init: bool = False,
use_guided_decoding: bool = True,
dataset_name: str = None):
"""Init
Arguments:
parsing_model_name: The name of the parsing model. The currently supported are:
* t5 models: if using a t5 model, parsing_model_name must be the name of
the model path or huggingface directory
* gpt-j few shot models: if using few-shot, this must be the name of the
hugging face directory, e.g., 'EleutherAI/gpt-j-6B'
* nearest-neighbor: if this, will use knn on the prompts to parse
Note, that for t5 and gpt models to be handled correctly, 't5' or 'gpt'
**must** be specified in parsing_model_name.
no_init: If True, will not init any parsing model
use_guided_decoding: Whether to use guided decoding
dataset_name: The name of the dataset
"""
self.gen_completions = None
self.use_guided_dec = use_guided_decoding
self.gpt_parser_initialized = False
self.gpt_model = None
self.gpt_tokenizer = None
self.parser_name = parsing_model_name
self.init_model(parsing_model_name,
in_8_bits,
no_init=no_init,
dataset_name=dataset_name)
def init_model(self,
parsing_model_name: str,
in_8_bits: bool,
no_init: bool = False,
dataset_name: str = None):
"""Initializes the model
Args:
dataset_name: The semantic name of the dataset
no_init: Do not init the model
parsing_model_name: the name of the model
config_file: a gin config file required for t5 models
:param in_8_bits:
"""
# Does not initialize a model
if no_init:
return
if "petals-team" in parsing_model_name:
"""p2p models from petals-team"""
self.gpt_tokenizer = AutoTokenizer.from_pretrained(parsing_model_name)
self.gpt_model = AutoDistributedModelForCausalLM.from_pretrained(parsing_model_name)
self.gpt_parser_initialized = True
from parsing.gpt.few_shot_inference import get_few_shot_predict_f
predict_f = get_few_shot_predict_f(model=self.gpt_model,
tokenizer=self.gpt_tokenizer,
use_guided_decoding=self.use_guided_dec)
def complete(prompt, grammar):
return predict_f(text=prompt, grammar=grammar)
elif "GPTQ" in parsing_model_name:
"""GPTQ quantized model"""
self.gpt_tokenizer = AutoTokenizer.from_pretrained(parsing_model_name)
quantization_config = GPTQConfig(bits=4, tokenizer=self.gpt_tokenizer)
self.gpt_model = AutoModelForCausalLM.from_pretrained(parsing_model_name, quantization_config=quantization_config,
device_map="auto")
self.gpt_model.config.pad_token_id = self.gpt_model.config.eos_token_id
self.gpt_parser_initialized = True
from parsing.gpt.few_shot_inference import get_few_shot_predict_f
predict_f = get_few_shot_predict_f(model=self.gpt_model,
tokenizer=self.gpt_tokenizer,
use_guided_decoding=self.use_guided_dec)
def complete(prompt, grammar):
return predict_f(text=prompt, grammar=grammar)
elif "Llama" in parsing_model_name or "Mistral" in parsing_model_name:
"""original model"""
if not self.gpt_parser_initialized:
self.gpt_tokenizer = AutoTokenizer.from_pretrained(parsing_model_name)
if in_8_bits:
self.gpt_model = AutoModelForCausalLM.from_pretrained(parsing_model_name, device_map='cuda:0',
load_in_8bit=True)
else:
self.gpt_model = AutoModelForCausalLM.from_pretrained(parsing_model_name)
self.gpt_model.config.pad_token_id = self.gpt_model.config.eos_token_id
self.gpt_parser_initialized = True
from parsing.gpt.few_shot_inference import get_few_shot_predict_f
predict_f = get_few_shot_predict_f(model=self.gpt_model,
tokenizer=self.gpt_tokenizer,
use_guided_decoding=self.use_guided_dec)
def complete(prompt, grammar):
return predict_f(text=prompt, grammar=grammar)
elif parsing_model_name == "nearest-neighbor":
def complete(prompt, _):
split_prompts = prompt.split("\n")
# prompt should be third back, considering
# "parsed:" is last and the user input is 2nd last
# then line break
last_prompt = split_prompts[-4][len("parsed: "):]
responses = {"generation": "\n".join(split_prompts) + last_prompt}
return responses
else:
raise NotImplementedError
self.gen_completions = complete
def complete(self, prompt: str, grammar: str = None):
"""Run a completion."""
assert self.gen_completions is not None, "Must run init_model first!"
completed = self.gen_completions(prompt, grammar)
return completed