Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【功能新增】增强对PPT、DOC知识库文件的OCR识别 #2013

Merged
merged 6 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion document_loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .mypdfloader import RapidOCRPDFLoader
from .myimgloader import RapidOCRLoader
from .myimgloader import RapidOCRLoader
from .mydocloader import RapidOCRDocLoader
from .mypptloader import RapidOCRPPTLoader
71 changes: 71 additions & 0 deletions document_loaders/mydocloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from typing import List
import tqdm


class RapidOCRDocLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def doc2text(filepath):
from docx.table import _Cell, Table
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.text.paragraph import Paragraph
from docx import Document, ImagePart
from PIL import Image
from io import BytesIO
import numpy as np
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
doc = Document(filepath)
resp = ""

def iter_block_items(parent):
from docx.document import Document
if isinstance(parent, Document):
parent_elm = parent.element.body
elif isinstance(parent, _Cell):
parent_elm = parent._tc
else:
raise ValueError("RapidOCRDocLoader parse fail")

for child in parent_elm.iterchildren():
if isinstance(child, CT_P):
yield Paragraph(child, parent)
elif isinstance(child, CT_Tbl):
yield Table(child, parent)

b_unit = tqdm.tqdm(total=len(doc.paragraphs)+len(doc.tables),
desc="RapidOCRDocLoader block index: 0")
for i, block in enumerate(iter_block_items(doc)):
b_unit.set_description(
"RapidOCRDocLoader block index: {}".format(i))
b_unit.refresh()
if isinstance(block, Paragraph):
resp += block.text.strip() + "\n"
images = block._element.xpath('.//pic:pic') # 获取所有图片
for image in images:
for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id
part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片
if isinstance(part, ImagePart):
image = Image.open(BytesIO(part._blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif isinstance(block, Table):
for row in block.rows:
for cell in row.cells:
for paragraph in cell.paragraphs:
resp += paragraph.text.strip() + "\n"
b_unit.update(1)
return resp

text = doc2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)


if __name__ == '__main__':
loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx")
docs = loader.load()
print(docs)
59 changes: 59 additions & 0 deletions document_loaders/mypptloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from typing import List
import tqdm


class RapidOCRPPTLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def ppt2text(filepath):
from pptx import Presentation
from PIL import Image
import numpy as np
from io import BytesIO
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
prs = Presentation(filepath)
resp = ""

def extract_text(shape):
nonlocal resp
if shape.has_text_frame:
resp += shape.text.strip() + "\n"
if shape.has_table:
for row in shape.table.rows:
for cell in row.cells:
for paragraph in cell.text_frame.paragraphs:
resp += paragraph.text.strip() + "\n"
if shape.shape_type == 13: # 13 表示图片
image = Image.open(BytesIO(shape.image.blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif shape.shape_type == 6: # 6 表示组合
for child_shape in shape.shapes:
extract_text(child_shape)

b_unit = tqdm.tqdm(total=len(prs.slides),
desc="RapidOCRPPTLoader slide index: 1")
# 遍历所有幻灯片
for slide_number, slide in enumerate(prs.slides, start=1):
b_unit.set_description(
"RapidOCRPPTLoader slide index: {}".format(slide_number))
b_unit.refresh()
sorted_shapes = sorted(slide.shapes,
key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历
for shape in sorted_shapes:
extract_text(shape)
b_unit.update(1)
return resp

text = ppt2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)


if __name__ == '__main__':
loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx")
docs = loader.load()
print(docs)
13 changes: 8 additions & 5 deletions server/knowledge_base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,14 @@ def process_entry(entry):
"JSONLoader": [".json"],
"JSONLinesLoader": [".jsonl"],
"CSVLoader": [".csv"],
# "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持
# "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv
"RapidOCRPDFLoader": [".pdf"],
"RapidOCRDocLoader": ['.docx', '.doc'],
"RapidOCRPPTLoader": ['.ppt', '.pptx', ],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.epub', '.odt','.tsv'],
"UnstructuredEmailLoader": ['.eml', '.msg'],
"UnstructuredEPubLoader": ['.epub'],
"UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'],
Expand All @@ -109,7 +114,6 @@ def process_entry(entry):
"UnstructuredXMLLoader": ['.xml'],
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
"EverNoteLoader": ['.enex'],
"UnstructuredFileLoader": ['.txt'],
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]

Expand Down Expand Up @@ -141,15 +145,14 @@ def get_LoaderClass(file_extension):
if file_extension in extensions:
return LoaderClass


# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
'''
根据loader_name和文件路径或内容返回文档加载器。
'''
loader_kwargs = loader_kwargs or {}
try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader",
"RapidOCRDocLoader", "RapidOCRPPTLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
Expand Down
Binary file added tests/samples/ocr_test.docx
Binary file not shown.
Binary file added tests/samples/ocr_test.pptx
Binary file not shown.