forked from embedchain/embedchain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
embedchain.py
285 lines (240 loc) · 9.69 KB
/
embedchain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import os
from chromadb.utils import embedding_functions
from dotenv import load_dotenv
from gpt4all import GPT4All
from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from embedchain.loaders.youtube_video import YoutubeVideoLoader
from embedchain.loaders.pdf_file import PdfFileLoader
from embedchain.loaders.web_page import WebPageLoader
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
from embedchain.loaders.local_text import LocalTextLoader
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.qna_pair import QnaPairChunker
from embedchain.chunkers.text import TextChunker
from embedchain.vectordb.chroma_db import ChromaDB
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name="text-embedding-ada-002"
)
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
gpt4all_model = None
load_dotenv()
ABS_PATH = os.getcwd()
DB_DIR = os.path.join(ABS_PATH, "db")
class EmbedChain:
def __init__(self, db=None, ef=None):
"""
Initializes the EmbedChain instance, sets up a vector DB client and
creates a collection.
:param db: The instance of the VectorDB subclass.
"""
if db is None:
db = ChromaDB(ef=ef)
self.db_client = db.client
self.collection = db.collection
self.user_asks = []
def _get_loader(self, data_type):
"""
Returns the appropriate data loader for the given data type.
:param data_type: The type of the data to load.
:return: The loader for the given data type.
:raises ValueError: If an unsupported data type is provided.
"""
loaders = {
'youtube_video': YoutubeVideoLoader(),
'pdf_file': PdfFileLoader(),
'web_page': WebPageLoader(),
'qna_pair': LocalQnaPairLoader(),
'text': LocalTextLoader(),
}
if data_type in loaders:
return loaders[data_type]
else:
raise ValueError(f"Unsupported data type: {data_type}")
def _get_chunker(self, data_type):
"""
Returns the appropriate chunker for the given data type.
:param data_type: The type of the data to chunk.
:return: The chunker for the given data type.
:raises ValueError: If an unsupported data type is provided.
"""
chunkers = {
'youtube_video': YoutubeVideoChunker(),
'pdf_file': PdfFileChunker(),
'web_page': WebPageChunker(),
'qna_pair': QnaPairChunker(),
'text': TextChunker(),
}
if data_type in chunkers:
return chunkers[data_type]
else:
raise ValueError(f"Unsupported data type: {data_type}")
def add(self, data_type, url):
"""
Adds the data from the given URL to the vector db.
Loads the data, chunks it, create embedding for each chunk
and then stores the embedding to vector database.
:param data_type: The type of the data to add.
:param url: The URL where the data is located.
"""
loader = self._get_loader(data_type)
chunker = self._get_chunker(data_type)
self.user_asks.append([data_type, url])
self.load_and_embed(loader, chunker, url)
def add_local(self, data_type, content):
"""
Adds the data you supply to the vector db.
Loads the data, chunks it, create embedding for each chunk
and then stores the embedding to vector database.
:param data_type: The type of the data to add.
:param content: The local data. Refer to the `README` for formatting.
"""
loader = self._get_loader(data_type)
chunker = self._get_chunker(data_type)
self.user_asks.append([data_type, content])
self.load_and_embed(loader, chunker, content)
def load_and_embed(self, loader, chunker, url):
"""
Loads the data from the given URL, chunks it, and adds it to the database.
:param loader: The loader to use to load the data.
:param chunker: The chunker to use to chunk the data.
:param url: The URL where the data is located.
"""
embeddings_data = chunker.create_chunks(loader, url)
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
ids = embeddings_data["ids"]
# get existing ids, and discard doc if any common id exist.
existing_docs = self.collection.get(
ids=ids,
# where={"url": url}
)
existing_ids = set(existing_docs["ids"])
if len(existing_ids):
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
if not data_dict:
print(f"All data from {url} already exists in the database.")
return
ids = list(data_dict.keys())
documents, metadatas = zip(*data_dict.values())
self.collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
print(f"Successfully saved {url}. Total chunks count: {self.collection.count()}")
def _format_result(self, results):
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
def get_llm_model_answer(self, prompt):
raise NotImplementedError
def retrieve_from_database(self, input_query):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query
:param input_query: The query to use.
:return: The content of the document that matched your query.
"""
result = self.collection.query(
query_texts=[input_query,],
n_results=1,
)
result_formatted = self._format_result(result)
if result_formatted:
content = result_formatted[0][0].page_content
else:
content = ""
return content
def generate_prompt(self, input_query, context):
"""
Generates a prompt based on the given query and context, ready to be passed to an LLM
:param input_query: The query to use.
:param context: Similar documents to the query used as context.
:return: The prompt
"""
prompt = f"""Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Query: {input_query}
Helpful Answer:
"""
return prompt
def get_answer_from_llm(self, prompt):
"""
Gets an answer based on the given query and context by passing it
to an LLM.
:param query: The query to use.
:param context: Similar documents to the query used as context.
:return: The answer.
"""
answer = self.get_llm_model_answer(prompt)
return answer
def query(self, input_query):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
:param input_query: The query to use.
:return: The answer to the query.
"""
context = self.retrieve_from_database(input_query)
prompt = self.generate_prompt(input_query, context)
answer = self.get_answer_from_llm(prompt)
return answer
class App(EmbedChain):
"""
The EmbedChain app.
Has two functions: add and query.
adds(data_type, url): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM.
"""
def __init__(self, model="gpt-3.5-turbo-0613", db=None, ef=None):
if ef is None:
ef = openai_ef
self.model = model
super().__init__(db, ef)
def get_llm_model_answer(self, prompt):
messages = []
messages.append({
"role": "user", "content": prompt
})
response = openai.ChatCompletion.create(
model=self.model,
messages=messages,
temperature=0,
max_tokens=1000,
top_p=1,
)
return response["choices"][0]["message"]["content"]
class OpenSourceApp(EmbedChain):
"""
The OpenSource app.
Same as App, but uses an open source embedding model and LLM.
Has two function: add and query.
adds(data_type, url): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM.
"""
def __init__(self, db=None, ef=None):
print("Loading open source embedding model. This may take some time...")
if ef is None:
ef = sentence_transformer_ef
print("Successfully loaded open source embedding model.")
super().__init__(db, ef)
def get_llm_model_answer(self, prompt):
global gpt4all_model
if gpt4all_model is None:
gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
response = gpt4all_model.generate(
prompt=prompt,
)
return response