Skip to content

Commit

Permalink
Add Ability to Summarize Documents (#800)
Browse files Browse the repository at this point in the history
* Uses entire file text and summarizer model to generate document summary.
* Uses the contents of the user's query to create a tailored summary.
* Integrates with File Filters #788 for a better UX.
  • Loading branch information
MythicalCow committed Jun 18, 2024
1 parent 677d49d commit d4e5c95
Show file tree
Hide file tree
Showing 21 changed files with 791 additions and 85 deletions.
55 changes: 54 additions & 1 deletion src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ClientApplication,
Conversation,
Entry,
FileObject,
GithubConfig,
GithubRepoConfig,
GoogleUser,
Expand Down Expand Up @@ -731,7 +732,7 @@ async def aget_summarizer_conversation_config():
if server_chat_settings is None or (
server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None
):
return await ChatModelOptions.objects.filter().afirst()
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
return server_chat_settings.summarizer_model or server_chat_settings.default_model

@staticmethod
Expand Down Expand Up @@ -846,6 +847,58 @@ async def aget_text_to_image_model_config():
return await TextToImageModelConfig.objects.filter().afirst()


class FileObjectAdapters:
@staticmethod
def update_raw_text(file_object: FileObject, new_raw_text: str):
file_object.raw_text = new_raw_text
file_object.save()

@staticmethod
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)

@staticmethod
def get_file_objects_by_name(user: KhojUser, file_name: str):
return FileObject.objects.filter(user=user, file_name=file_name).first()

@staticmethod
def get_all_file_objects(user: KhojUser):
return FileObject.objects.filter(user=user).all()

@staticmethod
def delete_file_object_by_name(user: KhojUser, file_name: str):
return FileObject.objects.filter(user=user, file_name=file_name).delete()

@staticmethod
def delete_all_file_objects(user: KhojUser):
return FileObject.objects.filter(user=user).delete()

@staticmethod
async def async_update_raw_text(file_object: FileObject, new_raw_text: str):
file_object.raw_text = new_raw_text
await file_object.asave()

@staticmethod
async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str):
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)

@staticmethod
async def async_get_file_objects_by_name(user: KhojUser, file_name: str):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name))

@staticmethod
async def async_get_all_file_objects(user: KhojUser):
return await sync_to_async(list)(FileObject.objects.filter(user=user))

@staticmethod
async def async_delete_file_object_by_name(user: KhojUser, file_name: str):
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()

@staticmethod
async def async_delete_all_file_objects(user: KhojUser):
return await FileObject.objects.filter(user=user).adelete()


class EntryAdapters:
word_filer = WordFilter()
file_filter = FileFilter()
Expand Down
37 changes: 37 additions & 0 deletions src/khoj/database/migrations/0045_fileobject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Generated by Django 4.2.11 on 2024-06-14 06:13

import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0044_conversation_file_filters"),
]

operations = [
migrations.CreateModel(
name="FileObject",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("file_name", models.CharField(blank=True, default=None, max_length=400, null=True)),
("raw_text", models.TextField()),
(
"user",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
},
),
]
7 changes: 7 additions & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,13 @@ class EntrySource(models.TextChoices):
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)


class FileObject(BaseModel):
# Same as Entry but raw will be a much larger string
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
raw_text = models.TextField()
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)


class EntryDates(BaseModel):
date = models.DateField()
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
Expand Down
2 changes: 1 addition & 1 deletion src/khoj/interface/web/chat.html
Original file line number Diff line number Diff line change
Expand Up @@ -2145,7 +2145,7 @@
<div style="border-top: 1px solid black; ">
<div style="display: flex; align-items: center; justify-content: space-between; margin-bottom: 5px; margin-top: 5px;">
<p style="margin: 0;">Files</p>
<svg id="file-toggle-button" class="file-toggle-button" style="width:20px; height:20px; position: relative; top: 2px" viewBox="0 0 40 40" fill="#000000" xmlns="http://www.w3.org/2000/svg">
<svg class="file-toggle-button" style="width:20px; height:20px; position: relative; top: 2px" viewBox="0 0 40 40" fill="#000000" xmlns="http://www.w3.org/2000/svg">
<path d="M16 0c-8.836 0-16 7.163-16 16s7.163 16 16 16c8.837 0 16-7.163 16-16s-7.163-16-16-16zM16 30.032c-7.72 0-14-6.312-14-14.032s6.28-14 14-14 14 6.28 14 14-6.28 14.032-14 14.032zM23 15h-6v-6c0-0.552-0.448-1-1-1s-1 0.448-1 1v6h-6c-0.552 0-1 0.448-1 1s0.448 1 1 1h6v6c0 0.552 0.448 1 1 1s1-0.448 1-1v-6h6c0.552 0 1-0.448 1-1s-0.448-1-1-1z"></path>
</svg>
</div>
Expand Down
9 changes: 6 additions & 3 deletions src/khoj/processor/content/markdown/markdown_to_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def process(
max_tokens = 256
# Extract Entries from specified Markdown files
with timer("Extract entries from specified Markdown files", logger):
current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
file_to_text_map, current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)

# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
Expand All @@ -50,27 +50,30 @@ def process(
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

return num_new_embeddings, num_deleted_embeddings

@staticmethod
def extract_markdown_entries(markdown_files, max_tokens=256) -> List[Entry]:
def extract_markdown_entries(markdown_files, max_tokens=256) -> Tuple[Dict, List[Entry]]:
"Extract entries by heading from specified Markdown files"
entries: List[str] = []
entry_to_file_map: List[Tuple[str, str]] = []
file_to_text_map = dict()
for markdown_file in markdown_files:
try:
markdown_content = markdown_files[markdown_file]
entries, entry_to_file_map = MarkdownToEntries.process_single_markdown_file(
markdown_content, markdown_file, entries, entry_to_file_map, max_tokens
)
file_to_text_map[markdown_file] = markdown_content
except Exception as e:
logger.error(
f"Unable to process file: {markdown_file}. This file will not be indexed.\n{e}", exc_info=True
)

return MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
return file_to_text_map, MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))

@staticmethod
def process_single_markdown_file(
Expand Down
19 changes: 13 additions & 6 deletions src/khoj/processor/content/org_mode/org_to_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def process(
# Extract Entries from specified Org files
max_tokens = 256
with timer("Extract entries from specified Org files", logger):
current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
file_to_text_map, current_entries = self.extract_org_entries(files, max_tokens=max_tokens)

with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=max_tokens)
Expand All @@ -49,33 +49,40 @@ def process(
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

return num_new_embeddings, num_deleted_embeddings

@staticmethod
def extract_org_entries(
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256
) -> List[Entry]:
) -> Tuple[Dict, List[Entry]]:
"Extract entries from specified Org files"
entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
return OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map, index_heading_entries)
file_to_text_map, entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
return file_to_text_map, OrgToEntries.convert_org_nodes_to_entries(
entries, entry_to_file_map, index_heading_entries
)

@staticmethod
def extract_org_nodes(org_files: dict[str, str], max_tokens) -> Tuple[List[List[Orgnode]], Dict[Orgnode, str]]:
def extract_org_nodes(
org_files: dict[str, str], max_tokens
) -> Tuple[Dict, List[List[Orgnode]], Dict[Orgnode, str]]:
"Extract org nodes from specified org files"
entries: List[List[Orgnode]] = []
entry_to_file_map: List[Tuple[Orgnode, str]] = []
file_to_text_map = {}
for org_file in org_files:
try:
org_content = org_files[org_file]
entries, entry_to_file_map = OrgToEntries.process_single_org_file(
org_content, org_file, entries, entry_to_file_map, max_tokens
)
file_to_text_map[org_file] = org_content
except Exception as e:
logger.error(f"Unable to process file: {org_file}. Skipped indexing it.\nError; {e}", exc_info=True)

return entries, dict(entry_to_file_map)
return file_to_text_map, entries, dict(entry_to_file_map)

@staticmethod
def process_single_org_file(
Expand Down
22 changes: 15 additions & 7 deletions src/khoj/processor/content/pdf/pdf_to_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import logging
import os
from datetime import datetime
from typing import List, Tuple
from typing import Dict, List, Tuple

from langchain_community.document_loaders import PyMuPDFLoader

# importing FileObjectAdapter so that we can add new files and debug file object db.
# from khoj.database.adapters import FileObjectAdapters
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
Expand Down Expand Up @@ -33,7 +35,7 @@ def process(

# Extract Entries from specified Pdf files
with timer("Extract entries from specified PDF files", logger):
current_entries = PdfToEntries.extract_pdf_entries(files)
file_to_text_map, current_entries = PdfToEntries.extract_pdf_entries(files)

# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
Expand All @@ -50,14 +52,15 @@ def process(
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

return num_new_embeddings, num_deleted_embeddings

@staticmethod
def extract_pdf_entries(pdf_files) -> List[Entry]:
def extract_pdf_entries(pdf_files) -> Tuple[Dict, List[Entry]]: # important function
"""Extract entries by page from specified PDF files"""

file_to_text_map = dict()
entries: List[str] = []
entry_to_location_map: List[Tuple[str, str]] = []
for pdf_file in pdf_files:
Expand All @@ -73,17 +76,22 @@ def extract_pdf_entries(pdf_files) -> List[Entry]:
pdf_entries_per_file = [page.page_content for page in loader.load()]
except ImportError:
loader = PyMuPDFLoader(f"{tmp_file}")
pdf_entries_per_file = [page.page_content for page in loader.load()]
entry_to_location_map += zip(pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file))
pdf_entries_per_file = [
page.page_content for page in loader.load()
] # page_content items list for a given pdf.
entry_to_location_map += zip(
pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file)
) # this is an indexed map of pdf_entries for the pdf.
entries.extend(pdf_entries_per_file)
file_to_text_map[pdf_file] = pdf_entries_per_file
except Exception as e:
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
finally:
if os.path.exists(f"{tmp_file}"):
os.remove(f"{tmp_file}")

return PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
return file_to_text_map, PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))

