From 31dd93036485a5958046ee1b77646790fb87fe07 Mon Sep 17 00:00:00 2001 From: onstabb Date: Tue, 12 Dec 2023 10:30:53 +0100 Subject: [PATCH] - Added fixes for contact service functions --- src/contacts/routers.py | 5 ++--- src/contacts/service.py | 30 +++++++++++++++++------------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/contacts/routers.py b/src/contacts/routers.py index dbfed06..13e7c4e 100644 --- a/src/contacts/routers.py +++ b/src/contacts/routers.py @@ -43,7 +43,6 @@ def create_contact(data_in: ContactCreateDataIn, current_user: CurrentActiveComp ) contact = service.create_contact_by_initiator(current_user, target_user, data_in) - if data_in.action == ContactState.ESTABLISHED: notification_manager.put_notification( UserPublicOut.model_validate(current_user, from_attributes=True), @@ -60,7 +59,7 @@ def update_contact_state( current_user: CurrentActiveCompletedUser, state_data: ContactStateIn, ): - contact = service.get_contact_by_users_pair(current_user.id, target_user.id) + contact = service.get_contact_by_users_pair(current_user, target_user, use_id_for_current_user=False) if not contact: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Contact doesn't exists") @@ -84,7 +83,7 @@ def send_message( current_user: CurrentActiveCompletedUser, target_user: TargetActiveCompletedUser, ): - contact = service.get_contact_by_users_pair(current_user, target_user) + contact = service.get_contact_by_users_pair(current_user, target_user, use_id_for_current_user=False) if not contact or not contact.established: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Contact must be established") diff --git a/src/contacts/service.py b/src/contacts/service.py index 54dd0b6..f05e8fa 100644 --- a/src/contacts/service.py +++ b/src/contacts/service.py @@ -51,23 +51,26 @@ def get_contacts_for_user(user: User, *, limit: int = 0, **filters) -> list[dict return result -def get_contact_by_users_pair(user: User, target_user: User) -> Contact | None: +def get_contact_by_users_pair( + current_user: User, target_user: User, use_id_for_current_user: bool = True +) -> Contact | None: """ Retrieve a Contact instance for the given pair of users. - :param user: The first user in the pair. + :param current_user: The first user in the pair. :param target_user: The second user in the pair. + :param use_id_for_current_user: if True then current_user object in found contact will be changed to the ObjectId :return: A Contact instance if it exists. - Note: if a contact is found, this function updates the Contact instance by setting - the appropriate user ID based on the role of the provided 'user' in the contact. If 'user' - is the respondent, the 'respondent' attribute in the Contact instance is set to the user's ID. - If 'user' is the initiator, the 'initiator' attribute is set to the user's ID. + Note: if a contact is found and parameter `use_id_for_current_user` is True, + this function updates the Contact instance by setting the appropriate user ID based on the role of the provided + 'user' in the contact. If 'user' is the respondent, the 'respondent' attribute in the Contact instance is set to the + user's ID. If 'user' is the initiator, the 'initiator' attribute is set to the user's ID. """ query = ( - (Query(initiator=user) & Query(respondent=target_user)) | - (Query(respondent=user) & Query(initiator=target_user)) + (Query(initiator=current_user) & Query(respondent=target_user)) | + (Query(respondent=current_user) & Query(initiator=target_user)) ) try: contact: Contact = Contact.objects.get(query) @@ -75,10 +78,11 @@ def get_contact_by_users_pair(user: User, target_user: User) -> Contact | None: return None # From this we cand understand who is target user - if user == contact.respondent: - contact.respondent = user.id - else: - contact.initiator = user.id + if use_id_for_current_user: + if current_user == contact.respondent: + contact.respondent = current_user.id + else: + contact.initiator = current_user.id return contact @@ -109,7 +113,7 @@ def create_message(contact: Contact, sender: User, message_in: MessageIn) -> Mes def get_messages_count_from_sender(contact: Contact, sender: User) -> int: - if sender.id not in (contact.initiator, contact.respondent): + if sender not in (contact.initiator, contact.respondent): raise ValueError(f"Contact does not contain user with {sender.id}") return len(contact.messages.filter(sender=sender))