From d30a2e48974468b3a5c02292edbabcd6cb5ffc31 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 14 Apr 2026 14:31:25 +0800 Subject: [PATCH 1/2] Fix org member identity mapping for sync and SSO --- backend/app/api/auth.py | 5 +- backend/app/api/feishu.py | 6 +- backend/app/services/channel_user_service.py | 44 ++++++-- backend/app/services/org_sync_adapter.py | 83 ++++++++++++--- backend/app/services/registration_service.py | 16 +-- backend/app/services/sso_service.py | 104 +++++++++++++------ backend/tests/test_identity_id_mapping.py | 56 ++++++++++ backend/tests/test_org_sync_adapter.py | 48 +++++++++ 8 files changed, 296 insertions(+), 66 deletions(-) create mode 100644 backend/tests/test_identity_id_mapping.py create mode 100644 backend/tests/test_org_sync_adapter.py diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index ef76429b5..296031954 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -971,7 +971,8 @@ async def bind_identity( user_info = await auth_provider.get_user_info(access_token) # Check if identity is already linked to another user - existing_user = await sso_service.check_duplicate_identity(db, provider, user_info.provider_user_id) + lookup_provider_user_id = user_info.provider_union_id or user_info.provider_user_id + existing_user = await sso_service.check_duplicate_identity(db, provider, lookup_provider_user_id) if existing_user and existing_user.id != current_user.id: raise HTTPException( status_code=409, @@ -983,7 +984,7 @@ async def bind_identity( db, str(current_user.id), provider, - user_info.provider_user_id, + lookup_provider_user_id, user_info.raw_data, ) diff --git a/backend/app/api/feishu.py b/backend/app/api/feishu.py index aff83664a..a26cc254b 100644 --- a/backend/app/api/feishu.py +++ b/backend/app/api/feishu.py @@ -510,7 +510,8 @@ async def process_feishu_event(agent_id: uuid.UUID, body: dict, db: AsyncSession "email": sender_email, "mobile": _user_info.get("mobile"), "avatar_url": _avatar_url, - "unionid": _user_info.get("user_id"), # tenant-level user_id + "external_id": _user_info.get("user_id"), + "unionid": _user_info.get("union_id"), "open_id": sender_open_id, } logger.info(f"[Feishu] Resolved sender: {sender_name} (user_id={sender_user_id_feishu})") @@ -1185,7 +1186,8 @@ async def _handle_feishu_file(db, agent_id, config, message, sender_open_id, cha "avatar_url": _avatar_url, "email": _user_info.get("email"), "mobile": _user_info.get("mobile"), - "unionid": _user_info.get("user_id"), + "external_id": _user_info.get("user_id"), + "unionid": _user_info.get("union_id"), "open_id": sender_open_id, } except Exception: diff --git a/backend/app/services/channel_user_service.py b/backend/app/services/channel_user_service.py index 56fd90d41..08a09ccac 100644 --- a/backend/app/services/channel_user_service.py +++ b/backend/app/services/channel_user_service.py @@ -23,6 +23,27 @@ class ChannelUserService: """Service for resolving channel users via OrgMember and SSO patterns.""" + def _get_channel_ids( + self, + channel_type: str, + external_user_id: str, + extra_info: dict[str, Any], + ) -> tuple[str | None, str | None, str | None]: + unionid = (extra_info.get("unionid") or extra_info.get("union_id") or "").strip() or None + open_id = (extra_info.get("open_id") or "").strip() or None + external_id = (extra_info.get("external_id") or external_user_id or "").strip() or None + + if channel_type == "feishu": + open_id = open_id or external_user_id + external_id = (extra_info.get("external_id") or "").strip() or None + elif channel_type == "dingtalk": + open_id = open_id or None + elif channel_type == "wecom": + unionid = None + open_id = open_id or None + + return unionid, open_id, external_id + async def resolve_channel_user( self, db: AsyncSession, @@ -104,13 +125,15 @@ async def resolve_channel_user( db, user.id, provider.id, tenant_id ) if existing_member: - # Reuse the org-synced record: update its channel-specific IDs - # so future lookups by external_id work without a new shell. - if channel_type == "feishu": - if external_user_id.startswith("on_"): - existing_member.unionid = existing_member.unionid or external_user_id - elif external_user_id.startswith("ou_"): - existing_member.open_id = existing_member.open_id or external_user_id + unionid, open_id, external_id = self._get_channel_ids( + channel_type, external_user_id, extra_info + ) + if unionid and not existing_member.unionid: + existing_member.unionid = unionid + if open_id and not existing_member.open_id: + existing_member.open_id = open_id + if external_id and not existing_member.external_id: + existing_member.external_id = external_id logger.info( f"[{channel_type}] Reusing org-synced OrgMember {existing_member.id} " f"for user {user.id} instead of creating a duplicate shell" @@ -232,6 +255,7 @@ async def _create_org_member_shell( ) -> OrgMember: """Create a shell OrgMember record for this identity.""" name = extra_info.get("name") or f"{channel_type.capitalize()} User {external_user_id[:8]}" + unionid, open_id, external_id = self._get_channel_ids(channel_type, external_user_id, extra_info) member = OrgMember( name=name, @@ -239,9 +263,9 @@ async def _create_org_member_shell( provider_id=provider.id, user_id=linked_user_id, tenant_id=provider.tenant_id, - external_id=external_user_id, - unionid=extra_info.get("unionid"), - open_id=extra_info.get("open_id"), + external_id=external_id, + unionid=unionid, + open_id=open_id, avatar_url=extra_info.get("avatar_url"), phone=extra_info.get("mobile"), title=extra_info.get("title", ""), diff --git a/backend/app/services/org_sync_adapter.py b/backend/app/services/org_sync_adapter.py index 8a37984eb..e0e129c14 100644 --- a/backend/app/services/org_sync_adapter.py +++ b/backend/app/services/org_sync_adapter.py @@ -386,6 +386,7 @@ async def _upsert_member( ) -> dict[str, Any]: """Insert or update a member, platform user, and identity.""" stats = {"user_created": False, "profile_synced": False} + self._validate_member_identifiers(provider, user) # Find department using user's actual department list. # DingTalk's dept_id_list last item is the most specific (leaf) department. @@ -413,23 +414,7 @@ async def _upsert_member( ) department = dept_result.scalars().first() - # Check if exists by unionid or external_id or open_id (any matches), and provider - conditions = [] - if user.unionid: - conditions.append(OrgMember.unionid == user.unionid) - if user.external_id: - conditions.append(OrgMember.external_id == user.external_id) - if user.open_id: - conditions.append(OrgMember.open_id == user.open_id) - - if conditions: - result = await db.execute( - select(OrgMember).where( - OrgMember.provider_id == provider.id, - or_(*conditions) - ) - ) - existing_member = result.scalars().first() + existing_member = await self._find_existing_member(db, provider, user) now = datetime.now() @@ -535,6 +520,70 @@ async def _upsert_member( await db.flush() return stats + def _provider_requires_unionid(self, provider: IdentityProvider) -> bool: + provider_type = (provider.provider_type or self.provider_type or "").lower() + return provider_type in {"feishu", "dingtalk"} + + def _validate_member_identifiers(self, provider: IdentityProvider, user: ExternalUser) -> None: + user.unionid = (user.unionid or "").strip() + user.external_id = (user.external_id or "").strip() + user.open_id = (user.open_id or "").strip() + + if self._provider_requires_unionid(provider) and not user.unionid: + raise ValueError( + f"unionid is required for {provider.provider_type} org sync user {user.external_id or user.name}" + ) + + if user.unionid and user.external_id and user.unionid == user.external_id: + raise ValueError( + f"invalid unionid for org sync user {user.external_id or user.name}: unionid must not equal external_id" + ) + + async def _find_existing_member( + self, + db: AsyncSession, + provider: IdentityProvider, + user: ExternalUser, + ) -> OrgMember | None: + if user.unionid: + result = await db.execute( + select(OrgMember).where( + OrgMember.provider_id == provider.id, + OrgMember.unionid == user.unionid, + ) + ) + existing_member = result.scalars().first() + if existing_member: + return existing_member + + fallback_conditions = [] + if user.external_id: + fallback_conditions.append(OrgMember.external_id == user.external_id) + if user.open_id: + fallback_conditions.append(OrgMember.open_id == user.open_id) + + if not fallback_conditions: + return None + + fallback_query = select(OrgMember).where( + OrgMember.provider_id == provider.id, + or_(*fallback_conditions), + ) + + # When unionid is required, only allow external/open id fallback to attach + # shell records that do not have a conflicting unionid yet. + if self._provider_requires_unionid(provider) and user.unionid: + fallback_query = fallback_query.where( + or_( + OrgMember.unionid.is_(None), + OrgMember.unionid == "", + OrgMember.unionid == user.unionid, + ) + ) + + result = await db.execute(fallback_query) + return result.scalars().first() + async def _resolve_platform_user(self, db: AsyncSession, user: ExternalUser) -> User | None: """Resolve platform user from external user info.""" # 1. Try by Email matching (primary way now) diff --git a/backend/app/services/registration_service.py b/backend/app/services/registration_service.py index 1d748b6b8..b319ad6a1 100644 --- a/backend/app/services/registration_service.py +++ b/backend/app/services/registration_service.py @@ -267,7 +267,10 @@ async def handle_sso_registration( tenant_id = tenant.id if tenant else None # Check if identity already exists - existing = await sso_service.resolve_user_identity(db, provider_user_id, provider_type, tenant_id=tenant_id) + lookup_provider_user_id = user_info.get("union_id") or user_info.get("unionId") or provider_user_id + existing = await sso_service.resolve_user_identity( + db, lookup_provider_user_id, provider_type, tenant_id=tenant_id + ) if existing: # Identity already linked @@ -279,7 +282,7 @@ async def handle_sso_registration( db, str(existing_user.id), provider_type, - provider_user_id, + lookup_provider_user_id, user_info, tenant_id=str(existing_user.tenant_id) if existing_user.tenant_id else tenant_id, ) @@ -360,8 +363,9 @@ async def register_with_sso( tenant_id = tenant.id if tenant else None # Try to find existing user by identity + lookup_provider_user_id = user_info_obj.provider_union_id or user_info_obj.provider_user_id existing_user = await sso_service.resolve_user_identity( - db, user_info_obj.provider_user_id, provider_type, tenant_id=tenant_id + db, lookup_provider_user_id, provider_type, tenant_id=tenant_id ) if existing_user: @@ -377,7 +381,7 @@ async def register_with_sso( db, str(existing_by_email.id), provider_type, - user_info_obj.provider_user_id, + lookup_provider_user_id, user_info, tenant_id=str(existing_by_email.tenant_id) if existing_by_email.tenant_id else tenant_id, ) @@ -387,7 +391,7 @@ async def register_with_sso( user, is_new = await self.handle_sso_registration( db, provider_type, - user_info_obj.provider_user_id, + lookup_provider_user_id, user_info, ) @@ -525,4 +529,4 @@ async def sync_org_member_contact_from_user( # Global registration service -registration_service = RegistrationService() \ No newline at end of file +registration_service = RegistrationService() diff --git a/backend/app/services/sso_service.py b/backend/app/services/sso_service.py index e60f47338..dcca54701 100644 --- a/backend/app/services/sso_service.py +++ b/backend/app/services/sso_service.py @@ -197,6 +197,49 @@ async def resolve_user_identity( ) return user_result.scalar_one_or_none() + def _get_identity_payload(self, identity_data: dict[str, Any] | None) -> dict[str, Any]: + if not identity_data: + return {} + raw_data = identity_data.get("raw_data") + if isinstance(raw_data, dict): + return raw_data + return identity_data + + def _extract_identity_ids( + self, + provider_type: str, + provider_user_id: str, + identity_data: dict[str, Any] | None, + ) -> tuple[str | None, str | None, str | None]: + payload = self._get_identity_payload(identity_data) + identity_data = identity_data or {} + + raw_open_id = ( + payload.get("open_id") + or payload.get("openId") + or identity_data.get("open_id") + or identity_data.get("openId") + ) + raw_union_id = ( + payload.get("union_id") + or payload.get("unionId") + or identity_data.get("union_id") + or identity_data.get("unionId") + ) + + external_id = None + if provider_type == "feishu": + external_id = payload.get("user_id") + elif provider_type == "dingtalk": + external_id = payload.get("userid") or payload.get("staffId") + elif provider_type == "wecom": + external_id = provider_user_id + + open_id = (raw_open_id or "").strip() or None + union_id = (raw_union_id or "").strip() or None + external_id = (external_id or "").strip() or None + return union_id, open_id, external_id + async def link_identity( self, db: AsyncSession, @@ -240,32 +283,33 @@ async def link_identity( uid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id - # Extract the raw open_id from identity_data (raw provider response). - # For Feishu: raw_data has 'open_id' and 'union_id' as separate fields. - # For DingTalk: raw_data has 'openId' and 'unionId'. - # Storing open_id separately prevents duplicate user creation when the - # lookup key alternates between open_id and union_id across SSO sessions. - raw_open_id = None - if identity_data: - raw_open_id = ( - identity_data.get("open_id") # Feishu - or identity_data.get("openId") # DingTalk - ) + # Extract canonical provider IDs from the raw payload. Some callers wrap + # the provider response in {"raw_data": ...}, so we normalize that here. + raw_union_id, raw_open_id, raw_external_id = self._extract_identity_ids( + provider_type, provider_user_id, identity_data + ) # Check if OrgMember already exists for this provider user. # Search across unionid, external_id, and open_id to handle the case where # the lookup key differs between sync (uses user_id/employee_id as external_id) # and SSO (uses union_id or open_id as provider_user_id). - conditions = [ - OrgMember.unionid == provider_user_id, - OrgMember.external_id == provider_user_id, - OrgMember.open_id == provider_user_id, - ] - if raw_open_id and raw_open_id != provider_user_id: - # Also search by the actual open_id from raw data, in case the member - # was created with open_id as its primary key (e.g. from a previous SSO login) - conditions.append(OrgMember.open_id == raw_open_id) - conditions.append(OrgMember.external_id == raw_open_id) + lookup_ids = {provider_user_id} + if raw_union_id: + lookup_ids.add(raw_union_id) + if raw_open_id: + lookup_ids.add(raw_open_id) + if raw_external_id: + lookup_ids.add(raw_external_id) + lookup_ids.discard("") + lookup_ids.discard(None) + + conditions = [] + for lookup_id in lookup_ids: + conditions.extend([ + OrgMember.unionid == lookup_id, + OrgMember.external_id == lookup_id, + OrgMember.open_id == lookup_id, + ]) member_query = select(OrgMember).where( OrgMember.provider_id == provider.id, @@ -279,10 +323,16 @@ async def link_identity( # Always link user member.user_id = uid - # Fill in open_id if not already set — prevents future lookup misses + if raw_external_id and not member.external_id: + member.external_id = raw_external_id + if raw_open_id and not member.open_id: member.open_id = raw_open_id + if raw_union_id and member.unionid != raw_union_id: + if not member.unionid or member.unionid in {provider_user_id, member.open_id, member.external_id}: + member.unionid = raw_union_id + # Passive identity enrichment: update profile fields from SSO data. # OrgMember records created by org-sync may have placeholder values # (e.g. name=userid, no avatar/email). We fill them in here so they @@ -331,12 +381,8 @@ async def link_identity( provider_id=provider.id, user_id=uid, tenant_id=tenant_id, - # For Feishu/DingTalk: external_id stores union_id (cross-app stable). - # open_id is stored separately so it can also be matched on next login. - external_id=provider_user_id, - unionid=provider_user_id if provider_type != "wecom" else None, - # Explicitly store the raw open_id so future SSO lookups can match on it - # even if the lookup key is union_id (and vice versa). + external_id=raw_external_id, + unionid=raw_union_id if provider_type != "wecom" else None, open_id=raw_open_id, ) db.add(member) @@ -466,4 +512,4 @@ def add_domain_hint(self, domain: str, tenant_id: str): # Global SSO service instance -sso_service = SSOService() \ No newline at end of file +sso_service = SSOService() diff --git a/backend/tests/test_identity_id_mapping.py b/backend/tests/test_identity_id_mapping.py new file mode 100644 index 000000000..0827b716c --- /dev/null +++ b/backend/tests/test_identity_id_mapping.py @@ -0,0 +1,56 @@ +from app.services.channel_user_service import ChannelUserService +from app.services.sso_service import sso_service + + +def test_sso_extract_identity_ids_uses_real_union_id_not_open_id(): + union_id, open_id, external_id = sso_service._extract_identity_ids( + "feishu", + "ou_open_123", + { + "raw_data": { + "open_id": "ou_open_123", + "union_id": "on_union_456", + "user_id": "u_emp_789", + } + }, + ) + + assert union_id == "on_union_456" + assert open_id == "ou_open_123" + assert external_id == "u_emp_789" + + +def test_sso_extract_identity_ids_handles_registration_wrapped_payload(): + union_id, open_id, external_id = sso_service._extract_identity_ids( + "dingtalk", + "open_123", + { + "name": "Alice", + "raw_data": { + "openId": "open_123", + "unionId": "union_456", + }, + }, + ) + + assert union_id == "union_456" + assert open_id == "open_123" + assert external_id is None + + +def test_channel_user_service_keeps_feishu_user_id_out_of_unionid(): + service = ChannelUserService() + + union_id, open_id, external_id = service._get_channel_ids( + "feishu", + "ou_open_123", + { + "external_id": "u_emp_789", + "unionid": "on_union_456", + "open_id": "ou_open_123", + }, + ) + + assert union_id == "on_union_456" + assert open_id == "ou_open_123" + assert external_id == "u_emp_789" diff --git a/backend/tests/test_org_sync_adapter.py b/backend/tests/test_org_sync_adapter.py new file mode 100644 index 000000000..39a7071e0 --- /dev/null +++ b/backend/tests/test_org_sync_adapter.py @@ -0,0 +1,48 @@ +from types import SimpleNamespace + +import pytest + +from app.services.org_sync_adapter import BaseOrgSyncAdapter, ExternalUser + + +class _DummyAdapter(BaseOrgSyncAdapter): + provider_type = "feishu" + + @property + def api_base_url(self) -> str: + return "https://example.com" + + async def get_access_token(self) -> str: + return "token" + + async def fetch_departments(self): + return [] + + async def fetch_users(self, department_external_id: str): + return [] + + +def test_validate_member_identifiers_requires_unionid_for_feishu(): + adapter = _DummyAdapter() + provider = SimpleNamespace(provider_type="feishu") + user = ExternalUser(external_id="ou_123", name="Alice", unionid="") + + with pytest.raises(ValueError, match="unionid is required"): + adapter._validate_member_identifiers(provider, user) + + +def test_validate_member_identifiers_rejects_unionid_equal_to_external_id(): + adapter = _DummyAdapter() + provider = SimpleNamespace(provider_type="dingtalk") + user = ExternalUser(external_id="same-id", name="Bob", unionid="same-id") + + with pytest.raises(ValueError, match="must not equal external_id"): + adapter._validate_member_identifiers(provider, user) + + +def test_validate_member_identifiers_allows_wecom_without_unionid(): + adapter = _DummyAdapter() + provider = SimpleNamespace(provider_type="wecom") + user = ExternalUser(external_id="zhangsan", name="Zhang San", unionid="") + + adapter._validate_member_identifiers(provider, user) From 4101cfbd034342af0534b0c98eb84f95fee7894e Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 14 Apr 2026 22:36:28 +0800 Subject: [PATCH 2/2] fix: unify SSO identity fallback order --- backend/app/api/auth.py | 7 +- backend/app/services/auth_provider.py | 17 +-- backend/app/services/org_sync_adapter.py | 16 ++- backend/app/services/registration_service.py | 12 +- backend/app/services/sso_service.py | 140 ++++++++++++------- backend/tests/test_identity_id_mapping.py | 20 +++ backend/tests/test_org_sync_adapter.py | 54 +++++++ 7 files changed, 201 insertions(+), 65 deletions(-) diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index 296031954..d6a2138ae 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -972,7 +972,12 @@ async def bind_identity( # Check if identity is already linked to another user lookup_provider_user_id = user_info.provider_union_id or user_info.provider_user_id - existing_user = await sso_service.check_duplicate_identity(db, provider, lookup_provider_user_id) + existing_user = await sso_service.check_duplicate_identity( + db, + provider, + lookup_provider_user_id, + identity_data=user_info.raw_data, + ) if existing_user and existing_user.id != current_user.id: raise HTTPException( status_code=409, diff --git a/backend/app/services/auth_provider.py b/backend/app/services/auth_provider.py index d40cd583b..ed51d18e1 100644 --- a/backend/app/services/auth_provider.py +++ b/backend/app/services/auth_provider.py @@ -110,19 +110,12 @@ async def find_or_create_user( # 1. Try lookup via sso_service (which now uses OrgMember) provider_user_id = user_info.provider_union_id or user_info.provider_user_id user = await sso_service.resolve_user_identity( - db, provider_user_id, self.provider_type, tenant_id=tenant_id + db, + provider_user_id, + self.provider_type, + tenant_id=tenant_id, + identity_data=user_info.raw_data, ) - - # Feishu: fallback to open_id if union_id lookup misses - if ( - not user - and self.provider_type == "feishu" - and user_info.provider_union_id - and user_info.provider_user_id - ): - user = await sso_service.resolve_user_identity( - db, user_info.provider_user_id, self.provider_type, tenant_id=tenant_id - ) is_new = False if not user: diff --git a/backend/app/services/org_sync_adapter.py b/backend/app/services/org_sync_adapter.py index e0e129c14..166b8fc79 100644 --- a/backend/app/services/org_sync_adapter.py +++ b/backend/app/services/org_sync_adapter.py @@ -136,6 +136,7 @@ async def sync_org_structure(self, db: AsyncSession) -> dict[str, Any]: user_count = 0 profile_count = 0 sync_start = datetime.now() + partial_failure = False # Ensure provider exists provider = await self._ensure_provider(db) @@ -149,6 +150,7 @@ async def sync_org_structure(self, db: AsyncSession) -> dict[str, Any]: await self._upsert_department(db, provider, dept) dept_count += 1 except Exception as e: + partial_failure = True errors.append(f"Department {dept.external_id}: {str(e)}") logger.error(f"[OrgSync] Failed to sync department {dept.external_id}: {e}") @@ -157,6 +159,7 @@ async def sync_org_structure(self, db: AsyncSession) -> dict[str, Any]: try: users = await self.fetch_users(dept.external_id) except Exception as e: + partial_failure = True logger.error(f"[OrgSync] Failed to fetch users in department {dept.external_id}: {e}") errors.append(f"Fetch users in dept {dept.external_id}: {str(e)}") continue @@ -171,6 +174,7 @@ async def sync_org_structure(self, db: AsyncSession) -> dict[str, Any]: profile_count += 1 member_count += 1 except Exception as e: + partial_failure = True logger.error(f"[OrgSync] Failed to sync member {user.external_id} ({user.name}): {e}") errors.append(f"Member {user.external_id}: {str(e)}") @@ -181,9 +185,15 @@ async def sync_org_structure(self, db: AsyncSession) -> dict[str, Any]: self.provider.config = config await db.flush() - # Reconciliation: mark records not updated in this sync as deleted - await self._reconcile(db, provider.id, sync_start) - await db.flush() + if partial_failure: + logger.warning( + f"[OrgSync] Skipping reconcile for provider {provider.id} because this sync had partial failures" + ) + errors.append("Reconcile skipped due to partial sync failures") + else: + # Reconciliation: mark records not updated in this sync as deleted + await self._reconcile(db, provider.id, sync_start) + await db.flush() # Recalculate member counts for all departments (crucial for DingTalk/WeCom) await self._update_member_counts(db, provider.id) diff --git a/backend/app/services/registration_service.py b/backend/app/services/registration_service.py index b319ad6a1..dfb06dbb1 100644 --- a/backend/app/services/registration_service.py +++ b/backend/app/services/registration_service.py @@ -269,7 +269,11 @@ async def handle_sso_registration( # Check if identity already exists lookup_provider_user_id = user_info.get("union_id") or user_info.get("unionId") or provider_user_id existing = await sso_service.resolve_user_identity( - db, lookup_provider_user_id, provider_type, tenant_id=tenant_id + db, + lookup_provider_user_id, + provider_type, + tenant_id=tenant_id, + identity_data=user_info, ) if existing: @@ -365,7 +369,11 @@ async def register_with_sso( # Try to find existing user by identity lookup_provider_user_id = user_info_obj.provider_union_id or user_info_obj.provider_user_id existing_user = await sso_service.resolve_user_identity( - db, lookup_provider_user_id, provider_type, tenant_id=tenant_id + db, + lookup_provider_user_id, + provider_type, + tenant_id=tenant_id, + identity_data=user_info, ) if existing_user: diff --git a/backend/app/services/sso_service.py b/backend/app/services/sso_service.py index dcca54701..205e4f610 100644 --- a/backend/app/services/sso_service.py +++ b/backend/app/services/sso_service.py @@ -8,7 +8,7 @@ from typing import Any from loguru import logger -from sqlalchemy import select, or_ +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.identity import IdentityProvider @@ -147,7 +147,12 @@ async def auto_associate_tenant(self, db: AsyncSession, email: str) -> str | Non return None async def resolve_user_identity( - self, db: AsyncSession, provider_user_id: str, provider_type: str, tenant_id: str | None = None + self, + db: AsyncSession, + provider_user_id: str, + provider_type: str, + tenant_id: str | None = None, + identity_data: dict[str, Any] | None = None, ) -> User | None: """Resolve user from external identity via OrgMember. @@ -173,19 +178,13 @@ async def resolve_user_identity( if not provider: return None - # Find OrgMember by unionid, external_id, or open_id - # For Feishu/DingTalk we often use unionid, for WeCom we use external_id (userid) - member_query = select(OrgMember).where( - OrgMember.provider_id == provider.id, - OrgMember.status == "active", - or_( - OrgMember.unionid == provider_user_id, - OrgMember.external_id == provider_user_id, - OrgMember.open_id == provider_user_id - ) + member = await self._find_identity_member( + db, + provider.id, + provider_type, + provider_user_id, + identity_data, ) - member_result = await db.execute(member_query) - member = member_result.scalar_one_or_none() if not member or not member.user_id: return None @@ -240,6 +239,67 @@ def _extract_identity_ids( external_id = (external_id or "").strip() or None return union_id, open_id, external_id + def _identity_lookup_chain( + self, + provider_type: str, + provider_user_id: str, + identity_data: dict[str, Any] | None, + ) -> list[tuple[str, str]]: + raw_union_id, raw_open_id, raw_external_id = self._extract_identity_ids( + provider_type, provider_user_id, identity_data + ) + + lookup_chain: list[tuple[str, str]] = [] + seen: set[tuple[str, str]] = set() + + def add(field: str, value: str | None) -> None: + normalized = (value or "").strip() + key = (field, normalized) + if not normalized or key in seen: + return + seen.add(key) + lookup_chain.append(key) + + add("unionid", raw_union_id) + add("external_id", raw_external_id) + add("open_id", raw_open_id) + + if not lookup_chain: + fallback_id = (provider_user_id or "").strip() + if provider_type == "wecom": + add("external_id", fallback_id) + else: + add("unionid", fallback_id) + add("external_id", fallback_id) + add("open_id", fallback_id) + + return lookup_chain + + async def _find_identity_member( + self, + db: AsyncSession, + provider_id: uuid.UUID, + provider_type: str, + provider_user_id: str, + identity_data: dict[str, Any] | None = None, + ): + from app.models.org import OrgMember + + for field, lookup_value in self._identity_lookup_chain(provider_type, provider_user_id, identity_data): + column = getattr(OrgMember, field) + member_result = await db.execute( + select(OrgMember).where( + OrgMember.provider_id == provider_id, + OrgMember.status == "active", + column == lookup_value, + ) + ) + member = member_result.scalar_one_or_none() + if member: + return member + + return None + async def link_identity( self, db: AsyncSession, @@ -283,41 +343,16 @@ async def link_identity( uid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id - # Extract canonical provider IDs from the raw payload. Some callers wrap - # the provider response in {"raw_data": ...}, so we normalize that here. raw_union_id, raw_open_id, raw_external_id = self._extract_identity_ids( provider_type, provider_user_id, identity_data ) - - # Check if OrgMember already exists for this provider user. - # Search across unionid, external_id, and open_id to handle the case where - # the lookup key differs between sync (uses user_id/employee_id as external_id) - # and SSO (uses union_id or open_id as provider_user_id). - lookup_ids = {provider_user_id} - if raw_union_id: - lookup_ids.add(raw_union_id) - if raw_open_id: - lookup_ids.add(raw_open_id) - if raw_external_id: - lookup_ids.add(raw_external_id) - lookup_ids.discard("") - lookup_ids.discard(None) - - conditions = [] - for lookup_id in lookup_ids: - conditions.extend([ - OrgMember.unionid == lookup_id, - OrgMember.external_id == lookup_id, - OrgMember.open_id == lookup_id, - ]) - - member_query = select(OrgMember).where( - OrgMember.provider_id == provider.id, - OrgMember.status == "active", - or_(*conditions) + member = await self._find_identity_member( + db, + provider.id, + provider_type, + provider_user_id, + identity_data, ) - member_result = await db.execute(member_query) - member = member_result.scalar_one_or_none() if member: # Always link user @@ -436,7 +471,12 @@ async def unlink_identity( return True async def check_duplicate_identity( - self, db: AsyncSession, provider_type: str, provider_user_id: str, tenant_id: str | None = None + self, + db: AsyncSession, + provider_type: str, + provider_user_id: str, + tenant_id: str | None = None, + identity_data: dict[str, Any] | None = None, ) -> User | None: """Check if an external identity is already linked to another user. @@ -449,7 +489,13 @@ async def check_duplicate_identity( Returns: Existing user if identity is already linked, None otherwise """ - return await self.resolve_user_identity(db, provider_user_id, provider_type, tenant_id) + return await self.resolve_user_identity( + db, + provider_user_id, + provider_type, + tenant_id, + identity_data=identity_data, + ) async def validate_sso_enablement(self, db: AsyncSession, tenant_id: uuid.UUID) -> bool: """Check if SSO can be enabled for this tenant under IP restrictions. diff --git a/backend/tests/test_identity_id_mapping.py b/backend/tests/test_identity_id_mapping.py index 0827b716c..b3b67cdb8 100644 --- a/backend/tests/test_identity_id_mapping.py +++ b/backend/tests/test_identity_id_mapping.py @@ -2,6 +2,26 @@ from app.services.sso_service import sso_service +def test_sso_identity_lookup_chain_prioritizes_unionid_then_userid_then_openid(): + lookup_chain = sso_service._identity_lookup_chain( + "feishu", + "ou_open_123", + { + "raw_data": { + "open_id": "ou_open_123", + "union_id": "on_union_456", + "user_id": "u_emp_789", + } + }, + ) + + assert lookup_chain == [ + ("unionid", "on_union_456"), + ("external_id", "u_emp_789"), + ("open_id", "ou_open_123"), + ] + + def test_sso_extract_identity_ids_uses_real_union_id_not_open_id(): union_id, open_id, external_id = sso_service._extract_identity_ids( "feishu", diff --git a/backend/tests/test_org_sync_adapter.py b/backend/tests/test_org_sync_adapter.py index 39a7071e0..338173438 100644 --- a/backend/tests/test_org_sync_adapter.py +++ b/backend/tests/test_org_sync_adapter.py @@ -1,3 +1,5 @@ +import asyncio +from contextlib import asynccontextmanager from types import SimpleNamespace import pytest @@ -22,6 +24,47 @@ async def fetch_users(self, department_external_id: str): return [] +class _FakeDB: + def __init__(self): + self.flush_calls = 0 + + @asynccontextmanager + async def begin_nested(self): + yield + + async def flush(self): + self.flush_calls += 1 + + +class _SyncAdapterWithFailure(_DummyAdapter): + def __init__(self): + super().__init__() + self.reconcile_called = False + self.member_counts_updated = False + self.provider = SimpleNamespace(id="provider-1", config={}) + + async def _ensure_provider(self, db): + return self.provider + + async def _upsert_department(self, db, provider, dept): + return None + + async def _upsert_member(self, db, provider, user, department_external_id): + raise ValueError("unionid is required") + + async def _reconcile(self, db, provider_id, sync_start): + self.reconcile_called = True + + async def _update_member_counts(self, db, provider_id): + self.member_counts_updated = True + + async def fetch_departments(self): + return [SimpleNamespace(external_id="dept-1", name="Dept 1")] + + async def fetch_users(self, department_external_id: str): + return [ExternalUser(external_id="user-1", name="Alice", unionid="")] + + def test_validate_member_identifiers_requires_unionid_for_feishu(): adapter = _DummyAdapter() provider = SimpleNamespace(provider_type="feishu") @@ -46,3 +89,14 @@ def test_validate_member_identifiers_allows_wecom_without_unionid(): user = ExternalUser(external_id="zhangsan", name="Zhang San", unionid="") adapter._validate_member_identifiers(provider, user) + + +def test_sync_org_structure_skips_reconcile_after_member_failure(): + adapter = _SyncAdapterWithFailure() + db = _FakeDB() + + result = asyncio.run(adapter.sync_org_structure(db)) + + assert adapter.reconcile_called is False + assert adapter.member_counts_updated is True + assert "Reconcile skipped due to partial sync failures" in result["errors"]