Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add local llm implementation #119

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM infiniflow/ragflow-base:v1.0
FROM swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow-base:v1.0
USER root

WORKDIR /ragflow
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@
</a>
</p>

[RAGFLOW](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM,
[RagFlow](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM,
with reasoned and well-founded answers to your question. Clone this repository, you can deploy your own knowledge management
platform to empower your business with AI.

<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/b24a7a5f-4d1d-4a30-90b1-7b0ec558b79d" width="1000"/>
</div>

# Features
# Key Features
- **Custom-build document understanding engine.** Our deep learning engine is made according to the needs of analyzing and searching various type of documents in different domain.
- For documents from different domain for different purpose, the engine applys different analyzing and search strategy.
- Easily intervene and manipulate the data proccessing procedure when things goes beyond expectation.
- Multi-media document understanding is supported using OCR and multi-modal LLM.
- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. [README](./deepdoc/README.md)
- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. See [README.](./deepdoc/README.md)
- For PDF files, layout and table structures including row, column and span of them are recognized.
- Put the table accrossing the pages together.
- Reconstruct the table structure components into html table.
Expand Down
2 changes: 1 addition & 1 deletion api/apps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
#app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False
app.config["SESSION_TYPE"] = "filesystem"
app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024
app.config['MAX_CONTENT_LENGTH'] = 128 * 1024 * 1024

Session(app)
login_manager = LoginManager()
Expand Down
2 changes: 1 addition & 1 deletion api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def my_llms():
}
res[o["llm_factory"]]["llm"].append({
"type": o["model_type"],
"name": o["model_name"],
"name": o["llm_name"],
"used_token": o["used_tokens"]
})
return get_json_result(data=res)
Expand Down
2 changes: 1 addition & 1 deletion api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ class Task(DataBaseModel):
begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
progress = FloatField(default=0)
progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="")


class Dialog(DataBaseModel):
Expand Down
1 change: 1 addition & 0 deletions api/db/services/knowledgebase_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_detail(cls, kb_id):
Tenant.embd_id,
cls.model.avatar,
cls.model.name,
cls.model.language,
cls.model.description,
cls.model.permission,
cls.model.doc_num,
Expand Down
10 changes: 8 additions & 2 deletions api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
ERROR_REPORT_WITH_PATH = False

MAX_TIMESTAMP_INTERVAL = 60
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60 * 1000
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60

REQUEST_TRY_TIMES = 3
REQUEST_WAIT_SEC = 2
Expand All @@ -69,6 +69,12 @@
"image2text_model": "glm-4v",
"asr_model": "",
},
"local": {
"chat_model": "",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
}
}
LLM = get_base_config("user_default_llm", {})
LLM_FACTORY = LLM.get("factory", "通义千问")
Expand Down Expand Up @@ -134,7 +140,7 @@
USE_DATA_AUTHENTICATION = False
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
USE_DEFAULT_TIMEOUT = False
AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False

Expand Down
16 changes: 15 additions & 1 deletion deepdoc/parser/excel_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,27 @@ def __call__(self, fnm):
for i,c in enumerate(r):
if not c.value:continue
t = str(ti[i].value) if i < len(ti) else ""
t += (":" if t else "") + str(c.value)
t += (":" if t else "") + str(c.value)
l.append(t)
l = "; ".join(l)
if sheetname.lower().find("sheet") <0: l += " ——"+sheetname
res.append(l)
return res

@staticmethod
def row_number(fnm, binary):
if fnm.split(".")[-1].lower().find("xls") >= 0:
wb = load_workbook(BytesIO(binary))
total = 0
for sheetname in wb.sheetnames:
ws = wb[sheetname]
total += len(ws.rows)
return total

if fnm.split(".")[-1].lower() in ["csv", "txt"]:
txt = binary.decode("utf-8")
return len(txt.split("\n"))


if __name__ == "__main__":
psr = HuExcelParser()
Expand Down
2 changes: 1 addition & 1 deletion docker/nginx/nginx.conf
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ http {
keepalive_timeout 65;

#gzip on;
client_max_body_size 82M;
client_max_body_size 128M;

include /etc/nginx/conf.d/ragflow.conf;
}
Expand Down
17 changes: 10 additions & 7 deletions rag/app/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


class Excel(ExcelParser):
def __call__(self, fnm, binary=None, callback=None):
def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None):
if not binary:
wb = load_workbook(fnm)
else:
Expand All @@ -35,6 +35,7 @@ def __call__(self, fnm, binary=None, callback=None):
total += len(list(wb[sheetname].rows))

res, fails, done = [], [], 0
rn = 0
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
Expand All @@ -46,6 +47,9 @@ def __call__(self, fnm, binary=None, callback=None):
rows[0]) if i not in missed]
data = []
for i, r in enumerate(rows[1:]):
rn += 1
if rn-1 < from_page:continue
if rn -1>=to_page: break
row = [
cell.value for ii,
cell in enumerate(r) if ii not in missed]
Expand Down Expand Up @@ -111,7 +115,7 @@ def column_data_type(arr):
return arr, ty


