Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port File Store from Volume to PG #1241

Merged
merged 9 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions backend/alembic/versions/4738e4b3bae1_pg_file_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""PG File Store

Revision ID: 4738e4b3bae1
Revises: e91df4e935ef
Create Date: 2024-03-20 18:53:32.461518

"""
from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision = "4738e4b3bae1"
down_revision = "e91df4e935ef"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"file_store",
sa.Column("file_name", sa.String(), nullable=False),
sa.Column("lobj_oid", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("file_name"),
)


def downgrade() -> None:
op.drop_table("file_store")
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Private Personas DocumentSets

Revision ID: e91df4e935ef
Revises: 91fd3b470d1a
Create Date: 2024-03-17 11:47:24.675881

"""
import fastapi_users_db_sqlalchemy
from alembic import op
Expand Down
4 changes: 0 additions & 4 deletions backend/danswer/background/celery/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,4 @@ def clean_old_temp_files_task(
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
},
"clean-old-temp-files": {
"task": "clean_old_temp_files_task",
"schedule": timedelta(minutes=30),
},
}
22 changes: 11 additions & 11 deletions backend/danswer/connectors/cross_connector_utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import os
import re
import zipfile
from collections.abc import Generator
from pathlib import Path
from collections.abc import Iterator
from typing import Any
from typing import IO

Expand Down Expand Up @@ -78,11 +77,11 @@ def is_macos_resource_fork_file(file_name: str) -> bool:
# to the zip file. This file should contain a list of objects with the following format:
# [{ "filename": "file1.txt", "link": "https://example.com/file1.txt" }]
def load_files_from_zip(
zip_location: str | Path,
zip_file_io: IO,
ignore_macos_resource_fork_files: bool = True,
ignore_dirs: bool = True,
) -> Generator[tuple[zipfile.ZipInfo, IO[Any], dict[str, Any]], None, None]:
with zipfile.ZipFile(zip_location, "r") as zip_file:
) -> Iterator[tuple[zipfile.ZipInfo, IO[Any], dict[str, Any]]]:
with zipfile.ZipFile(zip_file_io, "r") as zip_file:
zip_metadata = {}
try:
metadata_file_info = zip_file.getinfo(".danswer_metadata.json")
Expand All @@ -109,18 +108,19 @@ def load_files_from_zip(
yield file_info, file, zip_metadata.get(file_info.filename, {})


def detect_encoding(file_path: str | Path) -> str:
with open(file_path, "rb") as file:
raw_data = file.read(50000) # Read a portion of the file to guess encoding
return chardet.detect(raw_data)["encoding"] or "utf-8"
def detect_encoding(file: IO[bytes]) -> str:
raw_data = file.read(50000)
encoding = chardet.detect(raw_data)["encoding"] or "utf-8"
file.seek(0)
return encoding


def read_file(
file_reader: IO[Any], encoding: str = "utf-8", errors: str = "replace"
file: IO, encoding: str = "utf-8", errors: str = "replace"
) -> tuple[str, dict]:
metadata = {}
file_content_raw = ""
for ind, line in enumerate(file_reader):
for ind, line in enumerate(file):
try:
line = line.decode(encoding) if isinstance(line, bytes) else line
except UnicodeDecodeError:
Expand Down
78 changes: 44 additions & 34 deletions backend/danswer/connectors/file/connector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from pathlib import Path
from typing import Any
from typing import IO

from sqlalchemy.orm import Session

from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.file_utils import detect_encoding
Expand All @@ -20,37 +22,40 @@
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.file_store import get_default_file_store
from danswer.utils.logger import setup_logger

logger = setup_logger()


def _open_files_at_location(
file_path: str | Path,
) -> Generator[tuple[str, IO[Any], dict[str, Any]], Any, None]:
extension = get_file_ext(file_path)
def _read_files_and_metadata(
file_name: str,
db_session: Session,
) -> Iterator[tuple[str, IO, dict[str, Any]]]:
"""Reads the file into IO, in the case of a zip file, yields each individual
file contained within, also includes the metadata dict if packaged in the zip"""
extension = get_file_ext(file_name)
metadata: dict[str, Any] = {}
directory_path = os.path.dirname(file_name)

file_content = get_default_file_store(db_session).read_file(file_name, mode="b")

if extension == ".zip":
for file_info, file, metadata in load_files_from_zip(
file_path, ignore_dirs=True
file_content, ignore_dirs=True
):
yield file_info.filename, file, metadata
elif extension in [".txt", ".md", ".mdx"]:
encoding = detect_encoding(file_path)
with open(file_path, "r", encoding=encoding, errors="replace") as file:
yield os.path.basename(file_path), file, metadata
elif extension == ".pdf":
with open(file_path, "rb") as file:
yield os.path.basename(file_path), file, metadata
yield os.path.join(directory_path, file_info.filename), file, metadata
elif extension in [".txt", ".md", ".mdx", ".pdf"]:
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_path}' with extension '{extension}'")
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")


def _process_file(
file_name: str,
file: IO[Any],
metadata: dict[str, Any] = {},
metadata: dict[str, Any] | None = None,
pdf_pass: str | None = None,
) -> list[Document]:
extension = get_file_ext(file_name)
Expand All @@ -65,8 +70,9 @@ def _process_file(
file=file, file_name=file_name, pdf_pass=pdf_pass
)
else:
file_content_raw, file_metadata = read_file(file)
all_metadata = {**metadata, **file_metadata}
encoding = detect_encoding(file)
file_content_raw, file_metadata = read_file(file, encoding=encoding)
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata

# If this is set, we will show this in the UI as the "name" of the file
file_display_name_override = all_metadata.get("file_display_name")
Expand Down Expand Up @@ -114,7 +120,8 @@ def _process_file(
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],
source=DocumentSource.FILE,
semantic_identifier=file_display_name_override or file_name,
semantic_identifier=file_display_name_override
or os.path.basename(file_name),
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
Expand All @@ -140,24 +147,27 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None

def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
for file_location in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _open_files_at_location(file_location)

for file_name, file, metadata in files:
metadata["time_updated"] = metadata.get(
"time_updated", current_datetime
)
documents.extend(
_process_file(file_name, file, metadata, self.pdf_pass)
with Session(get_sqlalchemy_engine()) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(
file_name=str(file_path), db_session=db_session
)

if len(documents) >= self.batch_size:
yield documents
documents = []
for file_name, file, metadata in files:
metadata["time_updated"] = metadata.get(
"time_updated", current_datetime
)
documents.extend(
_process_file(file_name, file, metadata, self.pdf_pass)
)

if len(documents) >= self.batch_size:
yield documents
documents = []

if documents:
yield documents
if documents:
yield documents


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/connectors/google_drive/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def _process_folder_paths(

def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorugh
(1) A credential which holds a token acquired via a user going thorough
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
Expand Down
10 changes: 9 additions & 1 deletion backend/danswer/connectors/google_site/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from bs4 import BeautifulSoup
from bs4 import Tag
from sqlalchemy.orm import Session

from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
Expand All @@ -15,6 +16,8 @@
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.file_store import get_default_file_store
from danswer.utils.logger import setup_logger

logger = setup_logger()
Expand Down Expand Up @@ -66,8 +69,13 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []

with Session(get_sqlalchemy_engine()) as db_session:
file_content_io = get_default_file_store(db_session).read_file(
self.zip_path, mode="b"
)

# load the HTML files
files = load_files_from_zip(self.zip_path)
files = load_files_from_zip(file_content_io)
count = 0
for file_info, file_io, _metadata in files:
# skip non-published files
Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/connectors/web/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def __init__(
self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url))

elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD:
logger.warning(
"This is not a UI supported Web Connector flow, "
"are you sure you want to do this?"
)
self.to_visit_list = _read_urls_file(base_url)

else:
Expand Down
96 changes: 96 additions & 0 deletions backend/danswer/db/file_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from abc import ABC
from abc import abstractmethod
from typing import IO

from sqlalchemy.orm import Session

from danswer.db.pg_file_store import create_populate_lobj
from danswer.db.pg_file_store import delete_lobj_by_id
from danswer.db.pg_file_store import delete_pgfilestore_by_file_name
from danswer.db.pg_file_store import get_pgfilestore_by_file_name
from danswer.db.pg_file_store import read_lobj
from danswer.db.pg_file_store import upsert_pgfilestore


class FileStore(ABC):
"""
An abstraction for storing files and large binary objects.
"""

@abstractmethod
def save_file(self, file_name: str, content: IO) -> None:
"""
Save a file to the blob store

Parameters:
- connector_name: Name of the CC-Pair (as specified by the user in the UI)
- file_name: Name of the file to save
- content: Contents of the file
"""
raise NotImplementedError

@abstractmethod
def read_file(self, file_name: str, mode: str | None) -> IO:
"""
Read the content of a given file by the name

Parameters:
- file_name: Name of file to read

Returns:
Contents of the file and metadata dict
"""

@abstractmethod
def delete_file(self, file_name: str) -> None:
"""
Delete a file by its name.

Parameters:
- file_name: Name of file to delete
"""


class PostgresBackedFileStore(FileStore):
def __init__(self, db_session: Session):
self.db_session = db_session

def save_file(self, file_name: str, content: IO) -> None:
try:
# The large objects in postgres are saved as special objects can can be listed with
# SELECT * FROM pg_largeobject_metadata;
obj_id = create_populate_lobj(content=content, db_session=self.db_session)
upsert_pgfilestore(
file_name=file_name, lobj_oid=obj_id, db_session=self.db_session
)
self.db_session.commit()
except Exception:
self.db_session.rollback()
raise

def read_file(self, file_name: str, mode: str | None = None) -> IO:
file_record = get_pgfilestore_by_file_name(
file_name=file_name, db_session=self.db_session
)
return read_lobj(
lobj_oid=file_record.lobj_oid, db_session=self.db_session, mode=mode
)

def delete_file(self, file_name: str) -> None:
try:
file_record = get_pgfilestore_by_file_name(
file_name=file_name, db_session=self.db_session
)
delete_lobj_by_id(file_record.lobj_oid, db_session=self.db_session)
delete_pgfilestore_by_file_name(
file_name=file_name, db_session=self.db_session
)
self.db_session.commit()
except Exception:
self.db_session.rollback()
raise


def get_default_file_store(db_session: Session) -> FileStore:
# The only supported file store now is the Postgres File Store
return PostgresBackedFileStore(db_session=db_session)
6 changes: 6 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,3 +859,9 @@ class KVStore(Base):

key: Mapped[str] = mapped_column(String, primary_key=True)
value: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=False)


class PGFileStore(Base):
__tablename__ = "file_store"
file_name = mapped_column(String, primary_key=True)
lobj_oid = mapped_column(Integer, nullable=False)
Loading
Loading