@staticmethod
def convert_pdf_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
Expand Down
9 changes: 6 additions & 3 deletions src/khoj/processor/content/plaintext/plaintext_to_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def process(

# Extract Entries from specified plaintext files
with timer("Extract entries from specified Plaintext files", logger):
current_entries = PlaintextToEntries.extract_plaintext_entries(files)
file_to_text_map, current_entries = PlaintextToEntries.extract_plaintext_entries(files)

# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
Expand All @@ -49,6 +49,7 @@ def process(
deletion_filenames=deletion_file_names,
user=user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

return num_new_embeddings, num_deleted_embeddings
Expand All @@ -63,21 +64,23 @@ def extract_html_content(markup_content: str, markup_type: str):
return soup.get_text(strip=True, separator="\n")

@staticmethod
def extract_plaintext_entries(text_files: Dict[str, str]) -> List[Entry]:
def extract_plaintext_entries(text_files: Dict[str, str]) -> Tuple[Dict, List[Entry]]:
entries: List[str] = []
entry_to_file_map: List[Tuple[str, str]] = []
file_to_text_map = dict()
for text_file in text_files:
try:
text_content = text_files[text_file]
entries, entry_to_file_map = PlaintextToEntries.process_single_plaintext_file(
text_content, text_file, entries, entry_to_file_map
)
file_to_text_map[text_file] = text_content
except Exception as e:
logger.warning(f"Unable to read file: {text_file} as plaintext. Skipping file.")
logger.warning(e, exc_info=True)

# Extract Entries from specified plaintext files
return PlaintextToEntries.convert_text_files_to_entries(entries, dict(entry_to_file_map))
return file_to_text_map, PlaintextToEntries.convert_text_files_to_entries(entries, dict(entry_to_file_map))

@staticmethod
def process_single_plaintext_file(
Expand Down
20 changes: 19 additions & 1 deletion src/khoj/processor/content/text_to_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tqdm import tqdm

from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
from khoj.database.adapters import (
EntryAdapters,
FileObjectAdapters,
get_user_search_model_or_default,
)
from khoj.database.models import Entry as DbEntry
from khoj.database.models import EntryDates, KhojUser
from khoj.search_filter.date_filter import DateFilter
Expand Down Expand Up @@ -120,6 +124,7 @@ def update_embeddings(
deletion_filenames: Set[str] = None,
user: KhojUser = None,
regenerate: bool = False,
file_to_text_map: dict[str, List[str]] = None,
):
with timer("Constructed current entry hashes in", logger):
hashes_by_file = dict[str, set[str]]()
Expand Down Expand Up @@ -186,6 +191,18 @@ def update_embeddings(
logger.error(f"Error adding entries to database:\n{batch_indexing_error}\n---\n{e}", exc_info=True)
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")

if file_to_text_map:
# get the list of file_names using added_entries
filenames_to_update = [entry.file_path for entry in added_entries]
# for each file_name in filenames_to_update, try getting the file object and updating raw_text and if it fails create a new file object
for file_name in filenames_to_update:
raw_text = " ".join(file_to_text_map[file_name])
file_object = FileObjectAdapters.get_file_objects_by_name(user, file_name)
if file_object:
FileObjectAdapters.update_raw_text(file_object, raw_text)
else:
FileObjectAdapters.create_file_object(user, file_name, raw_text)

new_dates = []
with timer("Indexed dates from added entries in", logger):
for added_entry in added_entries:
Expand All @@ -210,6 +227,7 @@ def update_embeddings(
for file_path in deletion_filenames:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_entries += deleted_count
FileObjectAdapters.delete_file_object_by_name(user, file_path)

return len(added_entries), num_deleted_entries

Expand Down
Loading

0 comments on commit d4e5c95

Please sign in to comment.