def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs):
"""
Excel and csv(txt) format files are supported.
For csv or txt file, the delimiter between columns is TAB.
Expand Down Expand Up @@ -147,16 +151,15 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
headers = lines[0].split(kwargs.get("delimiter", "\t"))
rows = []
for i, line in enumerate(lines[1:]):
if from_page < from_page:continue
if i >= to_page: break
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
if len(row) != len(headers):
fails.append(str(i))
continue
rows.append(row)
if len(rows) % 999 == 0:
callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))

callback(0.6, ("Extract records: {}".format(len(rows)) + (
callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))

dfs = [pd.DataFrame(np.array(rows), columns=headers)]
Expand Down Expand Up @@ -209,7 +212,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):

KnowledgebaseService.update_parser_config(
kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
callback(0.6, "")
callback(0.35, "")

return res

Expand Down
9 changes: 6 additions & 3 deletions rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,25 @@


EmbeddingModel = {
"Infiniflow": HuEmbedding,
"local": HuEmbedding,
"OpenAI": OpenAIEmbed,
"通义千问": HuEmbedding, #QWenEmbed,
"智谱AI": ZhipuEmbed
}


CvModel = {
"OpenAI": GptV4,
"Infiniflow": GptV4,
"local": LocalCV,
"通义千问": QWenCV,
"智谱AI": Zhipu4V
}


ChatModel = {
"OpenAI": GptTurbo,
"Infiniflow": GptTurbo,
"智谱AI": ZhipuChat,
"通义千问": QWenChat,
"local": LocalLLM
}

42 changes: 40 additions & 2 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import openai

from rag.nlp import is_english
from rag.utils import num_tokens_from_string


class Base(ABC):
Expand Down Expand Up @@ -86,7 +87,6 @@ def __init__(self, key, model_name="glm-3-turbo"):
self.model_name = model_name

def chat(self, system, history, gen_conf):
from http import HTTPStatus
if system: history.insert(0, {"role": "system", "content": system})
try:
response = self.client.chat.completions.create(
Expand All @@ -100,4 +100,42 @@ def chat(self, system, history, gen_conf):
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.completion_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
return "**ERROR**: " + str(e), 0

class LocalLLM(Base):
class RPCProxy:
def __init__(self, host, port):
self.host = host
self.port = int(port)
self.__conn()

def __conn(self):
from multiprocessing.connection import Client
self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu')

def __getattr__(self, name):
import pickle
def do_rpc(*args, **kwargs):
for _ in range(3):
try:
self._connection.send(pickle.dumps((name, args, kwargs)))
return pickle.loads(self._connection.recv())
except Exception as e:
self.__conn()
raise Exception("RPC connection lost!")

return do_rpc

def __init__(self, key, model_name="glm-3-turbo"):
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)

def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system})
try:
ans = self.client.chat(
history,
gen_conf
)
return ans, num_tokens_from_string(ans)
except Exception as e:
return "**ERROR**: " + str(e), 0
8 changes: 8 additions & 0 deletions rag/llm/cv_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,11 @@ def describe(self, image, max_tokens=1024):
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens


class LocalCV(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese"):
pass

def describe(self, image, max_tokens=1024):
return "", 0
90 changes: 90 additions & 0 deletions rag/llm/rpc_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import argparse
import pickle
import random
import time
from multiprocessing.connection import Listener
from threading import Thread
import torch


class RPCHandler:
def __init__(self):
self._functions = { }

def register_function(self, func):
self._functions[func.__name__] = func

def handle_connection(self, connection):
try:
while True:
# Receive a message
func_name, args, kwargs = pickle.loads(connection.recv())
# Run the RPC and send a response
try:
r = self._functions[func_name](*args,**kwargs)
connection.send(pickle.dumps(r))
except Exception as e:
connection.send(pickle.dumps(e))
except EOFError:
pass


def rpc_server(hdlr, address, authkey):
sock = Listener(address, authkey=authkey)
while True:
try:
client = sock.accept()
t = Thread(target=hdlr.handle_connection, args=(client,))
t.daemon = True
t.start()
except Exception as e:
print("【EXCEPTION】:", str(e))


models = []
tokenizer = None

def chat(messages, gen_conf):
global tokenizer
model = Model()
roles = {"system":"System", "user": "User", "assistant": "Assistant"}
line = ["{}: {}".format(roles[m["role"].lower()], m["content"]) for m in messages]
line = "\n".join(line) + "\nAssistant: "
tokens = tokenizer([line], return_tensors='pt')
tokens = {k: tokens[k].to(model.device) if isinstance(tokens[k], torch.Tensor) else tokens[k] for k in
tokens.keys()}
res = [tokenizer.decode(t) for t in model.generate(**tokens, **gen_conf)][0]
return res.split("Assistant: ")[-1]


def Model():
global models
random.seed(time.time())
return random.choice(models)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, help="Model name")
parser.add_argument("--port", default=7860, type=int, help="RPC serving port")
args = parser.parse_args()

handler = RPCHandler()
handler.register_function(chat)

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig

models = []
for _ in range(2):
m = AutoModelForCausalLM.from_pretrained(args.model_name,
device_map="auto",
torch_dtype='auto',
trust_remote_code=True)
m.generation_config = GenerationConfig.from_pretrained(args.model_name)
m.generation_config.pad_token_id = m.generation_config.eos_token_id
models.append(m)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False,
trust_remote_code=True)

# Run the server
rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')
Loading