diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 0eb8cd1c5..de338ab1b 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -3,6 +3,9 @@ import boto3 import pika +import aio_pika +from aio_pika.abc import AbstractChannel + from app.config import settings from app.search.connect import connect_elasticsearch from minio import Minio @@ -74,18 +77,27 @@ async def get_external_fs() -> AsyncGenerator[Minio, None]: yield file_system -def get_rabbitmq() -> BlockingChannel: +async def get_rabbitmq() -> AbstractChannel: """Client to connect to RabbitMQ for listeners/extractors interactions.""" - credentials = pika.PlainCredentials(settings.RABBITMQ_USER, settings.RABBITMQ_PASS) - parameters = pika.ConnectionParameters( - settings.RABBITMQ_HOST, credentials=credentials - ) + RABBITMQ_URL = f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASS}@{settings.RABBITMQ_HOST}/" + logger.debug("Connecting to rabbitmq at %s", settings.RABBITMQ_HOST) - connection = pika.BlockingConnection(parameters) - channel = connection.channel() + connection = await aio_pika.connect_robust(RABBITMQ_URL) + channel = await connection.channel() + + print(f"DEBUG: get_rabbitmq() called. Returning channel of type: {type(channel)}") return channel +# Keep the old function for compatibility if needed +def get_blocking_rabbitmq() -> BlockingChannel: + """Legacy blocking RabbitMQ client (for extractors that need it)""" + credentials = pika.PlainCredentials(settings.RABBITMQ_USER, settings.RABBITMQ_PASS) + parameters = pika.ConnectionParameters(settings.RABBITMQ_HOST, credentials=credentials) + connection = pika.BlockingConnection(parameters) + return connection.channel() + + async def get_elasticsearchclient(): es = await connect_elasticsearch() return es diff --git a/backend/app/rabbitmq/listeners.py b/backend/app/rabbitmq/listeners.py index c9defa7f1..dd3dbf4af 100644 --- a/backend/app/rabbitmq/listeners.py +++ b/backend/app/rabbitmq/listeners.py @@ -17,34 +17,29 @@ from app.routers.users import get_user_job_key from fastapi import Depends from pika.adapters.blocking_connection import BlockingChannel +import aio_pika +from aio_pika.abc import AbstractChannel -async def create_reply_queue(): - channel: BlockingChannel = dependencies.get_rabbitmq() - - if ( - config_entry := await ConfigEntryDB.find_one({"key": "instance_id"}) - ) is not None: +async def create_reply_queue(channel: AbstractChannel): + if (config_entry := await ConfigEntryDB.find_one({"key": "instance_id"})) is not None: instance_id = config_entry.value else: - # If no ID has been generated for this instance, generate a 10-digit alphanumeric identifier instance_id = "".join( - random.choice( - string.ascii_uppercase + string.ascii_lowercase + string.digits - ) + random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(10) ) config_entry = ConfigEntryDB(key="instance_id", value=instance_id) await config_entry.insert() - queue_name = "clowder.%s" % instance_id - channel.exchange_declare(exchange="clowder", durable=True) - result = channel.queue_declare( - queue=queue_name, durable=True, exclusive=False, auto_delete=False - ) - queue_name = result.method.queue - channel.queue_bind(exchange="clowder", queue=queue_name) - return queue_name + queue_name = f"clowder.{instance_id}" + + # Use aio_pika methods instead of pika methods + exchange = await channel.declare_exchange("clowder", durable=True) + queue = await channel.declare_queue(queue_name, durable=True, exclusive=False, auto_delete=False) + await queue.bind(exchange) + + return queue.name async def submit_file_job( @@ -52,8 +47,9 @@ async def submit_file_job( routing_key: str, parameters: dict, user: UserOut, - rabbitmq_client: BlockingChannel, + rabbitmq_client: AbstractChannel, ): + print(f"DEBUG submit_file_job: Got client of type: {type(rabbitmq_client)}") # Create an entry in job history with unique ID job = EventListenerJobDB( listener_id=routing_key, @@ -65,6 +61,7 @@ async def submit_file_job( ) await job.insert() + current_secretKey = await get_user_job_key(user.email) msg_body = EventListenerJobMessage( filename=file_out.name, @@ -75,15 +72,19 @@ async def submit_file_job( job_id=str(job.id), parameters=parameters, ) - reply_to = await create_reply_queue() + + # Use aio_pika publishing + # Get the existing clowder exchange + reply_to = await create_reply_queue(rabbitmq_client) print("RABBITMQ_CLIENT: " + str(rabbitmq_client)) - rabbitmq_client.basic_publish( - exchange="", - routing_key=routing_key, - body=json.dumps(msg_body.dict(), ensure_ascii=False), - properties=pika.BasicProperties( - content_type="application/json", delivery_mode=1, reply_to=reply_to + await rabbitmq_client.default_exchange.publish( + aio_pika.Message( + body=json.dumps(msg_body.dict(), ensure_ascii=False).encode('utf-8'), + content_type="application/json", + delivery_mode=aio_pika.DeliveryMode.PERSISTENT, + reply_to=reply_to, ), + routing_key=routing_key, ) return str(job.id) @@ -93,7 +94,7 @@ async def submit_dataset_job( routing_key: str, parameters: dict, user: UserOut, - rabbitmq_client: BlockingChannel = Depends(dependencies.get_rabbitmq), + rabbitmq_client: AbstractChannel, ): # Create an entry in job history with unique ID job = EventListenerJobDB( @@ -113,13 +114,14 @@ async def submit_dataset_job( job_id=str(job.id), parameters=parameters, ) - reply_to = await create_reply_queue() - rabbitmq_client.basic_publish( - exchange="", - routing_key=routing_key, - body=json.dumps(msg_body.dict(), ensure_ascii=False), - properties=pika.BasicProperties( - content_type="application/json", delivery_mode=1, reply_to=reply_to + reply_to = await create_reply_queue(rabbitmq_client) + await rabbitmq_client.default_exchange.publish( + aio_pika.Message( + body=json.dumps(msg_body.dict(), ensure_ascii=False).encode('utf-8'), + content_type="application/json", + delivery_mode=aio_pika.DeliveryMode.PERSISTENT, + reply_to=reply_to, ), + routing_key=routing_key, ) return str(job.id) diff --git a/backend/app/routers/datasets.py b/backend/app/routers/datasets.py index ae608174b..b55f333bb 100644 --- a/backend/app/routers/datasets.py +++ b/backend/app/routers/datasets.py @@ -70,6 +70,8 @@ from fastapi.security import HTTPBearer from minio import Minio from pika.adapters.blocking_connection import BlockingChannel +import aio_pika +from aio_pika.abc import AbstractChannel from pymongo import DESCENDING from rocrate.model.person import Person from rocrate.rocrate import ROCrate @@ -944,7 +946,7 @@ async def save_file( fs: Minio = Depends(dependencies.get_fs), file: UploadFile = File(...), es=Depends(dependencies.get_elasticsearchclient), - rabbitmq_client: BlockingChannel = Depends(dependencies.get_rabbitmq), + rabbitmq_client: AbstractChannel = Depends(dependencies.get_rabbitmq), allow: bool = Depends(Authorization("uploader")), ): if (dataset := await DatasetDB.get(PydanticObjectId(dataset_id))) is not None: @@ -996,7 +998,7 @@ async def save_files( user=Depends(get_current_user), fs: Minio = Depends(dependencies.get_fs), es=Depends(dependencies.get_elasticsearchclient), - rabbitmq_client: BlockingChannel = Depends(dependencies.get_rabbitmq), + rabbitmq_client: AbstractChannel = Depends(dependencies.get_rabbitmq), allow: bool = Depends(Authorization("uploader")), ): if (dataset := await DatasetDB.get(PydanticObjectId(dataset_id))) is not None: @@ -1056,7 +1058,7 @@ async def save_local_file( folder_id: Optional[str] = None, user=Depends(get_current_user), es=Depends(dependencies.get_elasticsearchclient), - rabbitmq_client: BlockingChannel = Depends(dependencies.get_rabbitmq), + rabbitmq_client: AbstractChannel = Depends(dependencies.get_rabbitmq), allow: bool = Depends(Authorization("uploader")), ): if (dataset := await DatasetDB.get(PydanticObjectId(dataset_id))) is not None: @@ -1110,7 +1112,7 @@ async def create_dataset_from_zip( fs: Minio = Depends(dependencies.get_fs), file: UploadFile = File(...), es: Elasticsearch = Depends(dependencies.get_elasticsearchclient), - rabbitmq_client: BlockingChannel = Depends(dependencies.get_rabbitmq), + rabbitmq_client: AbstractChannel = Depends(dependencies.get_rabbitmq), token: str = Depends(get_token), ): if file.filename.endswith(".zip") is False: @@ -1427,7 +1429,7 @@ async def get_dataset_extract( # parameters don't have a fixed model shape parameters: dict = None, user=Depends(get_current_user), - rabbitmq_client: BlockingChannel = Depends(dependencies.get_rabbitmq), + rabbitmq_client: AbstractChannel = Depends(dependencies.get_rabbitmq), allow: bool = Depends(Authorization("uploader")), ): if extractorName is None: diff --git a/backend/app/routers/feeds.py b/backend/app/routers/feeds.py index 47e2b2d23..86c3b1740 100644 --- a/backend/app/routers/feeds.py +++ b/backend/app/routers/feeds.py @@ -4,7 +4,8 @@ from beanie.operators import Or, RegEx from fastapi import APIRouter, Depends, HTTPException from pika.adapters.blocking_connection import BlockingChannel - +import aio_pika +from aio_pika.abc import AbstractChannel from app.deps.authorization_deps import FeedAuthorization, ListenerAuthorization from app.keycloak_auth import get_current_user, get_current_username from app.models.feeds import FeedDB, FeedIn, FeedOut @@ -41,7 +42,7 @@ async def check_feed_listeners( es_client, file_out: FileOut, user: UserOut, - rabbitmq_client: BlockingChannel, + rabbitmq_client: AbstractChannel, ): """Automatically submit new file to listeners on feeds that fit the search criteria.""" listener_ids_found = [] diff --git a/backend/app/routers/files.py b/backend/app/routers/files.py index d4276e0ce..b34ec9bc5 100644 --- a/backend/app/routers/files.py +++ b/backend/app/routers/files.py @@ -2,7 +2,9 @@ import time from datetime import datetime, timedelta from typing import List, Optional, Union - +import json +from json import JSONEncoder +from aio_pika import Message from app import dependencies from app.config import settings from app.db.file.download import _increment_file_downloads @@ -33,14 +35,23 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from minio import Minio from pika.adapters.blocking_connection import BlockingChannel +import aio_pika +from aio_pika.abc import AbstractChannel router = APIRouter() security = HTTPBearer() +class CustomJSONEncoder(JSONEncoder): + def default(self, obj): + if isinstance(obj, PydanticObjectId): + return str(obj) + # Handle other non-serializable types if needed + return super().default(obj) + async def _resubmit_file_extractors( file: FileOut, - rabbitmq_client: BlockingChannel, + rabbitmq_client: AbstractChannel, user: UserOut, credentials: HTTPAuthorizationCredentials = Security(security), ): @@ -85,7 +96,7 @@ async def add_file_entry( user: UserOut, fs: Minio, es: Elasticsearch, - rabbitmq_client: BlockingChannel, + rabbitmq_client: AbstractChannel, file: Optional[io.BytesIO] = None, content_type: Optional[str] = None, public: bool = False, @@ -135,23 +146,44 @@ async def add_file_entry( # Add entry to the file index await index_file(es, FileOut(**new_file.dict())) - # TODO - timing issue here, check_feed_listeners needs to happen asynchronously. - time.sleep(1) + # Publish a message when indexing is complete - # Submit file job to any qualifying feeds - await check_feed_listeners( - es, - FileOut(**new_file.dict()), - user, - rabbitmq_client, + + # FIXED: Use aio_pika publishing + message_body = { + "event_type": "file_indexed", + "file_data": json.loads(new_file.json()), + "user": json.loads(user.json()), + "timestamp": datetime.now().isoformat() + } + + # Get the exchange first + exchange = await rabbitmq_client.get_exchange("clowder") + + # Use aio_pika publish method + await exchange.publish( + aio_pika.Message( + body=json.dumps(message_body).encode('utf-8'), + content_type="application/json", + delivery_mode=aio_pika.DeliveryMode.PERSISTENT, + ), + routing_key="file_indexed_events", ) + # Submit file job to any qualifying feeds + # await check_feed_listeners( + # es, + # FileOut(**new_file.dict()), + # user, + # rabbitmq_client, + # ) + async def add_local_file_entry( new_file: FileDB, user: UserOut, es: Elasticsearch, - rabbitmq_client: BlockingChannel, + rabbitmq_client: AbstractChannel, content_type: Optional[str] = None, ): """Insert FileDB object into MongoDB (makes Clowder ID). Bytes are not stored in DB and versioning not supported @@ -163,17 +195,35 @@ async def add_local_file_entry( # Add entry to the file index await index_file(es, FileOut(**new_file.dict())) - - # TODO - timing issue here, check_feed_listeners needs to happen asynchronously. - time.sleep(1) + # Publish a message when indexing is complete + + message_body = { + "event_type": "file_indexed", + "file_data": json.loads(new_file.json()), + "user": json.loads(user.json()), + "timestamp": datetime.now().isoformat() + } + + # Get the exchange first + exchange = await rabbitmq_client.get_exchange("clowder") + + # Use aio_pika publish method + await exchange.publish( + aio_pika.Message( + body=json.dumps(message_body).encode('utf-8'), + content_type="application/json", + delivery_mode=aio_pika.DeliveryMode.PERSISTENT, + ), + routing_key="file_indexed_events", + ) # Submit file job to any qualifying feeds - await check_feed_listeners( - es, - FileOut(**new_file.dict()), - user, - rabbitmq_client, - ) + # await check_feed_listeners( + # es, + # FileOut(**new_file.dict()), + # user, + # rabbitmq_client, + # ) # TODO: Move this to MongoDB middle layer @@ -218,7 +268,7 @@ async def update_file( file: UploadFile = File(...), es: Elasticsearch = Depends(dependencies.get_elasticsearchclient), credentials: HTTPAuthorizationCredentials = Security(security), - rabbitmq_client: BlockingChannel = Depends(dependencies.get_rabbitmq), + rabbitmq_client: AbstractChannel = Depends(dependencies.get_rabbitmq), allow: bool = Depends(FileAuthorization("uploader")), ): # Check all connection and abort if any one of them is not available @@ -556,7 +606,7 @@ async def post_file_extract( parameters: dict = None, user=Depends(get_current_user), credentials: HTTPAuthorizationCredentials = Security(security), - rabbitmq_client: BlockingChannel = Depends(dependencies.get_rabbitmq), + rabbitmq_client: AbstractChannel = Depends(dependencies.get_rabbitmq), allow: bool = Depends(FileAuthorization("uploader")), ): if extractorName is None: @@ -583,7 +633,7 @@ async def resubmit_file_extractions( file_id: str, user=Depends(get_current_user), credentials: HTTPAuthorizationCredentials = Security(security), - rabbitmq_client: BlockingChannel = Depends(dependencies.get_rabbitmq), + rabbitmq_client: AbstractChannel = Depends(dependencies.get_rabbitmq), allow: bool = Depends(FileAuthorization("editor")), ): """This route will check metadata. We get the extractors run from metadata from extractors. diff --git a/backend/message_listener.py b/backend/message_listener.py index acc55a67d..85f161f36 100644 --- a/backend/message_listener.py +++ b/backend/message_listener.py @@ -6,7 +6,14 @@ import string import time from datetime import datetime - +from app.models.files import ( + FileDB, + FileOut, +) +from app.models.users import ( + UserOut, +) +from app.routers.feeds import check_feed_listeners from aio_pika import connect_robust from aio_pika.abc import AbstractIncomingMessage from app.main import startup_beanie @@ -16,6 +23,8 @@ EventListenerJobStatus, EventListenerJobUpdateDB, ) +import os +from app.config import settings from beanie import PydanticObjectId logging.basicConfig(level=logging.INFO) @@ -24,6 +33,8 @@ timeout = 5 * 60 # five minute timeout time_ran = 0 +from app.dependencies import get_elasticsearchclient, get_rabbitmq + def parse_message_status(msg): @@ -85,88 +96,114 @@ def parse_message_status(msg): return {"status": EventListenerJobStatus.PROCESSING, "cleaned_msg": msg} -async def callback(message: AbstractIncomingMessage): - """This method receives messages from RabbitMQ and processes them. - the extractor info is parsed from the message and if the extractor is new - or is a later version, the db is updated. - """ +async def callback(message: AbstractIncomingMessage, es, rabbitmq_client): + """This method receives messages from RabbitMQ and processes them.""" async with message.process(): msg = json.loads(message.body.decode("utf-8")) - job_id = msg["job_id"] - message_str = msg["status"] - timestamp = datetime.strptime( - msg["start"], "%Y-%m-%dT%H:%M:%S%z" - ) # incoming format: '2023-01-20T08:30:27-05:00' - timestamp = timestamp.replace(tzinfo=datetime.utcnow().tzinfo) + if "event_type" in msg and msg["event_type"] == "file_indexed": + logger.info(f"This is an event type file indexed!") - # TODO: Updating an event message could go in rabbitmq/listeners + # Convert string IDs back to PydanticObjectId if needed + file_data = msg.get("file_data", {}) + user_data = msg.get("user", {}) # Fixed variable name - # Check if the job exists, and update if so - job = await EventListenerJobDB.find_one( - EventListenerJobDB.id == PydanticObjectId(job_id) - ) - if job: - # Update existing job with new info - job.updated = timestamp - parsed = parse_message_status(message_str) - cleaned_msg = parsed["cleaned_msg"] - incoming_status = parsed["status"] - - # Don't override a finished status if a message comes in late - if job.status in [ - EventListenerJobStatus.SUCCEEDED, - EventListenerJobStatus.ERROR, - EventListenerJobStatus.SKIPPED, - ]: - cleaned_status = job.status - else: - cleaned_status = incoming_status - - # Prepare fields to update based on status (don't overwrite whole object to avoid async issues) - field_updates = { - EventListenerJobDB.status: cleaned_status, - EventListenerJobDB.latest_message: cleaned_msg, - EventListenerJobDB.updated: timestamp, - } - - if job.started is not None: - field_updates[EventListenerJobDB.duration] = ( - timestamp - job.started - ).total_seconds() - elif incoming_status == EventListenerJobStatus.STARTED: - field_updates[EventListenerJobDB.duration] = 0 - - logger.info(f"[{job_id}] {timestamp} {incoming_status.value} {cleaned_msg}") - - # Update the job timestamps/duration depending on what status we received - if incoming_status == EventListenerJobStatus.STARTED: - field_updates[EventListenerJobDB.started] = timestamp - elif incoming_status in [ - EventListenerJobStatus.SUCCEEDED, - EventListenerJobStatus.ERROR, - EventListenerJobStatus.SKIPPED, - ]: - # job.finished = timestamp - field_updates[EventListenerJobDB.finished] = timestamp - - await job.set(field_updates) - - # Add latest message to the job updates - event_msg = EventListenerJobUpdateDB( - job_id=job_id, status=cleaned_msg, timestamp=timestamp + if "id" in file_data and isinstance(file_data["id"], str): + file_data["id"] = PydanticObjectId(file_data["id"]) + + # Create FileOut object + file_out = FileOut(**file_data) + + # Create UserOut object from the user data in the message + user = UserOut(**user_data) # Use user_data, not user + + # Now call check_feed_listeners with the injected dependencies + await check_feed_listeners( + es, # Elasticsearch client + file_out, + user, + rabbitmq_client, # RabbitMQ client ) - await event_msg.insert() return True + else: - # We don't know what this job is. Reject the message. - logger.error("Job ID %s not found in database, skipping message." % job_id) - return False + job_id = msg["job_id"] + message_str = msg["status"] + timestamp = datetime.strptime( + msg["start"], "%Y-%m-%dT%H:%M:%S%z" + ) # incoming format: '2023-01-20T08:30:27-05:00' + timestamp = timestamp.replace(tzinfo=datetime.utcnow().tzinfo) + + # TODO: Updating an event message could go in rabbitmq/listeners + + # Check if the job exists, and update if so + job = await EventListenerJobDB.find_one( + EventListenerJobDB.id == PydanticObjectId(job_id) + ) + if job: + # Update existing job with new info + job.updated = timestamp + parsed = parse_message_status(message_str) + cleaned_msg = parsed["cleaned_msg"] + incoming_status = parsed["status"] + + # Don't override a finished status if a message comes in late + if job.status in [ + EventListenerJobStatus.SUCCEEDED, + EventListenerJobStatus.ERROR, + EventListenerJobStatus.SKIPPED, + ]: + cleaned_status = job.status + else: + cleaned_status = incoming_status + + # Prepare fields to update based on status (don't overwrite whole object to avoid async issues) + field_updates = { + EventListenerJobDB.status: cleaned_status, + EventListenerJobDB.latest_message: cleaned_msg, + EventListenerJobDB.updated: timestamp, + } + + if job.started is not None: + field_updates[EventListenerJobDB.duration] = ( + timestamp - job.started + ).total_seconds() + elif incoming_status == EventListenerJobStatus.STARTED: + field_updates[EventListenerJobDB.duration] = 0 + + logger.info(f"[{job_id}] {timestamp} {incoming_status.value} {cleaned_msg}") + + # Update the job timestamps/duration depending on what status we received + if incoming_status == EventListenerJobStatus.STARTED: + field_updates[EventListenerJobDB.started] = timestamp + elif incoming_status in [ + EventListenerJobStatus.SUCCEEDED, + EventListenerJobStatus.ERROR, + EventListenerJobStatus.SKIPPED, + ]: + # job.finished = timestamp + field_updates[EventListenerJobDB.finished] = timestamp + + await job.set(field_updates) + + # Add latest message to the job updates + event_msg = EventListenerJobUpdateDB( + job_id=job_id, status=cleaned_msg, timestamp=timestamp + ) + await event_msg.insert() + return True + else: + # We don't know what this job is. Reject the message. + logger.error("Job ID %s not found in database, skipping message." % job_id) + return False async def listen_for_messages(): await startup_beanie() + # Initialize dependencies using your existing functions + es = await get_elasticsearchclient() + # For some reason, Pydantic Settings environment variable overrides aren't being applied, so get them here. RABBITMQ_USER = os.getenv("RABBITMQ_USER", "guest") RABBITMQ_PASS = os.getenv("RABBITMQ_PASS", "guest") @@ -206,10 +243,16 @@ async def listen_for_messages(): durable=True, ) await queue.bind(exchange) + await queue.bind(exchange, routing_key="file_indexed_events") # Add this line logger.info(f" [*] Listening to {exchange}") + + # Create a partial function that includes the dependencies + from functools import partial + callback_with_deps = partial(callback, es=es, rabbitmq_client=channel) + await queue.consume( - callback=callback, + callback=callback_with_deps, no_ack=False, ) @@ -219,9 +262,11 @@ async def listen_for_messages(): await asyncio.Future() finally: await connection.close() + await es.close() # Close ES connection when done if __name__ == "__main__": + logger.info(" Message listener starting...") start = datetime.now() while time_ran < timeout: try: diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index e396d5f35..7d34f513d 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -157,7 +157,7 @@ services: - rabbitmq extractors-messages: - image: "clowder/clowder2-messages:latest" + image: "clowder2-messages:test" build: dockerfile: backend/messages.Dockerfile environment: