diff --git a/spire/journal/actions.py b/spire/journal/actions.py index afcdbdd..a79768f 100644 --- a/spire/journal/actions.py +++ b/spire/journal/actions.py @@ -612,6 +612,21 @@ async def get_journal_entries( return query.all() +async def get_journal_entry( + db_session: Session, journal_entry_id: UUID +) -> Optional[JournalEntry]: + """ + Returns a journal entry by its id. Raises a JournalEntryNotFound error if no such entry is + found in the database. + """ + journal_entry = ( + db_session.query(JournalEntry) + .filter(JournalEntry.id == journal_entry_id) + .one_or_none() + ) + return journal_entry + + async def delete_journal_entry( db_session: Session, journal_spec: JournalSpec, diff --git a/spire/journal/api.py b/spire/journal/api.py index 3770463..84370ef 100644 --- a/spire/journal/api.py +++ b/spire/journal/api.py @@ -1280,21 +1280,12 @@ async def update_entry_content( es_index = journal.search_index try: - journal_entry_container = await actions.get_journal_entries( - db_session, - journal_spec, - entry_id, - request.state.user_group_id_list, + journal_entry = await actions.get_journal_entry( + db_session=db_session, journal_entry_id=entry_id ) - if len(journal_entry_container) == 0: + if journal_entry is None: raise actions.EntryNotFound() - assert len(journal_entry_container) == 1 - journal_entry = journal_entry_container[0] - except actions.JournalNotFound: - logger.error( - f"Journal not found with ID={journal_id} for user={request.state.user_id}" - ) - raise HTTPException(status_code=404) + except actions.EntryNotFound: logger.error( f"Entry not found with ID={entry_id} in journal with ID={journal_id}" @@ -1810,14 +1801,11 @@ async def create_tags( if es_index is not None: try: - entry_container = await actions.get_journal_entries( - db_session, - journal_spec, - entry_id, - user_group_id_list=request.state.user_group_id_list, + journal_entry = await actions.get_journal_entry( + db_session=db_session, journal_entry_id=entry_id ) - assert len(entry_container) == 1 - entry = entry_container[0] + assert journal_entry != None + entry = journal_entry all_tags = await actions.get_journal_entry_tags( db_session, journal_spec, @@ -1963,14 +1951,11 @@ async def update_tags( if es_index is not None: try: - entry_container = await actions.get_journal_entries( - db_session, - journal_spec, - entry_id, - request.state.user_group_id_list, + journal_entry = await actions.get_journal_entry( + db_session=db_session, journal_entry_id=entry_id ) - assert len(entry_container) == 1 - entry = entry_container[0] + assert journal_entry != None + entry = journal_entry all_tags_str = [tag.tag for tag in tags] search.new_entry( es_client, @@ -2062,14 +2047,11 @@ async def delete_tag( if es_index is not None: try: - entry_container = await actions.get_journal_entries( - db_session, - journal_spec, - entry_id, - user_group_id_list=request.state.user_group_id_list, + journal_entry = await actions.get_journal_entry( + db_session=db_session, journal_entry_id=entry_id ) - assert len(entry_container) == 1 - entry = entry_container[0] + assert journal_entry != None + entry = journal_entry all_tags = await actions.get_journal_entry_tags( db_session, journal_spec, @@ -2165,7 +2147,6 @@ async def search_journal( max_score: Optional[float] = 1.0 for entry in rows: - tags: List[str] = [tag.tag for tag in entry.tags] entry_url = f"{journal_url}/entries/{str(entry.id)}" content_url = f"{entry_url}/content" result = JournalSearchResult( @@ -2173,7 +2154,7 @@ async def search_journal( content_url=content_url, title=entry.title, content=entry.content, - tags=tags, + tags=entry.tags, created_at=str(entry.created_at), updated_at=str(entry.updated_at), score=1.0, diff --git a/spire/journal/search.py b/spire/journal/search.py index 1e1460e..6c3a8c6 100644 --- a/spire/journal/search.py +++ b/spire/journal/search.py @@ -14,10 +14,11 @@ import elasticsearch from elasticsearch.client import IndicesClient from elasticsearch.helpers import bulk -from sqlalchemy import and_, or_, not_ +from sqlalchemy import and_, or_, not_, func from sqlalchemy.sql.elements import BooleanClauseList from sqlalchemy.orm import Session, Query + from . import actions from .data import JournalSpec, JournalEntryResponse from ..db import yield_connection_from_env @@ -734,6 +735,38 @@ def search_database( query = query.order_by(JournalEntry.created_at.desc()) num_entries = query.count() query = query.limit(size).offset(start) + + journal_entries_temp = query.cte(name="journal_entries_temp") + + entries_ids_with_tags = ( + db_session.query(journal_entries_temp.c.id, JournalEntryTag.tag).join( + JournalEntryTag, + JournalEntryTag.journal_entry_id == journal_entries_temp.c.id, + ) + ).cte(name="entries_ids_with_tags") + + aggregated_tags = ( + db_session.query( + entries_ids_with_tags.c.id, + func.array_agg(entries_ids_with_tags.c.tag).label("tags"), + ) + .group_by(entries_ids_with_tags.c.id) + .cte(name="aggregated_tags") + ) + + query = db_session.query( + journal_entries_temp.c.id.label("id"), + aggregated_tags.c.tags.label("tags"), + journal_entries_temp.c.title.label("title"), + journal_entries_temp.c.content.label("content"), + journal_entries_temp.c.context_id.label("context_id"), + journal_entries_temp.c.context_url.label("context_url"), + journal_entries_temp.c.context_type.label("context_type"), + journal_entries_temp.c.version_id.label("version_id"), + journal_entries_temp.c.created_at.label("created_at"), + journal_entries_temp.c.updated_at.label("updated_at"), + ).join(aggregated_tags, journal_entries_temp.c.id == aggregated_tags.c.id) + rows = query.all() return num_entries, rows