Skip to content

Commit

Permalink
add RapidOCRPDFLoader and RapidOCRLoader (#1275)
Browse files Browse the repository at this point in the history
* add RapidOCRPDFLoader

* update mypdfloader.py and requirements.txt

* add myimgloader.py

* add test samples

* add TODO to mypdfloader

* add loaders to KnowledgeFile class

* add loaders to KnowledgeFile class
  • Loading branch information
imClumsyPanda committed Sep 1, 2023
1 parent 72b9da2 commit 6c4ef26
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 2 deletions.
2 changes: 2 additions & 0 deletions document_loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .mypdfloader import RapidOCRPDFLoader
from .myimgloader import RapidOCRLoader
25 changes: 25 additions & 0 deletions document_loaders/myimgloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader


class RapidOCRLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def img2text(filepath):
from rapidocr_onnxruntime import RapidOCR
resp = ""
ocr = RapidOCR()
result, _ = ocr(filepath)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
return resp

text = img2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)


if __name__ == "__main__":
loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg")
docs = loader.load()
print(docs)
37 changes: 37 additions & 0 deletions document_loaders/mypdfloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader


class RapidOCRPDFLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def pdf2text(filepath):
import fitz
from rapidocr_onnxruntime import RapidOCR
import numpy as np
ocr = RapidOCR()
doc = fitz.open(filepath)
resp = ""
for page in doc:
# TODO: 依据文本与图片顺序调整处理方式
text = page.get_text("")
resp += text + "\n"

img_list = page.get_images()
for img in img_list:
pix = fitz.Pixmap(doc, img[0])
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
result, _ = ocr(img_array)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
return resp

text = pdf2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)


if __name__ == "__main__":
loader = RapidOCRPDFLoader(file_path="../tests/samples/ocr_test.pdf")
docs = loader.load()
print(docs)
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ SQLAlchemy==2.0.19
faiss-cpu
accelerate
spacy
PyMuPDF==1.22.5
rapidocr_onnxruntime>=1.3.1

# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
Expand Down
2 changes: 2 additions & 0 deletions requirements_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ faiss-cpu
nltk
accelerate
spacy
PyMuPDF==1.22.5
rapidocr_onnxruntime>=1.3.1

# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
Expand Down
8 changes: 6 additions & 2 deletions server/knowledge_base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def load_embeddings(model: str, device: str):
"UnstructuredMarkdownLoader": ['.md'],
"CustomJSONLoader": [".json"],
"CSVLoader": [".csv"],
"PyPDFLoader": [".pdf"],
"RapidOCRPDFLoader": [".pdf"],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt',
Expand Down Expand Up @@ -196,7 +197,10 @@ def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = Fal

print(f"{self.document_loader_name} used for {self.filepath}")
try:
document_loaders_module = importlib.import_module('langchain.document_loaders')
if self.document_loader_name in []:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
except Exception as e:
print(e)
Expand Down
Binary file added tests/samples/ocr_test.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/samples/ocr_test.pdf
Binary file not shown.

0 comments on commit 6c4ef26

Please sign in to comment.