-
Notifications
You must be signed in to change notification settings - Fork 217
/
rag_specinfer.py
266 lines (230 loc) · 9.21 KB
/
rag_specinfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script implements the usecase of rag-search upon FlexFlow.
Functionality:
1. FlexFlowLLM Class:
- Initializes and configures FlexFlow.
- Loads configurations from a file or uses default settings.
- Compiles and starts the language model server for text generation.
- Stops the server when operations are complete.
2. FF_LLM_wrapper Class:
- Serves as a wrapper for FlexFlow.
- Implements the necessary interface to interact with the LangChain library.
3. Main:
- Initializes FlexFlow.
- Compiles and starts the server with specific generation configurations.
- Taking in specific source information with RAG(Retrieval Augmented Generation) technique for Q&A towards specific realm/knowledgebase.
- Use LLMChain to run the model and generate response.
- Stops the FlexFlow server after generating the response.
"""
import flexflow.serve as ff
import argparse, json, os
from types import SimpleNamespace
from langchain.llms.base import LLM
from typing import Any, List, Mapping, Optional
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.vectorstores import FAISS
class FlexFlowLLM:
def __init__(self, config_file=""):
self.configs = self.get_configs(config_file)
ff.init(self.configs)
self.llm = self.create_llm()
self.ssms = self.create_ssms()
def get_configs(self, config_file):
# Load configurations from a file or use default settings
if config_file and os.path.isfile(config_file):
with open(config_file) as f:
return json.load(f)
else:
# Define sample configs
ff_init_configs = {
# required parameters
"num_gpus": 2,
"memory_per_gpu": 14000,
"zero_copy_memory_per_node": 40000,
# optional parameters
"num_cpus": 4,
"legion_utility_processors": 4,
"data_parallelism_degree": 1,
"tensor_parallelism_degree": 1,
"pipeline_parallelism_degree": 2,
"offload": False,
"offload_reserve_space_size": 1024**2,
"use_4bit_quantization": False,
"use_8bit_quantization": False,
"profiling": False,
"inference_debugging": False,
"fusion": True,
}
llm_configs = {
# required llm arguments
"llm_model": "meta-llama/Llama-2-7b-hf",
# optional llm parameters
"cache_path": "",
"refresh_cache": False,
"full_precision": False,
"ssms": [
{
# required ssm parameter
"ssm_model": "JackFram/llama-160m",
# optional ssm parameters
"cache_path": "",
"refresh_cache": False,
"full_precision": False,
}
],
# "prompt": "",
"output_file": "",
}
# Merge dictionaries
ff_init_configs.update(llm_configs)
return ff_init_configs
def create_llm(self):
configs = SimpleNamespace(**self.configs)
ff_data_type = ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF
llm = ff.LLM(
configs.llm_model,
data_type=ff_data_type,
cache_path=configs.cache_path,
refresh_cache=configs.refresh_cache,
output_file=configs.output_file,
)
return llm
def create_ssms(self):
# Create the SSMs
configs = SimpleNamespace(**self.configs)
ssms = []
for ssm_config in configs.ssms:
ssm_config = SimpleNamespace(**ssm_config)
ff_data_type = (
ff.DataType.DT_FLOAT if ssm_config.full_precision else ff.DataType.DT_HALF
)
ssm = ff.SSM(
ssm_config.ssm_model,
data_type=ff_data_type,
cache_path=ssm_config.cache_path,
refresh_cache=ssm_config.refresh_cache,
output_file=configs.output_file,
)
ssms.append(ssm)
return ssms
def compile_and_start(self, generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch):
# Compile the SSMs for inference and load the weights into memory
for ssm in self.ssms:
ssm.compile(
generation_config,
max_requests_per_batch,
max_seq_length,
max_tokens_per_batch,
)
# Compile the LLM for inference and load the weights into memory
self.llm.compile(
generation_config,
max_requests_per_batch,
max_seq_length,
max_tokens_per_batch,
ssms = self.ssms
)
# start server
self.llm.start_server()
def generate(self, prompt):
results = self.llm.generate(prompt)
if isinstance(results, list):
result_txt = results[0].output_text.decode('utf-8')
else:
result_txt = results.output_text.decode('utf-8')
return result_txt
def stop_server(self):
self.llm.stop_server()
def __enter__(self):
return self.llm.__enter__()
def __exit__(self, exc_type, exc_value, traceback):
return self.llm.__exit__(exc_type, exc_value, traceback)
class FF_LLM_wrapper(LLM):
flexflow_llm: FlexFlowLLM
@property
def _llm_type(self) -> str:
return "custom"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
response = self.flexflow_llm.generate(prompt)
return response
if __name__ == "__main__":
# initialization
ff_llm = FlexFlowLLM()
# compile and start server
gen_config = ff.GenerationConfig(do_sample=False, temperature=0.9, topp=0.8, topk=1)
ff_llm.compile_and_start(
gen_config,
max_requests_per_batch=1,
max_seq_length=256,
max_tokens_per_batch=200
)
# the wrapper class serves as the 'Model' in LCEL
ff_llm_wrapper = FF_LLM_wrapper(flexflow_llm=ff_llm)
# USE CASE 2: Rag Search
# Load web page content
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
data = loader.load()
# Split text
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_documents(data)
# Initialize embeddings
embeddings = OpenAIEmbeddings(openai_api_key=os.getenv('OPENAI_API_KEY')) # fill in openai api key
# Create VectorStore
vectorstore = Chroma.from_documents(all_splits, embeddings)
# Use VectorStore as a retriever
retriever = vectorstore.as_retriever()
# Test if similarity search is working
question = "What are the approaches to Task Decomposition?"
docs = vectorstore.similarity_search(question)
max_chars_per_doc = 50
# docs_text_list = [docs[i].page_content for i in range(len(docs))]
docs_text_list = [docs[i].page_content[:max_chars_per_doc] for i in range(len(docs))]
docs_text = ''.join(docs_text_list)
# Using a Prompt Template
prompt_rag = PromptTemplate.from_template(
"Summarize the main themes in these retrieved docs: {docs_text}"
)
# Chain
llm_chain_rag = LLMChain(llm=ff_llm_wrapper, prompt=prompt_rag)
# Run
rag_result = llm_chain_rag(docs_text)
# stop the server
ff_llm.stop_server()