Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Indexing Images via OCR #823

Merged
merged 15 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/interface/desktop/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const textFileTypes = [
'org', 'md', 'markdown', 'txt', 'html', 'xml',
// Other valid text file extensions from https://google.github.io/magika/model/config.json
'appleplist', 'asm', 'asp', 'batch', 'c', 'cs', 'css', 'csv', 'eml', 'go', 'html', 'ini', 'internetshortcut', 'java', 'javascript', 'json', 'latex', 'lisp', 'makefile', 'markdown', 'mht', 'mum', 'pem', 'perl', 'php', 'powershell', 'python', 'rdf', 'rst', 'rtf', 'ruby', 'rust', 'scala', 'shell', 'smali', 'sql', 'svg', 'symlinktext', 'txt', 'vba', 'winregistry', 'xml', 'yaml']
const binaryFileTypes = ['pdf']
const binaryFileTypes = ['pdf', 'jpg', 'jpeg', 'png']
const validFileTypes = textFileTypes.concat(binaryFileTypes);

const schema = {
Expand Down
11 changes: 8 additions & 3 deletions src/khoj/interface/web/chat.html
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@

To get started, just start typing below. You can also type / to see a list of commands.
`.trim()
const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'];
const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'docx'];
const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'image/jpeg', 'image/png', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'];
const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'jpg', 'jpeg', 'png', 'docx'];
let chatOptions = [];
function createCopyParentText(message) {
return function(event) {
Expand Down Expand Up @@ -974,7 +974,12 @@
fileType = "text/html";
} else if (fileExtension === "pdf") {
fileType = "application/pdf";
} else {
} else if (fileExtension === "jpg" || fileExtension === "jpeg"){
fileType = "image/jpeg";
} else if (fileExtension === "png") {
fileType = "image/png";
}
else {
// Skip this file if its type is not supported
resolve();
return;
Expand Down
Empty file.
118 changes: 118 additions & 0 deletions src/khoj/processor/content/images/image_to_entries.py
MythicalCow marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import base64
import logging
import os
from datetime import datetime
from typing import Dict, List, Tuple

from rapidocr_onnxruntime import RapidOCR

from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry

logger = logging.getLogger(__name__)


class ImageToEntries(TextToEntries):
def __init__(self):
super().__init__()

# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None

# Extract Entries from specified image files
with timer("Extract entries from specified Image files", logger):
file_to_text_map, current_entries = ImageToEntries.extract_image_entries(files)

# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)

# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.IMAGE,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

return num_new_embeddings, num_deleted_embeddings

@staticmethod
def extract_image_entries(image_files) -> Tuple[Dict, List[Entry]]: # important function
"""Extract entries by page from specified image files"""
file_to_text_map = dict()
entries: List[str] = []
entry_to_location_map: List[Tuple[str, str]] = []
for image_file in image_files:
try:
loader = RapidOCR()
bytes = image_files[image_file]
# write the image to a temporary file
timestamp_now = datetime.utcnow().timestamp()
# use either png or jpg
if image_file.endswith(".png"):
tmp_file = f"tmp_image_file_{timestamp_now}.png"
elif image_file.endswith(".jpg") or image_file.endswith(".jpeg"):
tmp_file = f"tmp_image_file_{timestamp_now}.jpg"
with open(tmp_file, "wb") as f:
bytes = image_files[image_file]
f.write(bytes)
try:
image_entries_per_file = ""
result, _ = loader(tmp_file)
if result:
expanded_entries = [text[1] for text in result]
image_entries_per_file = " ".join(expanded_entries)
except ImportError:
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
continue
entry_to_location_map.append((image_entries_per_file, image_file))
entries.extend([image_entries_per_file])
file_to_text_map[image_file] = image_entries_per_file
except Exception as e:
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
finally:
if os.path.exists(tmp_file):
os.remove(tmp_file)
return file_to_text_map, ImageToEntries.convert_image_entries_to_maps(entries, dict(entry_to_location_map))

@staticmethod
def convert_image_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
"Convert each image entries into a dictionary"
entries = []
for parsed_entry in parsed_entries:
entry_filename = entry_to_file_map[parsed_entry]
# Append base filename to compiled entry for context to model
heading = f"{entry_filename}\n"
compiled_entry = f"{heading}{parsed_entry}"
entries.append(
Entry(
compiled=compiled_entry,
raw=parsed_entry,
heading=heading,
file=f"{entry_filename}",
)
)

logger.debug(f"Converted {len(parsed_entries)} image entries to dictionaries")

return entries
30 changes: 29 additions & 1 deletion src/khoj/routers/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from khoj.database.models import GithubConfig, KhojUser, NotionConfig
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
from khoj.processor.content.github.github_to_entries import GithubToEntries
from khoj.processor.content.images.image_to_entries import ImageToEntries
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.notion.notion_to_entries import NotionToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
Expand Down Expand Up @@ -41,6 +42,7 @@ class IndexerInput(BaseModel):
markdown: Optional[dict[str, str]] = None
pdf: Optional[dict[str, bytes]] = None
plaintext: Optional[dict[str, str]] = None
image: Optional[dict[str, bytes]] = None
docx: Optional[dict[str, bytes]] = None


Expand All @@ -65,7 +67,14 @@ async def update(
),
):
user = request.user.object
index_files: Dict[str, Dict[str, str]] = {"org": {}, "markdown": {}, "pdf": {}, "plaintext": {}, "docx": {}}
index_files: Dict[str, Dict[str, str]] = {
"org": {},
"markdown": {},
"pdf": {},
"plaintext": {},
"image": {},
"docx": {},
}
try:
logger.info(f"📬 Updating content index via API call by {client} client")
for file in files:
Expand All @@ -81,6 +90,7 @@ async def update(
markdown=index_files["markdown"],
pdf=index_files["pdf"],
plaintext=index_files["plaintext"],
image=index_files["image"],
docx=index_files["docx"],
)

Expand Down Expand Up @@ -133,6 +143,7 @@ async def update(
"num_markdown": len(index_files["markdown"]),
"num_pdf": len(index_files["pdf"]),
"num_plaintext": len(index_files["plaintext"]),
"num_image": len(index_files["image"]),
"num_docx": len(index_files["docx"]),
}

Expand Down Expand Up @@ -300,6 +311,23 @@ def configure_content(
logger.error(f"🚨 Failed to setup Notion: {e}", exc_info=True)
success = False

try:
# Initialize Image Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value) and files[
"image"
]:
logger.info("🖼️ Setting up search for images")
# Extract Entries, Generate Image Embeddings
text_search.setup(
ImageToEntries,
files.get("image"),
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
success = False
try:
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Docx.value) and files["docx"]:
logger.info("📄 Setting up search for docx")
Expand Down
4 changes: 2 additions & 2 deletions src/khoj/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
elif file_type in ["application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"]:
return "docx", encoding
elif file_type in ["image/jpeg"]:
return "jpeg", encoding
return "image", encoding
elif file_type in ["image/png"]:
return "png", encoding
return "image", encoding
elif content_group in ["code", "text"]:
return "plaintext", encoding
else:
Expand Down
1 change: 1 addition & 0 deletions src/khoj/utils/rawconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class ContentConfig(ConfigBase):
plaintext: Optional[TextContentConfig] = None
github: Optional[GithubContentConfig] = None
notion: Optional[NotionContentConfig] = None
image: Optional[TextContentConfig] = None
docx: Optional[TextContentConfig] = None


Expand Down
Binary file added tests/data/images/nasdaq.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/images/testocr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 21 additions & 0 deletions tests/test_image_to_entries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os

from khoj.processor.content.images.image_to_entries import ImageToEntries


def test_png_to_jsonl():
with open("tests/data/images/testocr.png", "rb") as f:
image_bytes = f.read()
data = {"tests/data/images/testocr.png": image_bytes}
entries = ImageToEntries.extract_image_entries(image_files=data)
assert len(entries) == 2
assert "opencv-python" in entries[1][0].raw


def test_jpg_to_jsonl():
with open("tests/data/images/nasdaq.jpg", "rb") as f:
image_bytes = f.read()
data = {"tests/data/images/nasdaq.jpg": image_bytes}
entries = ImageToEntries.extract_image_entries(image_files=data)
assert len(entries) == 2
assert "investments" in entries[1][0].raw
Loading