Skip to content

Commit

Permalink
Improve sync between FTS db and PostgreSQL
Browse files Browse the repository at this point in the history
  • Loading branch information
medihack committed Jun 14, 2024
1 parent e26a56b commit 71b1c85
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 67 deletions.
14 changes: 7 additions & 7 deletions radis/opensearch/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def register_app():
update_documents,
)

def handle_created_reports(report_ids: list[int]) -> None:
create_documents(report_ids)
def handle_created_reports(reports: list[Report]) -> None:
create_documents(reports)

register_reports_created_handler(
ReportsCreatedHandler(
Expand All @@ -44,8 +44,8 @@ def handle_created_reports(report_ids: list[int]) -> None:
)
)

def handle_updated_reports(report_ids: list[int]) -> None:
update_documents(report_ids)
def handle_updated_reports(reports: list[Report]) -> None:
update_documents(reports)

register_reports_updated_handler(
ReportsUpdatedHandler(
Expand All @@ -54,8 +54,8 @@ def handle_updated_reports(report_ids: list[int]) -> None:
)
)

def handle_deleted_reports(document_ids: list[str]) -> None:
delete_documents(document_ids)
def handle_deleted_reports(reports: list[Report]) -> None:
delete_documents(reports)

register_reports_deleted_handler(
ReportsDeletedHandler(
Expand All @@ -65,7 +65,7 @@ def handle_deleted_reports(document_ids: list[str]) -> None:
)

def fetch_opensearch_document(report: Report) -> dict[str, Any]:
return fetch_document(report.document_id)
return fetch_document(report)

register_document_fetcher("opensearch", fetch_opensearch_document)

Expand Down
19 changes: 7 additions & 12 deletions radis/opensearch/utils/document_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,36 +23,31 @@ def _dictify_report_for_opensearch(report: Report) -> dict[str, Any]:
}


def create_documents(report_ids: list[int]) -> None:
def create_documents(reports: list[Report]) -> None:
client = get_client()
reports = Report.objects.filter(id__in=report_ids)

for report in reports:
index_name = f"reports_{report.language.code}"
body = _dictify_report_for_opensearch(report)
client.create(index=index_name, id=report.document_id, body=body)


def update_documents(report_ids: list[int]) -> None:
def update_documents(reports: list[Report]) -> None:
client = get_client()
reports = Report.objects.filter(id__in=report_ids)

for report in reports:
index_name = f"reports_{report.language.code}"
body = _dictify_report_for_opensearch(report)
client.update(index=index_name, id=report.document_id, body={"doc": body})


def delete_documents(document_ids: list[str]) -> None:
def delete_documents(reports: list[Report]) -> None:
client = get_client()

for document_id in document_ids:
client.delete(index="reports", id=document_id)
for report in reports:
client.delete(index=f"reports_{report.language.code}", id=report.document_id)


def fetch_document(document_id: str) -> dict[str, Any]:
def fetch_document(report: Report) -> dict[str, Any]:
client = get_client()
return client.get(index="reports", id=document_id)
return client.get(index=f"reports_{report.language.code}", id=report.document_id)


def document_from_opensearch_response(record: dict[str, Any]) -> ReportDocument:
Expand Down
28 changes: 11 additions & 17 deletions radis/reports/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from rest_framework.exceptions import ValidationError

from ..models import Language, Metadata, Modality, Report
from ..signals import report_signal_processor


class MetadataSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -57,27 +56,22 @@ def create(self, validated_data: Any) -> Any:
metadata = validated_data.pop("metadata")
modalities = validated_data.pop("modalities")

try:
report_signal_processor.pause()

with transaction.atomic():
language_instance, _ = Language.objects.get_or_create(**language)
with transaction.atomic():
language_instance, _ = Language.objects.get_or_create(**language)

report = Report.objects.create(**validated_data, language=language_instance)
report = Report.objects.create(**validated_data, language=language_instance)

report.groups.set(groups)
report.groups.set(groups)

for metadata in metadata:
Metadata.objects.create(report=report, **metadata)
for metadata in metadata:
Metadata.objects.create(report=report, **metadata)

modality_instances: list[Modality] = []
for modality in modalities:
modality_instance, _ = Modality.objects.get_or_create(**modality)
modality_instances.append(modality_instance)
modality_instances: list[Modality] = []
for modality in modalities:
modality_instance, _ = Modality.objects.get_or_create(**modality)
modality_instances.append(modality_instance)

report.modalities.set(modality_instances)
finally:
report_signal_processor.resume()
report.modalities.set(modality_instances)

return report

Expand Down
58 changes: 39 additions & 19 deletions radis/reports/api/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rest_framework.serializers import BaseSerializer

from ..models import Report
from ..signals import report_signal_processor
from ..site import (
document_fetchers,
reports_created_handlers,
Expand Down Expand Up @@ -65,6 +66,13 @@ def retrieve(self, request: Request, *args: Any, **kwargs: Any) -> Response:

return Response(data)

def create(self, request: Request, *args: Any, **kwargs: Any) -> Response:
try:
report_signal_processor.pause()
return super().create(request, *args, **kwargs)
finally:
report_signal_processor.resume()

def perform_create(self, serializer: BaseSerializer) -> None:
super().perform_create(serializer)
assert serializer.instance
Expand All @@ -76,27 +84,32 @@ def on_commit():
for handler in reports_created_handlers:
report_ids = [report.id for report in reports]
logger.debug(f"{handler.name} - handle newly created reports: {report_ids}")
handler.handle(report_ids)
handler.handle(reports)

transaction.on_commit(on_commit)

def update(self, request: Request, *args: Any, **kwargs: Any) -> Response:
# DRF itself does not support upsert.
# Workaround adapted from https://gist.github.com/tomchristie/a2ace4577eff2c603b1b
upsert = request.GET.get("upsert", "").lower() in ["true", "1", "yes"]
if not upsert:
return super().update(request, *args, **kwargs)
else:
instance = self.get_object_or_none()
serializer = self.get_serializer(instance, data=request.data)
serializer.is_valid(raise_exception=True)

if instance is None:
self.perform_create(serializer)
return Response(serializer.data, status=status.HTTP_201_CREATED)

self.perform_update(serializer)
return Response(serializer.data)
try:
report_signal_processor.pause()

# DRF itself does not support upsert.
# Workaround adapted from https://gist.github.com/tomchristie/a2ace4577eff2c603b1b
upsert = request.GET.get("upsert", "").lower() in ["true", "1", "yes"]
if not upsert:
return super().update(request, *args, **kwargs)
else:
instance = self.get_object_or_none()
serializer = self.get_serializer(instance, data=request.data)
serializer.is_valid(raise_exception=True)

if instance is None:
self.perform_create(serializer)
return Response(serializer.data, status=status.HTTP_201_CREATED)

self.perform_update(serializer)
return Response(serializer.data)
finally:
report_signal_processor.resume()

def get_object_or_none(self) -> Report | None:
try:
Expand All @@ -118,7 +131,7 @@ def on_commit():
for handler in reports_updated_handlers:
report_ids = [report.id for report in reports]
logger.debug(f"{handler.name} - handle updated reports: {report_ids}")
handler.handle(report_ids)
handler.handle(reports)

transaction.on_commit(on_commit)

Expand All @@ -127,12 +140,19 @@ def partial_update(self, request: Request, *args: Any, **kwargs: Any) -> Respons
assert request.method
raise MethodNotAllowed(request.method)

def destroy(self, request: Request, *args: Any, **kwargs: Any) -> Response:
try:
report_signal_processor.pause()
return super().destroy(request, *args, **kwargs)
finally:
report_signal_processor.resume()

def perform_destroy(self, instance: Report) -> None:
super().perform_destroy(instance)

def on_commit():
for handler in reports_deleted_handlers:
logger.debug(f"{handler.name} - handle deleted report: {instance.document_id}")
handler.handle([instance.document_id])
handler.handle([instance])

transaction.on_commit(on_commit)
6 changes: 3 additions & 3 deletions radis/reports/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def _handle_save(self, sender: type[Report], instance: Report, created: bool, **
if created:
logger.debug("Received a signal that a report has been created: %s", instance)
transaction.on_commit(
lambda: [handler.handle([instance.id]) for handler in reports_created_handlers]
lambda: [handler.handle([instance]) for handler in reports_created_handlers]
)
else:
logger.debug("Received a signal that a report has been updated: %s", instance)
transaction.on_commit(
lambda: [handler.handle([instance.id]) for handler in reports_updated_handlers]
lambda: [handler.handle([instance]) for handler in reports_updated_handlers]
)

def _handle_delete(self, sender: type[Report], instance: Report, **kwargs):
Expand All @@ -55,7 +55,7 @@ def _handle_delete(self, sender: type[Report], instance: Report, **kwargs):

logger.debug("Received a signal that a report has been deleted: %s", instance)
transaction.on_commit(
lambda: [handler.handle([instance.document_id]) for handler in reports_deleted_handlers]
lambda: [handler.handle([instance]) for handler in reports_deleted_handlers]
)


Expand Down
6 changes: 3 additions & 3 deletions radis/reports/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class ReportsCreatedHandler(NamedTuple):
name: str
handle: Callable[[list[int]], None]
handle: Callable[[list[Report]], None]


reports_created_handlers: list[ReportsCreatedHandler] = []
Expand All @@ -23,7 +23,7 @@ def register_reports_created_handler(handler: ReportsCreatedHandler) -> None:

class ReportsUpdatedHandler(NamedTuple):
name: str
handle: Callable[[list[int]], None]
handle: Callable[[list[Report]], None]


reports_updated_handlers: list[ReportsUpdatedHandler] = []
Expand All @@ -39,7 +39,7 @@ def register_reports_updated_handler(handler: ReportsUpdatedHandler) -> None:

class ReportsDeletedHandler(NamedTuple):
name: str
handle: Callable[[list[str]], None]
handle: Callable[[list[Report]], None]


reports_deleted_handlers: list[ReportsDeletedHandler] = []
Expand Down
12 changes: 6 additions & 6 deletions radis/vespa/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def register_app():
from .utils.document_utils import fetch_document
from .vespa_app import MAX_RETRIEVAL_HITS, MAX_SEARCH_HITS

def handle_created_reports(report_ids: list[int]) -> None:
process_created_reports.delay(report_ids)
def handle_created_reports(reports: list[Report]) -> None:
process_created_reports.delay([report.id for report in reports])

register_reports_created_handler(
ReportsCreatedHandler(
Expand All @@ -41,8 +41,8 @@ def handle_created_reports(report_ids: list[int]) -> None:
)
)

def handle_updated_reports(report_ids: list[int]) -> None:
process_updated_reports.delay(report_ids)
def handle_updated_reports(reports: list[Report]) -> None:
process_updated_reports.delay([report.id for report in reports])

register_reports_updated_handler(
ReportsUpdatedHandler(
Expand All @@ -51,8 +51,8 @@ def handle_updated_reports(report_ids: list[int]) -> None:
)
)

def handle_deleted_reports(document_ids: list[str]) -> None:
process_deleted_reports.delay(document_ids)
def handle_deleted_reports(reports: list[Report]) -> None:
process_deleted_reports.delay([report.document_id for report in reports])

register_reports_deleted_handler(
ReportsDeletedHandler(
Expand Down

0 comments on commit 71b1c85

Please sign in to comment.