Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions spire/journal/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,21 @@ async def get_journal_entries(
return query.all()


async def get_journal_entry(
db_session: Session, journal_entry_id: UUID
) -> Optional[JournalEntry]:
"""
Returns a journal entry by its id. Raises a JournalEntryNotFound error if no such entry is
found in the database.
"""
journal_entry = (
db_session.query(JournalEntry)
.filter(JournalEntry.id == journal_entry_id)
.one_or_none()
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> JournalEntry: to -> Optional[JournalEntry]:

return journal_entry


async def delete_journal_entry(
db_session: Session,
journal_spec: JournalSpec,
Expand Down
53 changes: 17 additions & 36 deletions spire/journal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,21 +1280,12 @@ async def update_entry_content(
es_index = journal.search_index

try:
journal_entry_container = await actions.get_journal_entries(
db_session,
journal_spec,
entry_id,
request.state.user_group_id_list,
journal_entry = await actions.get_journal_entry(
db_session=db_session, journal_entry_id=entry_id
)
if len(journal_entry_container) == 0:
if journal_entry is None:
raise actions.EntryNotFound()
assert len(journal_entry_container) == 1
journal_entry = journal_entry_container[0]
except actions.JournalNotFound:
logger.error(
f"Journal not found with ID={journal_id} for user={request.state.user_id}"
)
raise HTTPException(status_code=404)

except actions.EntryNotFound:
logger.error(
f"Entry not found with ID={entry_id} in journal with ID={journal_id}"
Expand Down Expand Up @@ -1810,14 +1801,11 @@ async def create_tags(

if es_index is not None:
try:
entry_container = await actions.get_journal_entries(
db_session,
journal_spec,
entry_id,
user_group_id_list=request.state.user_group_id_list,
journal_entry = await actions.get_journal_entry(
db_session=db_session, journal_entry_id=entry_id
)
assert len(entry_container) == 1
entry = entry_container[0]
assert journal_entry != None
entry = journal_entry
all_tags = await actions.get_journal_entry_tags(
db_session,
journal_spec,
Expand Down Expand Up @@ -1963,14 +1951,11 @@ async def update_tags(

if es_index is not None:
try:
entry_container = await actions.get_journal_entries(
db_session,
journal_spec,
entry_id,
request.state.user_group_id_list,
journal_entry = await actions.get_journal_entry(
db_session=db_session, journal_entry_id=entry_id
)
assert len(entry_container) == 1
entry = entry_container[0]
assert journal_entry != None
entry = journal_entry
all_tags_str = [tag.tag for tag in tags]
search.new_entry(
es_client,
Expand Down Expand Up @@ -2062,14 +2047,11 @@ async def delete_tag(

if es_index is not None:
try:
entry_container = await actions.get_journal_entries(
db_session,
journal_spec,
entry_id,
user_group_id_list=request.state.user_group_id_list,
journal_entry = await actions.get_journal_entry(
db_session=db_session, journal_entry_id=entry_id
)
assert len(entry_container) == 1
entry = entry_container[0]
assert journal_entry != None
entry = journal_entry
all_tags = await actions.get_journal_entry_tags(
db_session,
journal_spec,
Expand Down Expand Up @@ -2165,15 +2147,14 @@ async def search_journal(
max_score: Optional[float] = 1.0

for entry in rows:
tags: List[str] = [tag.tag for tag in entry.tags]
entry_url = f"{journal_url}/entries/{str(entry.id)}"
content_url = f"{entry_url}/content"
result = JournalSearchResult(
entry_url=entry_url,
content_url=content_url,
title=entry.title,
content=entry.content,
tags=tags,
tags=entry.tags,
created_at=str(entry.created_at),
updated_at=str(entry.updated_at),
score=1.0,
Expand Down
35 changes: 34 additions & 1 deletion spire/journal/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
import elasticsearch
from elasticsearch.client import IndicesClient
from elasticsearch.helpers import bulk
from sqlalchemy import and_, or_, not_
from sqlalchemy import and_, or_, not_, func
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.orm import Session, Query


from . import actions
from .data import JournalSpec, JournalEntryResponse
from ..db import yield_connection_from_env
Expand Down Expand Up @@ -734,6 +735,38 @@ def search_database(
query = query.order_by(JournalEntry.created_at.desc())
num_entries = query.count()
query = query.limit(size).offset(start)

journal_entries_temp = query.cte(name="journal_entries_temp")

entries_ids_with_tags = (
db_session.query(journal_entries_temp.c.id, JournalEntryTag.tag).join(
JournalEntryTag,
JournalEntryTag.journal_entry_id == journal_entries_temp.c.id,
)
).cte(name="entries_ids_with_tags")

aggregated_tags = (
db_session.query(
entries_ids_with_tags.c.id,
func.array_agg(entries_ids_with_tags.c.tag).label("tags"),
)
.group_by(entries_ids_with_tags.c.id)
.cte(name="aggregated_tags")
)

query = db_session.query(
journal_entries_temp.c.id.label("id"),
aggregated_tags.c.tags.label("tags"),
journal_entries_temp.c.title.label("title"),
journal_entries_temp.c.content.label("content"),
journal_entries_temp.c.context_id.label("context_id"),
journal_entries_temp.c.context_url.label("context_url"),
journal_entries_temp.c.context_type.label("context_type"),
journal_entries_temp.c.version_id.label("version_id"),
journal_entries_temp.c.created_at.label("created_at"),
journal_entries_temp.c.updated_at.label("updated_at"),
).join(aggregated_tags, journal_entries_temp.c.id == aggregated_tags.c.id)

rows = query.all()

return num_entries, rows
Expand Down