diff --git a/.github/workflows/server-ci-report.yml b/.github/workflows/server-ci-report.yml index f977e2a2bf4..37aaf693852 100644 --- a/.github/workflows/server-ci-report.yml +++ b/.github/workflows/server-ci-report.yml @@ -124,3 +124,29 @@ jobs: repo: context.repo.repo, body: body }) + + - name: Report retried tests to flaky-test webhook (pull request) + if: >- + steps.report.outputs.flaky_summary != '
TestRetries
' + && steps.report.outputs.failed == '0' + && github.event.workflow_run.event == 'pull_request' + && env.WEBHOOK_URL_FLAKY_TEST != '' + continue-on-error: true + env: + WEBHOOK_URL_FLAKY_TEST: ${{ secrets.WEBHOOK_URL_FLAKY_TEST }} + FLAKY_SUMMARY: ${{ steps.report.outputs.flaky_summary }} + PR_NUMBER: ${{ steps.incoming-pr.outputs.NUMBER }} + REPO: ${{ github.repository }} + run: | + PAYLOAD=$(jq -n \ + --arg repo "$REPO" \ + --arg pr_number "$PR_NUMBER" \ + --arg flaky_summary "$FLAKY_SUMMARY" \ + '{repo:$repo, pr_number:$pr_number, flaky_summary:$flaky_summary}') + + curl -X POST -fsSL \ + --connect-timeout 5 \ + --max-time 30 \ + -H "Content-Type: application/json" \ + -d "$PAYLOAD" \ + "$WEBHOOK_URL_FLAKY_TEST" diff --git a/api/v4/source/users.yaml b/api/v4/source/users.yaml index c22441ec34d..e5c6b6de375 100644 --- a/api/v4/source/users.yaml +++ b/api/v4/source/users.yaml @@ -1534,6 +1534,54 @@ $ref: "#/components/responses/Unauthorized" "404": $ref: "#/components/responses/NotFound" + /api/v4/users/auth_data: + get: + tags: + - users + summary: Get a user by auth data + description: > + Get a user by their external auth data identifier. The `value` is + matched against what is stored in `Users.AuthData`, which for most + identity providers is the identifier as the provider issues it. + + + The exception is Active Directory `objectGUID`: under + `auth_service: ldap` it is stored as the LDAP filter hex-escape + form (e.g. `\61\14\e1\d1\c5\35\18\4a\b6\60\d6\78\50\fd\0d\5d`), + and under `auth_service: saml` it is stored as the standard + Base64 of the same bytes (e.g. `YRTh0cU1GEq2YNZ4UP0NXQ==`). Use + the form matching the user's current `AuthService`. + + + ##### Permissions + + Must be a system admin. + operationId: GetUserByAuthData + parameters: + - name: value + in: query + description: > + The user's AuthData as stored in `Users.AuthData`. Must be + URL-encoded; in particular, Base64 `+` characters must be sent + as `%2B` so they are not decoded as spaces. + required: true + schema: + type: string + responses: + "200": + description: User retrieval successful + content: + application/json: + schema: + $ref: "#/components/schemas/User" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "403": + $ref: "#/components/responses/Forbidden" + "404": + $ref: "#/components/responses/NotFound" /api/v4/users/password/reset: post: tags: diff --git a/e2e-tests/playwright/lib/src/ui/components/system_console/sections/system_attributes/system_properties.ts b/e2e-tests/playwright/lib/src/ui/components/system_console/sections/system_attributes/system_properties.ts index ec3c6df9328..5af51a245f4 100644 --- a/e2e-tests/playwright/lib/src/ui/components/system_console/sections/system_attributes/system_properties.ts +++ b/e2e-tests/playwright/lib/src/ui/components/system_console/sections/system_attributes/system_properties.ts @@ -51,6 +51,17 @@ export default class SystemProperties { return this.container.locator(`input[value="${value}"]`); } + displayNameInput(nth: number): Locator { + return this.container.getByTestId('property-display-name-input').nth(nth); + } + + displayNameInputNear(identifierValue: string): Locator { + return this.container + .locator('tr') + .filter({has: this.nameInputByValue(identifierValue)}) + .getByTestId('property-display-name-input'); + } + typeSelector(nth: number): Locator { return this.container.getByTestId('fieldTypeSelectorMenuButton').nth(nth); } @@ -73,6 +84,10 @@ export default class SystemProperties { return this.container.getByTestId('property-field-input').last(); } + lastDisplayNameInput(): Locator { + return this.container.getByTestId('property-display-name-input').last(); + } + lastTypeSelector(): Locator { return this.container.getByTestId('fieldTypeSelectorMenuButton').last(); } @@ -209,7 +224,11 @@ export default class SystemProperties { // ── Validation ────────────────────────────────────────────────────── - validationMessage(text: string): Locator { + identifierValidationError(): Locator { + return this.container.getByTestId('property-field-validation-error'); + } + + validationMessage(text: string | RegExp): Locator { return this.container.getByText(text); } } diff --git a/e2e-tests/playwright/specs/functional/channels/custom_profile_attributes/helpers.ts b/e2e-tests/playwright/specs/functional/channels/custom_profile_attributes/helpers.ts index c08c4877f99..d130f7cb37b 100644 --- a/e2e-tests/playwright/specs/functional/channels/custom_profile_attributes/helpers.ts +++ b/e2e-tests/playwright/specs/functional/channels/custom_profile_attributes/helpers.ts @@ -49,6 +49,7 @@ export type CustomProfileAttribute = { visibility?: string; managed?: string; options?: {name: string; color: string}[]; + display_name?: string; }; }; diff --git a/e2e-tests/playwright/specs/functional/system_console/abac/user_attributes/display_name_in_selector.spec.ts b/e2e-tests/playwright/specs/functional/system_console/abac/user_attributes/display_name_in_selector.spec.ts new file mode 100644 index 00000000000..2b3bcc48edf --- /dev/null +++ b/e2e-tests/playwright/specs/functional/system_console/abac/user_attributes/display_name_in_selector.spec.ts @@ -0,0 +1,207 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import type {Client4} from '@mattermost/client'; +import type {UserPropertyField} from '@mattermost/types/properties'; + +import {expect, test, enableABAC, navigateToABACPage} from '@mattermost/playwright-lib'; + +import { + CustomProfileAttribute, + deleteCustomProfileAttributes, + setupCustomProfileAttributeFields, +} from '../../../channels/custom_profile_attributes/helpers'; +import {getPolicyIdByName} from '../support'; + +type FieldsMap = Record; + +async function clearExistingFields(client: Client4): Promise { + try { + const existing = await client.getCustomProfileAttributeFields(); + if (existing?.length) { + const map: FieldsMap = {}; + for (const f of existing) { + map[f.id] = f; + } + await deleteCustomProfileAttributes(client, map); + } + } catch { + // No fields to clean up + } +} + +test.describe('ABAC Attribute Selector - display_name rendering and filtering', () => { + /** + * @objective Verify the attribute selector renders display_name when set, + * falls back to `name`, filters on both, and persists the CEL identifier + * in saved policy expressions. + * + * @precondition + * Two admin-managed CPA fields seeded via REST: `dept_head` with + * display_name 'Department Head', and `office` with no display_name. + */ + test( + 'renders and filters by display_name while persisting CEL identifier', + {tag: '@user_attributes'}, + async ({pw}) => { + test.setTimeout(120000); + + await pw.ensureLicense(); + await pw.skipIfNoLicense(); + + const {adminUser, adminClient} = await pw.initSetup(); + + await clearExistingFields(adminClient); + + const seedAttributes: CustomProfileAttribute[] = [ + { + name: 'dept_head', + type: 'text', + attrs: { + display_name: 'Department Head', + visibility: 'when_set', + managed: 'admin', + }, + }, + { + name: 'office', + type: 'text', + attrs: { + visibility: 'when_set', + managed: 'admin', + }, + }, + ]; + + const fieldsMap = await setupCustomProfileAttributeFields(adminClient, seedAttributes); + + const policyName = `Display Name Selector ${pw.random.id()}`; + let policyId: string | null = null; + + try { + const {systemConsolePage} = await pw.testBrowser.login(adminUser); + const {page} = systemConsolePage; + + await navigateToABACPage(page); + await enableABAC(page); + + // # Open the new-policy form + await page.getByRole('button', {name: 'Add policy'}).click(); + await page.waitForLoadState('networkidle'); + + const nameInput = page.locator('#admin\\.access_control\\.policy\\.edit_policy\\.policyName'); + await nameInput.waitFor({state: 'visible', timeout: 10000}); + await nameInput.fill(policyName); + + // # Add an attribute rule and open its selector + const addAttributeButton = page.getByRole('button', {name: /add attribute/i}); + await expect(addAttributeButton).toBeEnabled({timeout: 10000}); + await addAttributeButton.click(); + + const attributeButton = page.locator('[data-testid="attributeSelectorMenuButton"]').first(); + await attributeButton.waitFor({state: 'visible', timeout: 5000}); + + const attributeMenu = page.locator('[id^="attribute-selector-menu"]'); + + if (!(await attributeMenu.isVisible({timeout: 1000}).catch(() => false))) { + await attributeButton.click(); + } + await attributeMenu.waitFor({state: 'visible', timeout: 5000}); + + const deptHeadItem = page.locator('[id^="attribute-selector-menu"] li:has-text("Department Head")'); + const officeItem = page.locator('[id^="attribute-selector-menu"] li:has-text("office")'); + + // * Both fields render: 'Department Head' (display_name) and 'office' (name fallback) + await expect(deptHeadItem).toBeVisible(); + await expect(officeItem).toBeVisible(); + + const filterInput = attributeMenu.locator('.attribute-selector-search input'); + await filterInput.waitFor({state: 'visible', timeout: 5000}); + + // * Filter by display_name keeps only 'Department Head' + await filterInput.fill('department'); + await expect(deptHeadItem).toBeVisible(); + await expect(officeItem).toHaveCount(0); + + // * Filter by CEL identifier keeps only 'Department Head' + await filterInput.fill(''); + await filterInput.fill('dept_head'); + await expect(deptHeadItem).toBeVisible(); + await expect(officeItem).toHaveCount(0); + + // * Filter on the no-display_name field's `name` keeps only 'office' + await filterInput.fill(''); + await filterInput.fill('office'); + await expect(officeItem).toBeVisible(); + await expect(deptHeadItem).toHaveCount(0); + + // # Select 'Department Head' + await filterInput.fill(''); + await deptHeadItem.first().click({force: true}); + + // * The trigger button shows display_name, not the CEL identifier + await expect(attributeButton).toContainText('Department Head', {timeout: 5000}); + + // # Wait for the attribute-selector popover to close before opening the next menu + const attributeMenuPopover = page.locator('[id^="attribute-selector-menu"]'); + await attributeMenuPopover.waitFor({state: 'hidden', timeout: 5000}); + + const operatorButton = page.locator('[data-testid="operatorSelectorMenuButton"]').first(); + await operatorButton.waitFor({state: 'visible', timeout: 5000}); + await operatorButton.click(); + + const operatorMenu = page.locator('[id^="operator-selector-menu"]'); + await operatorMenu.waitFor({state: 'visible', timeout: 5000}); + await operatorMenu.locator('li:has-text("is")').first().click(); + + const valueInput = page.locator('.values-editor__simple-input').first(); + await valueInput.waitFor({state: 'visible', timeout: 10000}); + await valueInput.fill('engineering'); + await valueInput.press('Tab'); + + const saveButton = page.getByRole('button', {name: 'Save'}); + await expect(saveButton).toBeEnabled({timeout: 10000}); + + // # Click Save and wait on the create-policy PUT (button unmounts on navigate) + const createPolicyResponse = page.waitForResponse( + (r) => /\/access_control_policies(?:\?|$)/.test(r.url()) && r.request().method() === 'PUT', + {timeout: 15000}, + ); + await saveButton.click(); + const createResponse = await createPolicyResponse; + expect(createResponse.ok()).toBe(true); + + await page.waitForURL(/\/admin_console\/system_attributes\/membership_policies/, {timeout: 10000}); + + // * The persisted CEL uses the canonical identifier, not display_name + policyId = await getPolicyIdByName(adminClient, policyName); + expect(policyId).not.toBeNull(); + + const policy = await (adminClient as any).doFetch( + `${adminClient.getBaseRoute()}/access_control_policies/${policyId}`, + {method: 'GET'}, + ); + + const rules = (policy?.rules || []) as Array<{actions?: string[]; expression?: string}>; + const membershipRule = rules.find((r) => r.actions?.includes('membership')) || rules[0]; + + expect(membershipRule).toBeDefined(); + expect(membershipRule?.expression || '').toContain('user.attributes.dept_head'); + expect(membershipRule?.expression || '').not.toContain('Department Head'); + } finally { + if (policyId) { + try { + await (adminClient as any).doFetch( + `${adminClient.getBaseRoute()}/access_control_policies/${policyId}`, + {method: 'DELETE'}, + ); + } catch { + // best-effort cleanup + } + } + + await deleteCustomProfileAttributes(adminClient, fieldsMap); + } + }, + ); +}); diff --git a/e2e-tests/playwright/specs/functional/system_console/system_users/user_attributes_admin_editing.spec.ts b/e2e-tests/playwright/specs/functional/system_console/system_users/user_attributes_admin_editing.spec.ts index c83cf9542f3..cccb2305068 100644 --- a/e2e-tests/playwright/specs/functional/system_console/system_users/user_attributes_admin_editing.spec.ts +++ b/e2e-tests/playwright/specs/functional/system_console/system_users/user_attributes_admin_editing.spec.ts @@ -36,6 +36,15 @@ let cpaFieldNames: { skills: string; }; +/** Rendered label for each CPA field: attrs.display_name when set, else field.name. */ +let cpaDisplayLabels: { + department: string; + workEmail: string; + personalWebsite: string; + location: string; + skills: string; +}; + let testUserAttributes: CustomProfileAttribute[]; let team: Team; @@ -76,6 +85,14 @@ test.describe('System Console - Admin User Profile Editing', () => { location: `UAAE_Location_${suffix}`, skills: `UAAE_Skills_${suffix}`, }; + // Mirror display_name values from testUserAttributes; absent display_name falls back to name. + cpaDisplayLabels = { + department: cpaFieldNames.department, + workEmail: 'Work Email', + personalWebsite: 'Personal Website', + location: cpaFieldNames.location, + skills: cpaFieldNames.skills, + }; testUserAttributes = [ { name: cpaFieldNames.department, @@ -92,6 +109,7 @@ test.describe('System Console - Admin User Profile Editing', () => { attrs: { value_type: 'email', visibility: 'when_set', + display_name: 'Work Email', }, }, { @@ -101,6 +119,7 @@ test.describe('System Console - Admin User Profile Editing', () => { attrs: { value_type: 'url', visibility: 'when_set', + display_name: 'Personal Website', }, }, { @@ -242,8 +261,8 @@ test.describe('System Console - Admin User Profile Editing', () => { await systemConsolePage.page.waitForURL(`**/admin_console/user_management/user/${testUser.id}`); await systemConsolePage!.users.userDetail.userCard.container.waitFor({state: 'visible'}); const {userCard} = systemConsolePage!.users.userDetail; - await expect(userCard.getFieldInputByExactLabel(cpaFieldNames.department)).toBeVisible({timeout: 30_000}); - await expect(userCard.getFieldInputByExactLabel(cpaFieldNames.workEmail)).toBeVisible({timeout: 30_000}); + await expect(userCard.getFieldInputByExactLabel(cpaDisplayLabels.department)).toBeVisible({timeout: 30_000}); + await expect(userCard.getFieldInputByExactLabel(cpaDisplayLabels.workEmail)).toBeVisible({timeout: 30_000}); // Remove the intercept now that field visibility is confirmed. // Keeping it active through the test body would intercept the save API call @@ -271,7 +290,7 @@ test.describe('System Console - Admin User Profile Editing', () => { const {userCard} = userDetail; // # Find and edit Department field (custom text attribute) - const departmentInput = userCard.getFieldInputByExactLabel(cpaFieldNames.department); + const departmentInput = userCard.getFieldInputByExactLabel(cpaDisplayLabels.department); await departmentInput.clear(); await departmentInput.fill('Marketing'); @@ -300,9 +319,8 @@ test.describe('System Console - Admin User Profile Editing', () => { // * Verify custom user attributes are present for (const field of testUserAttributes) { - await expect( - systemConsolePage!.page.locator('label').filter({hasText: new RegExp(field.name)}), - ).toBeVisible(); + const label = field.attrs?.display_name || field.name; + await expect(systemConsolePage!.page.locator('label').filter({hasText: label})).toBeVisible(); } // * Verify we have input fields (at least 4-5 total) @@ -339,7 +357,7 @@ test.describe('System Console - Admin User Profile Editing', () => { const {userCard} = userDetail; // # Find Location select field - const locationSelect = userCard.getSelectByExactLabel(cpaFieldNames.location); + const locationSelect = userCard.getSelectByExactLabel(cpaDisplayLabels.location); // # Get the first available option (since we can't predict the option value/ID) const firstOption = await locationSelect.locator('option').nth(1); // Skip the default "Select an option" @@ -363,11 +381,11 @@ test.describe('System Console - Admin User Profile Editing', () => { const {userCard} = userDetail; // * Verify Skills multiselect component is displayed - const skillsColumn = userCard.getCpaMultiselectContainer(cpaFieldNames.skills); + const skillsColumn = userCard.getCpaMultiselectContainer(cpaDisplayLabels.skills); await expect(skillsColumn).toBeVisible(); // # Make a change to a different field to trigger save state - const departmentInput = userCard.getFieldInputByExactLabel(cpaFieldNames.department); + const departmentInput = userCard.getFieldInputByExactLabel(cpaDisplayLabels.department); await departmentInput.fill('Engineering Updated'); // # Verify save button becomes enabled @@ -407,7 +425,7 @@ test.describe('System Console - Admin User Profile Editing', () => { }); try { // # Find CPA email field (Work Email) - const workEmailInput = userCard.getFieldInputByExactLabel(cpaFieldNames.workEmail); + const workEmailInput = userCard.getFieldInputByExactLabel(cpaDisplayLabels.workEmail); await workEmailInput.scrollIntoViewIfNeeded(); const originalEmail = await workEmailInput.inputValue(); @@ -416,7 +434,7 @@ test.describe('System Console - Admin User Profile Editing', () => { await workEmailInput.fill('not-an-email'); // * Verify inline validation error appears - const fieldError = userCard.getFieldError(cpaFieldNames.workEmail); + const fieldError = userCard.getFieldError(cpaDisplayLabels.workEmail); await expect(fieldError).toBeVisible({timeout: 30000}); await expect(fieldError).toContainText('Invalid email address'); @@ -461,7 +479,7 @@ test.describe('System Console - Admin User Profile Editing', () => { }); try { // # Find custom URL field (Personal Website) - const urlInput = userCard.getFieldInputByExactLabel(cpaFieldNames.personalWebsite); + const urlInput = userCard.getFieldInputByExactLabel(cpaDisplayLabels.personalWebsite); const originalUrl = await urlInput.inputValue(); // # Enter invalid URL (specifically the one mentioned: "<%>") @@ -469,7 +487,7 @@ test.describe('System Console - Admin User Profile Editing', () => { await urlInput.fill('<%>'); // * Verify inline validation error appears - const fieldError = userCard.getFieldError(cpaFieldNames.personalWebsite); + const fieldError = userCard.getFieldError(cpaDisplayLabels.personalWebsite); await expect(fieldError).toBeVisible(); await expect(fieldError).toContainText('Invalid URL'); @@ -511,14 +529,14 @@ test.describe('System Console - Admin User Profile Editing', () => { }); try { // # Find custom email field (Work Email) - const workEmailInput = userCard.getFieldInputByExactLabel(cpaFieldNames.workEmail); + const workEmailInput = userCard.getFieldInputByExactLabel(cpaDisplayLabels.workEmail); // # Enter invalid email await workEmailInput.clear(); await workEmailInput.fill('not-an-email-either'); // * Verify inline validation error appears - const fieldError = userCard.getFieldError(cpaFieldNames.workEmail); + const fieldError = userCard.getFieldError(cpaDisplayLabels.workEmail); await expect(fieldError).toBeVisible(); await expect(fieldError).toContainText('Invalid email address'); @@ -541,7 +559,7 @@ test.describe('System Console - Admin User Profile Editing', () => { await expect(userDetail.cancelButton).not.toBeVisible(); // # Make a change to trigger save needed state - const departmentInput = userCard.getFieldInputByExactLabel(cpaFieldNames.department); + const departmentInput = userCard.getFieldInputByExactLabel(cpaDisplayLabels.department); const originalValue = await departmentInput.inputValue(); await departmentInput.clear(); await departmentInput.fill('Changed Value'); @@ -573,7 +591,7 @@ test.describe('System Console - Admin User Profile Editing', () => { await userCard.emailInput.clear(); await userCard.emailInput.fill(newEmail); - const departmentInput = userCard.getFieldInputByExactLabel(cpaFieldNames.department); + const departmentInput = userCard.getFieldInputByExactLabel(cpaDisplayLabels.department); await departmentInput.clear(); await departmentInput.fill('Sales'); diff --git a/e2e-tests/playwright/specs/functional/system_console/user_attributes/user_attributes.spec.ts b/e2e-tests/playwright/specs/functional/system_console/user_attributes/user_attributes.spec.ts index be0aeea8531..a580d14a450 100644 --- a/e2e-tests/playwright/specs/functional/system_console/user_attributes/user_attributes.spec.ts +++ b/e2e-tests/playwright/specs/functional/system_console/user_attributes/user_attributes.spec.ts @@ -165,14 +165,14 @@ test.describe('System Console - User Attributes Management', () => { await expect(nameInput).toBeVisible(); // # Type attribute name (must be a valid CEL identifier — no spaces) - await nameInput.fill('Test_Department'); + await nameInput.fill('test_department'); await nameInput.blur(); await sp.saveAndWaitForSettled(); // * Verify the field was created by fetching from API const fieldsMap = await getFieldsMap(adminClient); - const createdField = Object.values(fieldsMap).find((f) => f.name === 'Test_Department'); + const createdField = Object.values(fieldsMap).find((f) => f.name === 'test_department'); expect(createdField).toBeDefined(); expect(createdField!.type).toBe('text'); @@ -195,7 +195,7 @@ test.describe('System Console - User Attributes Management', () => { // # Type attribute name (must be a valid CEL identifier — no spaces) const nameInput = sp.lastNameInput(); - await nameInput.fill('Office_Location'); + await nameInput.fill('office_location'); await nameInput.blur(); // # Change type to Select (use selectLastType so the index stays correct @@ -212,7 +212,7 @@ test.describe('System Console - User Attributes Management', () => { // * Verify field was created with correct type via API const fieldsMap = await getFieldsMap(adminClient); - const createdField = Object.values(fieldsMap).find((f) => f.name === 'Office_Location'); + const createdField = Object.values(fieldsMap).find((f) => f.name === 'office_location'); expect(createdField).toBeDefined(); expect(createdField!.type).toBe('select'); expect(createdField!.attrs.options).toBeDefined(); @@ -233,16 +233,16 @@ test.describe('System Console - User Attributes Management', () => { const sp = systemConsolePage.systemProperties; // # Create an attribute via API - const attributes: CustomProfileAttribute[] = [{name: 'Old_Name', type: 'text'}]; + const attributes: CustomProfileAttribute[] = [{name: 'old_name', type: 'text'}]; const fieldsMap = await setupCustomProfileAttributeFields(adminClient, attributes); // # Navigate to User Attributes page await sp.goto(); - const nameInputLocator = sp.nameInputByValue('Old_Name'); - await expect(nameInputLocator).toBeVisible(); - await nameInputLocator.focus(); - await nameInputLocator.fill('New_Name'); + const nameInput = sp.nameInputByValue('old_name'); + await expect(nameInput).toBeVisible(); + await nameInput.focus(); + await nameInput.fill('new_name'); // blur via keyboard — the CSS-attribute locator no longer matches // after fill() so calling .blur() on it would time out. await sp.page.keyboard.press('Tab'); @@ -251,7 +251,7 @@ test.describe('System Console - User Attributes Management', () => { // * Verify field was updated via API const updatedMap = await getFieldsMap(adminClient); - expect(Object.values(updatedMap).find((f) => f.name === 'New_Name')).toBeDefined(); + expect(Object.values(updatedMap).find((f) => f.name === 'new_name')).toBeDefined(); await cleanupFields(adminClient, {...fieldsMap, ...updatedMap}); }); @@ -268,7 +268,7 @@ test.describe('System Console - User Attributes Management', () => { const sp = systemConsolePage.systemProperties; // # Create an attribute via API - const attributes: CustomProfileAttribute[] = [{name: 'To_Delete', type: 'text'}]; + const attributes: CustomProfileAttribute[] = [{name: 'to_delete', type: 'text'}]; const fieldsMap = await setupCustomProfileAttributeFields(adminClient, attributes); const fieldId = Object.keys(fieldsMap)[0]; @@ -276,7 +276,7 @@ test.describe('System Console - User Attributes Management', () => { await sp.goto(); // * Verify the attribute exists - await expect(sp.nameInputByValue('To_Delete')).toBeVisible(); + await expect(sp.nameInputByValue('to_delete')).toBeVisible(); // # Open dot menu for the field await sp.openDotMenu(fieldId); @@ -291,7 +291,7 @@ test.describe('System Console - User Attributes Management', () => { // * Verify field was deleted via API const updatedMap = await getFieldsMap(adminClient); - expect(Object.values(updatedMap).find((f) => f.name === 'To_Delete')).toBeUndefined(); + expect(Object.values(updatedMap).find((f) => f.name === 'to_delete')).toBeUndefined(); await cleanupFields(adminClient, updatedMap); }); @@ -321,13 +321,10 @@ test.describe('System Console - User Attributes Management', () => { // # Click "Duplicate attribute" await sp.duplicateAttribute(); - // * Verify a copy row appeared (server generates "Original (copy)" as the default name) - await expect(sp.nameInputByValue('Original (copy)')).toBeVisible(); + // * Verify a copy row appeared with "_copy" suffix in the name + await expect(sp.nameInputByValue('Original_copy')).toBeVisible(); - // # Rename the copy to a valid CEL identifier. - // "Original (copy)" contains spaces and parentheses which the server rejects with 422. - // Use lastNameInput() for the fill/blur — it's position-based (.last()) so it stays - // valid after the value changes, unlike the value-based nameInputByValue locator. + // # Rename the copy to a valid CEL identifier (server rejects spaces/parens with 422) const copyInput = sp.lastNameInput(); await copyInput.fill('Original_copy'); await copyInput.blur(); @@ -354,7 +351,7 @@ test.describe('System Console - User Attributes Management', () => { const sp = systemConsolePage.systemProperties; // # Create an attribute via API - const attributes: CustomProfileAttribute[] = [{name: 'Visibility_Test', type: 'text'}]; + const attributes: CustomProfileAttribute[] = [{name: 'visibility_test', type: 'text'}]; const fieldsMap = await setupCustomProfileAttributeFields(adminClient, attributes); const fieldId = Object.keys(fieldsMap)[0]; @@ -371,7 +368,7 @@ test.describe('System Console - User Attributes Management', () => { // * Verify visibility was updated via API const updatedMap = await getFieldsMap(adminClient); - const updatedField = Object.values(updatedMap).find((f) => f.name === 'Visibility_Test'); + const updatedField = Object.values(updatedMap).find((f) => f.name === 'visibility_test'); expect(updatedField).toBeDefined(); expect(updatedField!.attrs.visibility).toBe('hidden'); @@ -390,7 +387,7 @@ test.describe('System Console - User Attributes Management', () => { const sp = systemConsolePage.systemProperties; // # Create an attribute via API - const attributes: CustomProfileAttribute[] = [{name: 'Editable_Test', type: 'text'}]; + const attributes: CustomProfileAttribute[] = [{name: 'editable_test', type: 'text'}]; const fieldsMap = await setupCustomProfileAttributeFields(adminClient, attributes); const fieldId = Object.keys(fieldsMap)[0]; @@ -420,7 +417,7 @@ test.describe('System Console - User Attributes Management', () => { await expect .poll(async () => { const map = await getFieldsMap(adminClient); - return Object.values(map).find((f) => f.name === 'Editable_Test'); + return Object.values(map).find((f) => f.name === 'editable_test'); }) .toMatchObject({attrs: {managed: 'admin'}}); @@ -428,8 +425,9 @@ test.describe('System Console - User Attributes Management', () => { }); /** - * @objective Verify that leaving an attribute name empty shows a validation - * warning and disables the Save button. + * @objective Verify that clearing the auto-derived CEL identifier (Name) + * after entering a Display Name shows the empty-name validation warning + * and disables the Save button. */ test('shows validation warning for empty attribute name', {tag: '@user_attributes'}, async ({pw}) => { const {systemConsolePage} = await setupTest(pw); @@ -441,9 +439,16 @@ test.describe('System Console - User Attributes Management', () => { // # Add a new attribute await sp.addAttribute(); - // # Clear the auto-focused name input (leave it empty). - // Use lastNameInput() so concurrent UAAE/ABAC rows don't shift the index. + // # Fill Display Name so the Name field auto-derives as snake_case + const displayNameInput = sp.lastDisplayNameInput(); + await displayNameInput.fill('Job Title'); + await displayNameInput.blur(); + + // * Verify the Name field auto-populated with the snake_case identifier const nameInput = sp.lastNameInput(); + await expect(nameInput).toHaveValue('job_title'); + + // # Clear the auto-derived identifier and blur to trigger the empty-name warning await nameInput.clear(); await nameInput.blur(); @@ -466,7 +471,7 @@ test.describe('System Console - User Attributes Management', () => { const sp = systemConsolePage.systemProperties; // # Create an attribute via API (name must be a valid CEL identifier — no spaces) - const uniqueDupName = `UniqueName_${Date.now()}`; + const uniqueDupName = `unique_name_${Date.now()}`; const attributes: CustomProfileAttribute[] = [{name: uniqueDupName, type: 'text'}]; const fieldsMap = await setupCustomProfileAttributeFields(adminClient, attributes); @@ -503,7 +508,7 @@ test.describe('System Console - User Attributes Management', () => { const sp = systemConsolePage.systemProperties; // # Create a text attribute via API - const attributes: CustomProfileAttribute[] = [{name: 'Contact_Number', type: 'text'}]; + const attributes: CustomProfileAttribute[] = [{name: 'contact_number', type: 'text'}]; await setupCustomProfileAttributeFields(adminClient, attributes); // # Navigate to User Attributes page @@ -512,13 +517,13 @@ test.describe('System Console - User Attributes Management', () => { // # Select "Phone" type for the Contact_Number field. // Use selectTypeForField() — resolves the row index by name so concurrent // UAAE/ABAC rows don't shift the positional index. - await sp.selectTypeForField('Contact_Number', 'Phone'); + await sp.selectTypeForField('contact_number', 'Phone'); await sp.saveAndWaitForSettled(); // * Verify field type was updated via API const updatedMap = await getFieldsMap(adminClient); - const updatedField = Object.values(updatedMap).find((f) => f.name === 'Contact_Number'); + const updatedField = Object.values(updatedMap).find((f) => f.name === 'contact_number'); expect(updatedField).toBeDefined(); expect(updatedField!.type).toBe('text'); expect(updatedField!.attrs.value_type).toBe('phone'); @@ -581,21 +586,21 @@ test.describe('System Console - User Attributes Management', () => { // # Create first attribute (text) — use lastNameInput() after each addAttribute() await sp.addAttribute(); const firstInput = sp.lastNameInput(); - await firstInput.fill('Job_Title'); + await firstInput.fill('job_title'); await firstInput.blur(); // # Create second attribute (text) await sp.addAttribute(); const secondInput = sp.lastNameInput(); - await secondInput.fill('Team_Name'); + await secondInput.fill('team_name'); await secondInput.blur(); await sp.saveAndWaitForSettled(); // * Verify both fields were created via API const fieldsMap = await getFieldsMap(adminClient); - expect(Object.values(fieldsMap).find((f) => f.name === 'Job_Title')).toBeDefined(); - expect(Object.values(fieldsMap).find((f) => f.name === 'Team_Name')).toBeDefined(); + expect(Object.values(fieldsMap).find((f) => f.name === 'job_title')).toBeDefined(); + expect(Object.values(fieldsMap).find((f) => f.name === 'team_name')).toBeDefined(); await cleanupFields(adminClient, fieldsMap); }); @@ -612,20 +617,20 @@ test.describe('System Console - User Attributes Management', () => { const sp = systemConsolePage.systemProperties; // # Create an attribute via API - await setupCustomProfileAttributeFields(adminClient, [{name: 'Persistent_Field', type: 'text'}]); + await setupCustomProfileAttributeFields(adminClient, [{name: 'persistent_field', type: 'text'}]); // # Navigate to User Attributes page await sp.goto(); // * Verify attribute exists - await expect(sp.nameInputByValue('Persistent_Field')).toBeVisible(); + await expect(sp.nameInputByValue('persistent_field')).toBeVisible(); // # Edit the name using a value-based locator so concurrent UAAE/ABAC rows // don't shift a positional index to the wrong field. - const nameInput = sp.nameInputByValue('Persistent_Field'); - await expect(nameInput).toHaveValue('Persistent_Field'); + const nameInput = sp.nameInputByValue('persistent_field'); + await expect(nameInput).toHaveValue('persistent_field'); await nameInput.focus(); - await nameInput.fill('Updated_Persistent'); + await nameInput.fill('updated_persistent'); // blur via keyboard — the value-based locator is stale after fill() await sp.page.keyboard.press('Tab'); @@ -635,7 +640,7 @@ test.describe('System Console - User Attributes Management', () => { await sp.goto(); // * Verify the updated name persisted - await expect(sp.nameInputByValue('Updated_Persistent')).toBeVisible(); + await expect(sp.nameInputByValue('updated_persistent')).toBeVisible(); await cleanupFields(adminClient, await getFieldsMap(adminClient)); }); diff --git a/e2e-tests/playwright/specs/functional/system_console/user_attributes/user_attributes_display_name.spec.ts b/e2e-tests/playwright/specs/functional/system_console/user_attributes/user_attributes_display_name.spec.ts new file mode 100644 index 00000000000..ca5eccdaacb --- /dev/null +++ b/e2e-tests/playwright/specs/functional/system_console/user_attributes/user_attributes_display_name.spec.ts @@ -0,0 +1,215 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import {Client4} from '@mattermost/client'; +import type {UserPropertyField} from '@mattermost/types/properties'; +import type {UserProfile} from '@mattermost/types/users'; + +import {expect, test, testConfig} from '@mattermost/playwright-lib'; +import type {PlaywrightExtended, SystemConsolePage} from '@mattermost/playwright-lib'; + +import {setupCustomProfileAttributeValuesForUser} from '../../channels/custom_profile_attributes/helpers'; + +type AdminUser = UserProfile & {password: string}; + +const IDENTIFIER_VALIDATION_MESSAGE = + 'Identifier must start with a letter or underscore and contain only letters, numbers, and underscores. Reserved CEL words are not allowed.'; + +interface TestContext { + adminClient: Client4; + adminUser: AdminUser; + systemConsolePage: SystemConsolePage; +} + +async function createAdminClient(): Promise<{adminClient: Client4; adminUser: AdminUser}> { + const adminClient = new Client4(); + adminClient.setUrl(testConfig.baseURL); + + const loggedInUser = await adminClient.login(testConfig.adminUsername, testConfig.adminPassword); + const adminUser = { + ...loggedInUser, + password: testConfig.adminPassword, + } as AdminUser; + + return {adminClient, adminUser}; +} + +async function setupTest(pw: PlaywrightExtended): Promise { + const {adminClient, adminUser} = await createAdminClient(); + + const {systemConsolePage} = await pw.testBrowser.login(adminUser); + await systemConsolePage.goto(); + await systemConsolePage.toBeVisible(); + + return {adminClient, adminUser, systemConsolePage}; +} + +async function getTownSquareRoute(adminClient: Client4) { + const teams = await adminClient.getMyTeams(); + expect(teams.length).toBeGreaterThan(0); + + const team = teams[0]; + const channel = await adminClient.getChannelByName(team.id, 'town-square'); + + return { + teamName: team.name, + channelName: channel.name, + }; +} + +test.describe('System Console - User Attributes display names', () => { + /** + * @objective Verify that a CPA field's display_name is rendered as the user-facing + * label in the user attributes table, the profile popover, account settings, and + * the admin user detail page while the identifier remains unchanged in the API. + */ + test('renders display_name across admin and self-service surfaces', {tag: '@user_attributes'}, async ({pw}) => { + const {adminClient, adminUser, systemConsolePage} = await setupTest(pw); + const sp = systemConsolePage.systemProperties; + + const uid = Date.now(); + const identifier = `department_${uid}`; + const displayName = `Department ${uid}`; + const attributeValue = 'Engineering'; + + let createdField: UserPropertyField | undefined; + + try { + // # Navigate to User Attributes and create a new field with a display name + await sp.goto(); + + // * Verify the table exposes both Name and Display Name column headers + await expect(sp.container.getByRole('columnheader', {name: 'Display Name', exact: true})).toBeVisible(); + await expect(sp.container.getByRole('columnheader', {name: 'Name', exact: true})).toBeVisible(); + + // # Add a new row and fill identifier + display name + await sp.addAttribute(); + await sp.lastNameInput().fill(identifier); + await sp.lastNameInput().blur(); + await sp.lastDisplayNameInput().fill(displayName); + await sp.lastDisplayNameInput().blur(); + + await sp.saveAndWaitForSettled(); + + const fields = await adminClient.getCustomProfileAttributeFields(); + createdField = fields.find((field) => field.name === identifier); + + expect(createdField).toBeDefined(); + expect(createdField?.attrs?.display_name).toBe(displayName); + + // # Set a value for sysadmin and open the self profile popover in Channels + await setupCustomProfileAttributeValuesForUser( + adminClient, + [{name: identifier, value: attributeValue, type: 'text'}], + {[createdField!.id]: createdField!}, + adminUser.id, + ); + + const {teamName, channelName} = await getTownSquareRoute(adminClient); + const {channelsPage} = await pw.testBrowser.login(adminUser); + + await channelsPage.goto(teamName, channelName); + await channelsPage.postMessage(`phase-5-display-name-${uid}`); + + const lastPost = await channelsPage.getLastPost(); + await channelsPage.openProfilePopover(lastPost); + + // * Verify the profile popover and account settings render display_name + await expect( + channelsPage.page.locator(`#user-popover__custom_attributes-title-${createdField!.id}`), + ).toHaveText(displayName); + + await channelsPage.userProfilePopover.close(); + + const profileModal = await channelsPage.openProfileModal(); + const section = profileModal.container.locator('.section-min').filter({hasText: displayName}); + await expect(section).toBeVisible(); + + const editButton = profileModal.container.locator(`#customAttribute_${createdField!.id}Edit`); + await editButton.scrollIntoViewIfNeeded(); + await editButton.click(); + + const settingsInput = profileModal.container.locator(`#customAttribute_${createdField!.id}`); + await expect(settingsInput).toHaveAttribute('aria-label', displayName); + await profileModal.closeModal(); + + // # Open the admin user detail page for sysadmin + await systemConsolePage.page.goto(`/admin_console/user_management/user/${adminUser.id}`); + await systemConsolePage.users.userDetail.toBeVisible(); + + // * Verify the admin user detail label also uses display_name + await expect( + systemConsolePage.page.getByTestId(`user-detail-custom-attribute-label-${createdField!.id}`), + ).toContainText(displayName); + } finally { + if (createdField) { + await adminClient.deleteCustomProfileAttributeField(createdField.id).catch(() => undefined); + } + } + }); + + /** + * @objective Verify that invalid CPA identifiers are blocked client-side before + * any create-field API request is issued, and that a valid identifier clears the + * warning and can be saved successfully. + */ + test('blocks invalid identifiers before API submission', {tag: '@user_attributes'}, async ({pw}) => { + const {adminClient, systemConsolePage} = await setupTest(pw); + const sp = systemConsolePage.systemProperties; + const {page} = systemConsolePage; + + const apiPosts: string[] = []; + const validIdentifier = `my_field_${Date.now()}`; + + page.on('request', (request) => { + if (request.method() === 'POST' && request.url().includes('/api/v4/custom_profile_attributes/fields')) { + apiPosts.push(request.url()); + } + }); + + try { + // # Add an attribute and exercise invalid identifier inputs in the table. + // Use lastNameInput() — not positional nameInput(0) — so a leftover row from + // a prior test attempt (or a concurrent UAAE/ABAC suite) cannot shift the + // index and cause us to rename someone else's field instead of populating + // the row we just added. + await sp.goto(); + await sp.addAttribute(); + + const invalidIdentifiers = ['in', 'true', 'for']; + for (const invalidIdentifier of invalidIdentifiers) { + await sp.lastNameInput().fill(invalidIdentifier); + await sp.lastNameInput().blur(); + + // * Verify the warning appears and Save stays disabled before any POST + await expect(sp.identifierValidationError()).toHaveText(IDENTIFIER_VALIDATION_MESSAGE); + await expect(sp.saveButton).toBeDisabled(); + } + + expect(apiPosts).toHaveLength(0); + + // # Correct the identifier to a valid CEL-safe name and save it + await sp.lastNameInput().fill(validIdentifier); + await sp.lastNameInput().blur(); + + // * Verify the warning clears and the field can be created successfully + await expect(sp.identifierValidationError()).not.toBeVisible(); + await expect(sp.saveButton).toBeEnabled(); + + await sp.saveAndWaitForSettled(); + + const fields = await adminClient.getCustomProfileAttributeFields(); + const createdField = fields.find((field) => field.name === validIdentifier); + expect(createdField).toBeDefined(); + } finally { + // Look up the field by name rather than relying on a captured `createdField` + // — the assertions above can throw before that variable is assigned, and we + // still want to remove the server-side field so retries start from a clean slate. + const fields = await adminClient.getCustomProfileAttributeFields().catch(() => []); + const leftover = fields.find((field) => field.name === validIdentifier); + if (leftover) { + await adminClient.deleteCustomProfileAttributeField(leftover.id).catch(() => undefined); + } + } + }); +}); diff --git a/server/channels/api4/user.go b/server/channels/api4/user.go index aa55372b3bd..7a5aec2426d 100644 --- a/server/channels/api4/user.go +++ b/server/channels/api4/user.go @@ -76,6 +76,7 @@ func (api *API) InitUser() { api.BaseRoutes.UserByUsername.Handle("", api.APISessionRequired(getUserByUsername)).Methods(http.MethodGet) api.BaseRoutes.UserByEmail.Handle("", api.APISessionRequired(getUserByEmail)).Methods(http.MethodGet) + api.BaseRoutes.Users.Handle("/auth_data", api.APISessionRequired(getUserByAuthData)).Methods(http.MethodGet) api.BaseRoutes.User.Handle("/sessions", api.APISessionRequired(getSessions)).Methods(http.MethodGet) api.BaseRoutes.User.Handle("/sessions/revoke", api.APISessionRequired(revokeSession)).Methods(http.MethodPost) @@ -461,6 +462,61 @@ func getUserByEmail(c *Context, w http.ResponseWriter, r *http.Request) { } } +func getUserByAuthData(c *Context, w http.ResponseWriter, r *http.Request) { + if !c.IsSystemAdmin() { + c.SetPermissionError(model.PermissionManageSystem) + return + } + authData := r.URL.Query().Get("value") + if authData == "" { + c.SetInvalidParam("value") + return + } + if len(authData) > model.UserAuthDataMaxLength { + c.SetInvalidParam("value") + return + } + user, err := c.App.GetUserByAuthData(&authData) + if err != nil { + c.Err = err + return + } + + canSee, err2 := c.App.UserCanSeeOtherUser(c.AppContext, c.AppContext.Session().UserId, user.Id) + if err2 != nil { + c.Err = err2 + return + } + + if !canSee { + c.SetPermissionError(model.PermissionViewMembers) + return + } + + userTermsOfService, err := c.App.GetUserTermsOfService(user.Id) + if err != nil && err.StatusCode != http.StatusNotFound { + c.Err = err + return + } + + if userTermsOfService != nil { + user.TermsOfServiceId = userTermsOfService.TermsOfServiceId + user.TermsOfServiceCreateAt = userTermsOfService.CreateAt + } + + etag := user.Etag(*c.App.Config().PrivacySettings.ShowFullName, *c.App.Config().PrivacySettings.ShowEmailAddress) + + if c.HandleEtag(etag, "Get User", w, r) { + return + } + + c.App.SanitizeProfile(user, c.IsSystemAdmin()) + w.Header().Set(model.HeaderEtagServer, etag) + if jerr := json.NewEncoder(w).Encode(user); jerr != nil { + c.Logger.Warn("Error while writing response", mlog.Err(jerr)) + } +} + func getDefaultProfileImage(c *Context, w http.ResponseWriter, r *http.Request) { c.RequireUserId() if c.Err != nil { diff --git a/server/channels/api4/user_local.go b/server/channels/api4/user_local.go index 7fba2cb46cc..2747a92d619 100644 --- a/server/channels/api4/user_local.go +++ b/server/channels/api4/user_local.go @@ -40,6 +40,7 @@ func (api *API) InitUserLocal() { api.BaseRoutes.UserByUsername.Handle("", api.APILocal(localGetUserByUsername)).Methods(http.MethodGet) api.BaseRoutes.UserByEmail.Handle("", api.APILocal(localGetUserByEmail)).Methods(http.MethodGet) + api.BaseRoutes.Users.Handle("/auth_data", api.APILocal(localGetUserByAuthData)).Methods(http.MethodGet) api.BaseRoutes.Users.Handle("/tokens/revoke", api.APILocal(revokeUserAccessToken)).Methods(http.MethodPost) api.BaseRoutes.User.Handle("/tokens", api.APILocal(getUserAccessTokensForUser)).Methods(http.MethodGet) @@ -427,6 +428,46 @@ func localGetUserByEmail(c *Context, w http.ResponseWriter, r *http.Request) { } } +func localGetUserByAuthData(c *Context, w http.ResponseWriter, r *http.Request) { + authData := r.URL.Query().Get("value") + if authData == "" { + c.SetInvalidParam("value") + return + } + if len(authData) > model.UserAuthDataMaxLength { + c.SetInvalidParam("value") + return + } + user, err := c.App.GetUserByAuthData(&authData) + if err != nil { + c.Err = err + return + } + + userTermsOfService, err := c.App.GetUserTermsOfService(user.Id) + if err != nil && err.StatusCode != http.StatusNotFound { + c.Err = err + return + } + + if userTermsOfService != nil { + user.TermsOfServiceId = userTermsOfService.TermsOfServiceId + user.TermsOfServiceCreateAt = userTermsOfService.CreateAt + } + + etag := user.Etag(*c.App.Config().PrivacySettings.ShowFullName, *c.App.Config().PrivacySettings.ShowEmailAddress) + + if c.HandleEtag(etag, "Get User", w, r) { + return + } + + c.App.SanitizeProfile(user, c.IsSystemAdmin()) + w.Header().Set(model.HeaderEtagServer, etag) + if jerr := json.NewEncoder(w).Encode(user); jerr != nil { + c.Logger.Warn("Error while writing response", mlog.Err(jerr)) + } +} + func localGetUploadsForUser(c *Context, w http.ResponseWriter, r *http.Request) { uss, appErr := c.App.GetUploadSessionsForUser(c.Params.UserId) if appErr != nil { diff --git a/server/channels/api4/user_test.go b/server/channels/api4/user_test.go index 82afeea8db1..127aa75a9ed 100644 --- a/server/channels/api4/user_test.go +++ b/server/channels/api4/user_test.go @@ -1487,6 +1487,169 @@ func TestGetUserByEmail(t *testing.T) { }) } +func TestGetUserByAuthData(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t) + + team := th.CreateTeamWithClient(t, th.SystemAdminClient) + regularUser := th.CreateUser(t) + th.LinkUserToTeam(t, regularUser, team) + user := th.CreateUser(t) + th.LinkUserToTeam(t, user, team) + _, err := th.App.Srv().Store().User().VerifyEmail(user.Id, user.Email) + require.NoError(t, err) + + authID := "extid-" + model.NewId() + userAuth := &model.UserAuth{ + AuthData: model.NewPointer(authID), + AuthService: model.UserAuthServiceSaml, + } + _, _, err = th.SystemAdminClient.UpdateUserAuth(context.Background(), user.Id, userAuth) + require.NoError(t, err) + + th.TestForSystemAdminAndLocal(t, func(t *testing.T, client *model.Client4) { + t.Run("returns user and auth fields for system admin and local", func(t *testing.T) { + ruser, resp, getErr := client.GetUserByAuthData(context.Background(), authID, "") + require.NoError(t, getErr) + require.Equal(t, user.Id, ruser.Id) + require.NotNil(t, ruser.AuthData) + require.Equal(t, authID, *ruser.AuthData) + require.Equal(t, model.UserAuthServiceSaml, ruser.AuthService) + ruser, resp, _ = client.GetUserByAuthData(context.Background(), authID, resp.Etag) + CheckEtag(t, ruser, resp) + }) + + t.Run("not found returns not found", func(t *testing.T) { + _, resp, notFoundErr := client.GetUserByAuthData(context.Background(), "nope-"+model.NewId(), "") + require.Error(t, notFoundErr) + CheckNotFoundStatus(t, resp) + }) + }) + + t.Run("returns accepted terms of service for system admin", func(t *testing.T) { + tos, appErr := th.App.CreateTermsOfService("Dummy TOS", user.Id) + require.Nil(t, appErr) + appErr = th.App.SaveUserTermsOfService(user.Id, tos.Id, true) + require.Nil(t, appErr) + + ruser, _, getErr := th.SystemAdminClient.GetUserByAuthData(context.Background(), authID, "") + require.NoError(t, getErr) + require.Equal(t, tos.Id, ruser.TermsOfServiceId, "Terms of service ID should match") + require.NotZero(t, ruser.TermsOfServiceCreateAt, "Terms of service CreateAt should be populated") + }) + + t.Run("returns user when auth_data is an email-shaped value", func(t *testing.T) { + // ResetAuthDataToEmailForUsers sets AuthData = Email for whole batches of + // users, so email-shaped auth_data values are common in practice. Verify + // the route, Client4 path escaping (`@` -> `%40`), and server-side decoding + // all round-trip correctly. + emailUser := th.CreateUser(t) + th.LinkUserToTeam(t, emailUser, team) + emailAuth := "user-" + model.NewId() + "@example.com" + _, _, updErr := th.SystemAdminClient.UpdateUserAuth(context.Background(), emailUser.Id, &model.UserAuth{ + AuthData: model.NewPointer(emailAuth), + AuthService: model.UserAuthServiceSaml, + }) + require.NoError(t, updErr) + + ruser, _, getErr := th.SystemAdminClient.GetUserByAuthData(context.Background(), emailAuth, "") + require.NoError(t, getErr) + require.Equal(t, emailUser.Id, ruser.Id) + require.NotNil(t, ruser.AuthData) + require.Equal(t, emailAuth, *ruser.AuthData) + }) + + t.Run("preserves case in auth_data", func(t *testing.T) { + // auth_data is opaque and case-sensitive (unlike email, which the email + // endpoint lowercases). Non-SAML IdPs commonly issue mixed-case identifiers, + // so guard against a regression where the handler normalizes the value. + mixedUser := th.CreateUser(t) + th.LinkUserToTeam(t, mixedUser, team) + mixedAuth := "MixedCase-" + model.NewId() + "@Example.COM" + _, _, updErr := th.SystemAdminClient.UpdateUserAuth(context.Background(), mixedUser.Id, &model.UserAuth{ + AuthData: model.NewPointer(mixedAuth), + AuthService: model.UserAuthServiceSaml, + }) + require.NoError(t, updErr) + + ruser, _, getErr := th.SystemAdminClient.GetUserByAuthData(context.Background(), mixedAuth, "") + require.NoError(t, getErr) + require.Equal(t, mixedUser.Id, ruser.Id) + require.NotNil(t, ruser.AuthData) + require.Equal(t, mixedAuth, *ruser.AuthData) + + _, resp, lowerErr := th.SystemAdminClient.GetUserByAuthData(context.Background(), strings.ToLower(mixedAuth), "") + require.Error(t, lowerErr) + CheckNotFoundStatus(t, resp) + }) + + t.Run("returns user when auth_data is an LDAP objectGUID hex-escape form", func(t *testing.T) { + // AD objectGUID stored under auth_service=ldap uses the LDAP filter + // hex-escape form (`\xx` per byte). Backslashes are special in URL paths + // (WHATWG rewrites them to `/`), which is why this endpoint uses a query + // parameter; this test guards the query-string round-trip for the exact + // shape the customer reported. + ldapUser := th.CreateUser(t) + th.LinkUserToTeam(t, ldapUser, team) + ldapAuth := `\61\14\e1\d1\c5\35\18\4a\b6\60\d6\78\50\fd\0d\5d` + _, _, updErr := th.SystemAdminClient.UpdateUserAuth(context.Background(), ldapUser.Id, &model.UserAuth{ + AuthData: model.NewPointer(ldapAuth), + AuthService: model.UserAuthServiceLdap, + }) + require.NoError(t, updErr) + + ruser, _, getErr := th.SystemAdminClient.GetUserByAuthData(context.Background(), ldapAuth, "") + require.NoError(t, getErr) + require.Equal(t, ldapUser.Id, ruser.Id) + require.NotNil(t, ruser.AuthData) + require.Equal(t, ldapAuth, *ruser.AuthData) + }) + + t.Run("returns user when auth_data is SAML base64 with reserved chars", func(t *testing.T) { + // AD objectGUID stored under auth_service=saml uses standard Base64, + // which can contain `+`, `/`, and `=` padding -- all reserved in + // application/x-www-form-urlencoded. url.Values.Set escapes them + // correctly; this test guards against a future regression where someone + // rewrites the client to skip that escaping. + samlUser := th.CreateUser(t) + th.LinkUserToTeam(t, samlUser, team) + // Bytes chosen to produce all three reserved characters in the Base64 + // output: 0xfb,0xef,0xff,0x00 -> "++//AA==". + samlAuth := base64.StdEncoding.EncodeToString([]byte{0xfb, 0xef, 0xff, 0x00}) + require.Contains(t, samlAuth, "+") + require.Contains(t, samlAuth, "/") + require.Contains(t, samlAuth, "=") + _, _, updErr := th.SystemAdminClient.UpdateUserAuth(context.Background(), samlUser.Id, &model.UserAuth{ + AuthData: model.NewPointer(samlAuth), + AuthService: model.UserAuthServiceSaml, + }) + require.NoError(t, updErr) + + ruser, _, getErr := th.SystemAdminClient.GetUserByAuthData(context.Background(), samlAuth, "") + require.NoError(t, getErr) + require.Equal(t, samlUser.Id, ruser.Id) + require.NotNil(t, ruser.AuthData) + require.Equal(t, samlAuth, *ruser.AuthData) + }) + + t.Run("rejects non-system admin", func(t *testing.T) { + // `user` is converted to SAML below and can no longer use password login; use + // a separate team member to assert the endpoint requires a system admin. + _, _, err = th.Client.Login(context.Background(), regularUser.Email, regularUser.Password) + require.NoError(t, err) + _, resp, err := th.Client.GetUserByAuthData(context.Background(), authID, "") + require.Error(t, err) + CheckForbiddenStatus(t, resp) + }) + + t.Run("rejects auth data over max length", func(t *testing.T) { + longData := strings.Repeat("x", model.UserAuthDataMaxLength+1) + _, resp, err := th.SystemAdminClient.GetUserByAuthData(context.Background(), longData, "") + require.Error(t, err) + CheckBadRequestStatus(t, resp) + }) +} + // This test can flake if two calls to model.NewId can return the same value. // Not much can be done about it. func TestSearchUsers(t *testing.T) { diff --git a/server/channels/app/app_test.go b/server/channels/app/app_test.go index 0ebfc035610..03792e4acf9 100644 --- a/server/channels/app/app_test.go +++ b/server/channels/app/app_test.go @@ -151,6 +151,8 @@ func TestDoAdvancedPermissionsMigration(t *testing.T) { model.PermissionManageChannelAccessRules.Id, model.PermissionManagePublicChannelAutoTranslation.Id, model.PermissionManagePrivateChannelAutoTranslation.Id, + model.PermissionManagePrivateChannelDiscoverability.Id, + model.PermissionManageChannelJoinRequests.Id, }, "team_user": { model.PermissionListTeamChannels.Id, diff --git a/server/channels/app/permissions_migrations.go b/server/channels/app/permissions_migrations.go index 0a7e788b5c2..83444fe4c6a 100644 --- a/server/channels/app/permissions_migrations.go +++ b/server/channels/app/permissions_migrations.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/v8/channels/store" "github.com/mattermost/mattermost/server/v8/channels/store/sqlstore" ) @@ -1325,6 +1326,22 @@ func (a *App) getAddEditFileAttachmentPermissionMigration() (permissionsMap, err }, nil } +func (a *App) getAddDiscoverableChannelPermissionsMigration() (permissionsMap, error) { + return permissionsMap{ + permissionTransformation{ + On: permissionOr( + isRole(model.ChannelAdminRoleId), + isRole(model.TeamAdminRoleId), + isRole(model.SystemAdminRoleId), + ), + Add: []string{ + model.PermissionManagePrivateChannelDiscoverability.Id, + model.PermissionManageChannelJoinRequests.Id, + }, + }, + }, nil +} + // DoPermissionsMigrations execute all the permissions migrations need by the current version. func (a *App) DoPermissionsMigrations() error { return a.Srv().doPermissionsMigrations() @@ -1387,6 +1404,7 @@ func (s *Server) doPermissionsMigrations() error { {Key: model.MigrationKeyRestoreManageOAuthPermission, Migration: a.getRestoreManageOAuthPermissionMigration}, {Key: model.MigrationKeyAddManageAgentPermissions, Migration: a.getAddManageAgentPermissionsMigration}, {Key: model.MigrationKeyAddEditFileAttachmentPermission, Migration: a.getAddEditFileAttachmentPermissionMigration}, + {Key: model.MigrationKeyAddDiscoverableChannelPermissions, Migration: a.getAddDiscoverableChannelPermissionsMigration}, } roles, err := s.Store().Role().GetAll() @@ -1400,6 +1418,7 @@ func (s *Server) doPermissionsMigrations() error { return err } if err := s.doPermissionsMigration(migration.Key, migMap, roles); err != nil { + mlog.Error("Failed to run permissions migration", mlog.String("key", migration.Key), mlog.Err(err)) return err } } diff --git a/server/channels/app/user.go b/server/channels/app/user.go index 932b039dd0c..325e3bdd1a1 100644 --- a/server/channels/app/user.go +++ b/server/channels/app/user.go @@ -606,6 +606,24 @@ func (a *App) GetUserByAuth(authData *string, authService string) (*model.User, return user, nil } +func (a *App) GetUserByAuthData(authData *string) (*model.User, *model.AppError) { + user, err := a.ch.srv.userService.GetUserByAuthData(authData) + if err != nil { + var invErr *store.ErrInvalidInput + var nfErr *store.ErrNotFound + switch { + case errors.As(err, &invErr): + return nil, model.NewAppError("GetUserByAuthData", MissingAccountError, nil, "", http.StatusBadRequest).Wrap(err) + case errors.As(err, &nfErr): + return nil, model.NewAppError("GetUserByAuthData", MissingAccountError, nil, "", http.StatusNotFound).Wrap(err) + default: + return nil, model.NewAppError("GetUserByAuthData", MissingAccountError, nil, "", http.StatusInternalServerError).Wrap(err) + } + } + + return user, nil +} + func (a *App) GetUsersFromProfiles(options *model.UserGetOptions) ([]*model.User, *model.AppError) { users, err := a.ch.srv.userService.GetUsersFromProfiles(options) if err != nil { diff --git a/server/channels/app/users/users.go b/server/channels/app/users/users.go index 900184f781e..c81c0f8ccec 100644 --- a/server/channels/app/users/users.go +++ b/server/channels/app/users/users.go @@ -114,6 +114,10 @@ func (us *UserService) GetUserByAuth(authData *string, authService string) (*mod return us.store.GetByAuth(authData, authService) } +func (us *UserService) GetUserByAuthData(authData *string) (*model.User, error) { + return us.store.GetByAuthData(authData) +} + func (us *UserService) GetUsersFromProfiles(options *model.UserGetOptions) ([]*model.User, error) { return us.store.GetAllProfiles(options) } diff --git a/server/channels/db/migrations/migrations.list b/server/channels/db/migrations/migrations.list index 72f6bccb373..04fc52052f5 100644 --- a/server/channels/db/migrations/migrations.list +++ b/server/channels/db/migrations/migrations.list @@ -351,3 +351,15 @@ channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.down.sql channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.up.sql channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.down.sql channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.up.sql +channels/db/migrations/postgres/000178_add_discoverable_to_channels.down.sql +channels/db/migrations/postgres/000178_add_discoverable_to_channels.up.sql +channels/db/migrations/postgres/000179_add_channels_discoverable_index.down.sql +channels/db/migrations/postgres/000179_add_channels_discoverable_index.up.sql +channels/db/migrations/postgres/000180_create_channel_join_requests.down.sql +channels/db/migrations/postgres/000180_create_channel_join_requests.up.sql +channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.down.sql +channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.up.sql +channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.down.sql +channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.up.sql +channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.down.sql +channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.up.sql diff --git a/server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.down.sql b/server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.down.sql new file mode 100644 index 00000000000..98788019071 --- /dev/null +++ b/server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.down.sql @@ -0,0 +1 @@ +ALTER TABLE Channels DROP COLUMN IF EXISTS Discoverable; diff --git a/server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.up.sql b/server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.up.sql new file mode 100644 index 00000000000..dce84a520df --- /dev/null +++ b/server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.up.sql @@ -0,0 +1 @@ +ALTER TABLE Channels ADD COLUMN IF NOT EXISTS Discoverable BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.down.sql b/server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.down.sql new file mode 100644 index 00000000000..d3d4d6b3545 --- /dev/null +++ b/server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.down.sql @@ -0,0 +1,2 @@ +-- morph:nontransactional +DROP INDEX CONCURRENTLY IF EXISTS idx_channels_discoverable_team; diff --git a/server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.up.sql b/server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.up.sql new file mode 100644 index 00000000000..b838ee35a52 --- /dev/null +++ b/server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.up.sql @@ -0,0 +1,4 @@ +-- morph:nontransactional +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_channels_discoverable_team + ON Channels (TeamId) + WHERE Discoverable = true AND Type = 'P' AND DeleteAt = 0; diff --git a/server/channels/db/migrations/postgres/000180_create_channel_join_requests.down.sql b/server/channels/db/migrations/postgres/000180_create_channel_join_requests.down.sql new file mode 100644 index 00000000000..8c692c6b8e5 --- /dev/null +++ b/server/channels/db/migrations/postgres/000180_create_channel_join_requests.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ChannelJoinRequests; diff --git a/server/channels/db/migrations/postgres/000180_create_channel_join_requests.up.sql b/server/channels/db/migrations/postgres/000180_create_channel_join_requests.up.sql new file mode 100644 index 00000000000..1e8076dab23 --- /dev/null +++ b/server/channels/db/migrations/postgres/000180_create_channel_join_requests.up.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS ChannelJoinRequests ( + Id VARCHAR(26) PRIMARY KEY, + ChannelId VARCHAR(26) NOT NULL, + UserId VARCHAR(26) NOT NULL, + Message TEXT NOT NULL DEFAULT '', + Status VARCHAR(16) NOT NULL DEFAULT 'pending', + DenialReason TEXT NOT NULL DEFAULT '', + CreateAt BIGINT NOT NULL, + UpdateAt BIGINT NOT NULL, + ReviewedBy VARCHAR(26) NOT NULL DEFAULT '', + ReviewedAt BIGINT NOT NULL DEFAULT 0 +); diff --git a/server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.down.sql b/server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.down.sql new file mode 100644 index 00000000000..ca606fc8a74 --- /dev/null +++ b/server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.down.sql @@ -0,0 +1,2 @@ +-- morph:nontransactional +DROP INDEX CONCURRENTLY IF EXISTS idx_channeljoinrequests_pending_unique; diff --git a/server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.up.sql b/server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.up.sql new file mode 100644 index 00000000000..d2317fecf77 --- /dev/null +++ b/server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.up.sql @@ -0,0 +1,4 @@ +-- morph:nontransactional +CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS idx_channeljoinrequests_pending_unique + ON ChannelJoinRequests (ChannelId, UserId) + WHERE Status = 'pending'; diff --git a/server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.down.sql b/server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.down.sql new file mode 100644 index 00000000000..f5a3ed0da5a --- /dev/null +++ b/server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.down.sql @@ -0,0 +1,2 @@ +-- morph:nontransactional +DROP INDEX CONCURRENTLY IF EXISTS idx_channeljoinrequests_channel_status_createat; diff --git a/server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.up.sql b/server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.up.sql new file mode 100644 index 00000000000..dbaf927fbb5 --- /dev/null +++ b/server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.up.sql @@ -0,0 +1,3 @@ +-- morph:nontransactional +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_channeljoinrequests_channel_status_createat + ON ChannelJoinRequests (ChannelId, Status, CreateAt DESC); diff --git a/server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.down.sql b/server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.down.sql new file mode 100644 index 00000000000..134bcc459f5 --- /dev/null +++ b/server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.down.sql @@ -0,0 +1,2 @@ +-- morph:nontransactional +DROP INDEX CONCURRENTLY IF EXISTS idx_channeljoinrequests_user_status_createat; diff --git a/server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.up.sql b/server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.up.sql new file mode 100644 index 00000000000..73271ffa61e --- /dev/null +++ b/server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.up.sql @@ -0,0 +1,3 @@ +-- morph:nontransactional +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_channeljoinrequests_user_status_createat + ON ChannelJoinRequests (UserId, Status, CreateAt DESC); diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index f1b8fe9950d..938cd91d1c7 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -27,6 +27,7 @@ type RetryLayer struct { BotStore store.BotStore ChannelStore store.ChannelStore ChannelBookmarkStore store.ChannelBookmarkStore + ChannelJoinRequestStore store.ChannelJoinRequestStore ChannelMemberHistoryStore store.ChannelMemberHistoryStore ClusterDiscoveryStore store.ClusterDiscoveryStore CommandStore store.CommandStore @@ -107,6 +108,10 @@ func (s *RetryLayer) ChannelBookmark() store.ChannelBookmarkStore { return s.ChannelBookmarkStore } +func (s *RetryLayer) ChannelJoinRequest() store.ChannelJoinRequestStore { + return s.ChannelJoinRequestStore +} + func (s *RetryLayer) ChannelMemberHistory() store.ChannelMemberHistoryStore { return s.ChannelMemberHistoryStore } @@ -342,6 +347,11 @@ type RetryLayerChannelBookmarkStore struct { Root *RetryLayer } +type RetryLayerChannelJoinRequestStore struct { + store.ChannelJoinRequestStore + Root *RetryLayer +} + type RetryLayerChannelMemberHistoryStore struct { store.ChannelMemberHistoryStore Root *RetryLayer @@ -3858,6 +3868,153 @@ func (s *RetryLayerChannelBookmarkStore) UpdateSortOrder(bookmarkID string, chan } +func (s *RetryLayerChannelJoinRequestStore) CountPending(channelId string) (int64, error) { + + tries := 0 + for { + result, err := s.ChannelJoinRequestStore.CountPending(channelId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelJoinRequestStore) Get(id string) (*model.ChannelJoinRequest, error) { + + tries := 0 + for { + result, err := s.ChannelJoinRequestStore.Get(id) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelJoinRequestStore) GetForChannel(channelId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + + tries := 0 + for { + result, resultVar1, err := s.ChannelJoinRequestStore.GetForChannel(channelId, opts) + if err == nil { + return result, resultVar1, nil + } + if !isRepeatableError(err) { + return result, resultVar1, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, resultVar1, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelJoinRequestStore) GetForUser(userId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + + tries := 0 + for { + result, resultVar1, err := s.ChannelJoinRequestStore.GetForUser(userId, opts) + if err == nil { + return result, resultVar1, nil + } + if !isRepeatableError(err) { + return result, resultVar1, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, resultVar1, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelJoinRequestStore) GetPendingForChannelAndUser(channelId string, userId string) (*model.ChannelJoinRequest, error) { + + tries := 0 + for { + result, err := s.ChannelJoinRequestStore.GetPendingForChannelAndUser(channelId, userId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelJoinRequestStore) Save(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + + tries := 0 + for { + result, err := s.ChannelJoinRequestStore.Save(req) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelJoinRequestStore) Update(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + + tries := 0 + for { + result, err := s.ChannelJoinRequestStore.Update(req) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerChannelMemberHistoryStore) DeleteOrphanedRows(limit int) (int64, error) { tries := 0 @@ -16149,6 +16306,27 @@ func (s *RetryLayerUserStore) GetByAuth(authData *string, authService string) (* } +func (s *RetryLayerUserStore) GetByAuthData(authData *string) (*model.User, error) { + + tries := 0 + for { + result, err := s.UserStore.GetByAuthData(authData) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerUserStore) GetByEmail(email string) (*model.User, error) { tries := 0 @@ -18404,6 +18582,7 @@ func New(childStore store.Store) *RetryLayer { newStore.BotStore = &RetryLayerBotStore{BotStore: childStore.Bot(), Root: &newStore} newStore.ChannelStore = &RetryLayerChannelStore{ChannelStore: childStore.Channel(), Root: &newStore} newStore.ChannelBookmarkStore = &RetryLayerChannelBookmarkStore{ChannelBookmarkStore: childStore.ChannelBookmark(), Root: &newStore} + newStore.ChannelJoinRequestStore = &RetryLayerChannelJoinRequestStore{ChannelJoinRequestStore: childStore.ChannelJoinRequest(), Root: &newStore} newStore.ChannelMemberHistoryStore = &RetryLayerChannelMemberHistoryStore{ChannelMemberHistoryStore: childStore.ChannelMemberHistory(), Root: &newStore} newStore.ClusterDiscoveryStore = &RetryLayerClusterDiscoveryStore{ClusterDiscoveryStore: childStore.ClusterDiscovery(), Root: &newStore} newStore.CommandStore = &RetryLayerCommandStore{CommandStore: childStore.Command(), Root: &newStore} diff --git a/server/channels/store/retrylayer/retrylayer_test.go b/server/channels/store/retrylayer/retrylayer_test.go index 9c1e08dcfd9..7cb965e53b3 100644 --- a/server/channels/store/retrylayer/retrylayer_test.go +++ b/server/channels/store/retrylayer/retrylayer_test.go @@ -74,6 +74,7 @@ func genStore() *mocks.Store { mock.On("Recap").Return(&mocks.RecapStore{}) mock.On("TemporaryPost").Return(&mocks.TemporaryPostStore{}) mock.On("View").Return(&mocks.ViewStore{}) + mock.On("ChannelJoinRequest").Return(&mocks.ChannelJoinRequestStore{}) return mock } diff --git a/server/channels/store/sqlstore/channel_join_request_store.go b/server/channels/store/sqlstore/channel_join_request_store.go new file mode 100644 index 00000000000..700639fa70a --- /dev/null +++ b/server/channels/store/sqlstore/channel_join_request_store.go @@ -0,0 +1,244 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + "database/sql" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/channels/store" + sq "github.com/mattermost/squirrel" + "github.com/pkg/errors" +) + +const channelJoinRequestsTable = "ChannelJoinRequests" + +var channelJoinRequestColumns = []string{ + "Id", + "ChannelId", + "UserId", + "Message", + "Status", + "DenialReason", + "CreateAt", + "UpdateAt", + "ReviewedBy", + "ReviewedAt", +} + +type SqlChannelJoinRequestStore struct { + *SqlStore + + selectQuery sq.SelectBuilder +} + +func newSqlChannelJoinRequestStore(sqlStore *SqlStore) store.ChannelJoinRequestStore { + s := &SqlChannelJoinRequestStore{SqlStore: sqlStore} + s.selectQuery = s.getQueryBuilder(). + Select(channelJoinRequestColumns...). + From(channelJoinRequestsTable) + return s +} + +func (s *SqlChannelJoinRequestStore) toMap(r *model.ChannelJoinRequest) map[string]any { + return map[string]any{ + "Id": r.Id, + "ChannelId": r.ChannelId, + "UserId": r.UserId, + "Message": r.Message, + "Status": r.Status, + "DenialReason": r.DenialReason, + "CreateAt": r.CreateAt, + "UpdateAt": r.UpdateAt, + "ReviewedBy": r.ReviewedBy, + "ReviewedAt": r.ReviewedAt, + } +} + +// Save inserts a new join request. The partial unique index in Postgres +// (channelid, userid) WHERE status = 'pending' enforces at-most-one pending +// row per (channel, user). On conflict we translate the unique-violation into +// a store.ErrConflict so the app layer can return 409. +func (s *SqlChannelJoinRequestStore) Save(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + req.PreSave() + + if err := req.IsValid(); err != nil { + return nil, err + } + + query := s.getQueryBuilder(). + Insert(channelJoinRequestsTable). + SetMap(s.toMap(req)) + + if _, err := s.GetMaster().ExecBuilder(query); err != nil { + if IsUniqueConstraintError(err, []string{"idx_channeljoinrequests_pending_unique"}) { + return nil, store.NewErrConflict("ChannelJoinRequest", err, "channel_id="+req.ChannelId+" user_id="+req.UserId) + } + return nil, errors.Wrap(err, "failed to save ChannelJoinRequest") + } + + return req, nil +} + +func (s *SqlChannelJoinRequestStore) Get(id string) (*model.ChannelJoinRequest, error) { + var req model.ChannelJoinRequest + query := s.selectQuery.Where(sq.Eq{"Id": id}) + + if err := s.GetReplica().GetBuilder(&req, query); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("ChannelJoinRequest", id) + } + return nil, errors.Wrapf(err, "failed to get ChannelJoinRequest with id=%s", id) + } + + return &req, nil +} + +func (s *SqlChannelJoinRequestStore) GetPendingForChannelAndUser(channelId, userId string) (*model.ChannelJoinRequest, error) { + var req model.ChannelJoinRequest + query := s.selectQuery.Where(sq.Eq{ + "ChannelId": channelId, + "UserId": userId, + "Status": model.ChannelJoinRequestStatusPending, + }) + + if err := s.GetReplica().GetBuilder(&req, query); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("ChannelJoinRequest", "channel_id="+channelId+" user_id="+userId) + } + return nil, errors.Wrapf(err, "failed to get pending ChannelJoinRequest for channel_id=%s user_id=%s", channelId, userId) + } + + return &req, nil +} + +// applyStatusFilter applies the opts.Status filter (defaulting to pending if empty) +// to both the select and count queries. Returning the two filtered builders keeps +// list and count perfectly in sync. +func applyJoinRequestStatusFilter(opts model.GetChannelJoinRequestsOpts) sq.Eq { + status := opts.Status + if status == "" { + status = model.ChannelJoinRequestStatusPending + } + return sq.Eq{"Status": status} +} + +func paginate(opts model.GetChannelJoinRequestsOpts) (limit, offset uint64) { + perPage := opts.PerPage + if perPage <= 0 { + perPage = 60 + } + page := max(opts.Page, 0) + return uint64(perPage), uint64(page) * uint64(perPage) +} + +func (s *SqlChannelJoinRequestStore) GetForChannel(channelId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + where := sq.And{sq.Eq{"ChannelId": channelId}, applyJoinRequestStatusFilter(opts)} + + limit, offset := paginate(opts) + listQuery := s.selectQuery. + Where(where). + OrderBy("CreateAt DESC", "Id DESC"). + Limit(limit). + Offset(offset) + + var rows []*model.ChannelJoinRequest + if err := s.GetReplica().SelectBuilder(&rows, listQuery); err != nil { + return nil, 0, errors.Wrapf(err, "failed to list ChannelJoinRequests for channel_id=%s", channelId) + } + + countQuery := s.getQueryBuilder(). + Select("COUNT(*)"). + From(channelJoinRequestsTable). + Where(where) + + var total int64 + if err := s.GetReplica().GetBuilder(&total, countQuery); err != nil { + return nil, 0, errors.Wrapf(err, "failed to count ChannelJoinRequests for channel_id=%s", channelId) + } + + return rows, total, nil +} + +func (s *SqlChannelJoinRequestStore) GetForUser(userId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + where := sq.And{sq.Eq{"UserId": userId}, applyJoinRequestStatusFilter(opts)} + + limit, offset := paginate(opts) + listQuery := s.selectQuery. + Where(where). + OrderBy("CreateAt DESC", "Id DESC"). + Limit(limit). + Offset(offset) + + var rows []*model.ChannelJoinRequest + if err := s.GetReplica().SelectBuilder(&rows, listQuery); err != nil { + return nil, 0, errors.Wrapf(err, "failed to list ChannelJoinRequests for user_id=%s", userId) + } + + countQuery := s.getQueryBuilder(). + Select("COUNT(*)"). + From(channelJoinRequestsTable). + Where(where) + + var total int64 + if err := s.GetReplica().GetBuilder(&total, countQuery); err != nil { + return nil, 0, errors.Wrapf(err, "failed to count ChannelJoinRequests for user_id=%s", userId) + } + + return rows, total, nil +} + +// Update writes the mutable fields back. Id/ChannelId/UserId/CreateAt are +// immutable post-create — the partial-unique index relies on (ChannelId, UserId) +// being stable for the lifetime of a row. +func (s *SqlChannelJoinRequestStore) Update(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + req.PreUpdate() + + if err := req.IsValid(); err != nil { + return nil, err + } + + query := s.getQueryBuilder(). + Update(channelJoinRequestsTable). + SetMap(map[string]any{ + "Status": req.Status, + "Message": req.Message, + "DenialReason": req.DenialReason, + "UpdateAt": req.UpdateAt, + "ReviewedBy": req.ReviewedBy, + "ReviewedAt": req.ReviewedAt, + }). + Where(sq.Eq{"Id": req.Id}) + + res, err := s.GetMaster().ExecBuilder(query) + if err != nil { + return nil, errors.Wrapf(err, "failed to update ChannelJoinRequest with id=%s", req.Id) + } + + n, err := res.RowsAffected() + if err != nil { + return nil, errors.Wrap(err, "failed to read RowsAffected on ChannelJoinRequest update") + } + if n == 0 { + return nil, store.NewErrNotFound("ChannelJoinRequest", req.Id) + } + + return req, nil +} + +func (s *SqlChannelJoinRequestStore) CountPending(channelId string) (int64, error) { + query := s.getQueryBuilder(). + Select("COUNT(*)"). + From(channelJoinRequestsTable). + Where(sq.Eq{ + "ChannelId": channelId, + "Status": model.ChannelJoinRequestStatusPending, + }) + + var count int64 + if err := s.GetReplica().GetBuilder(&count, query); err != nil { + return 0, errors.Wrapf(err, "failed to count pending ChannelJoinRequests for channel_id=%s", channelId) + } + return count, nil +} diff --git a/server/channels/store/sqlstore/channel_join_request_store_test.go b/server/channels/store/sqlstore/channel_join_request_store_test.go new file mode 100644 index 00000000000..bbbfdc8f52e --- /dev/null +++ b/server/channels/store/sqlstore/channel_join_request_store_test.go @@ -0,0 +1,14 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + "testing" + + "github.com/mattermost/mattermost/server/v8/channels/store/storetest" +) + +func TestChannelJoinRequestStore(t *testing.T) { + StoreTest(t, storetest.TestChannelJoinRequestStore) +} diff --git a/server/channels/store/sqlstore/channel_store.go b/server/channels/store/sqlstore/channel_store.go index 4691f73df3f..d98f215dacb 100644 --- a/server/channels/store/sqlstore/channel_store.go +++ b/server/channels/store/sqlstore/channel_store.go @@ -158,6 +158,7 @@ func channelSliceColumns(isSelect bool, prefix ...string) []string { p + "LastRootPostAt", p + "BannerInfo", p + "DefaultCategoryName", + p + "Discoverable", } if isSelect { @@ -196,6 +197,7 @@ func channelToSlice(channel *model.Channel) []any { channel.LastRootPostAt, channel.BannerInfo, channel.DefaultCategoryName, + channel.Discoverable, } } @@ -872,7 +874,8 @@ func (s SqlChannelStore) updateChannelT(transaction *sqlxTxWrapper, channel *mod LastRootPostAt=:LastRootPostAt, BannerInfo=:BannerInfo, DefaultCategoryName=:DefaultCategoryName, - AutoTranslation=:AutoTranslation + AutoTranslation=:AutoTranslation, + Discoverable=:Discoverable WHERE Id=:Id`, channel) if err != nil { if IsUniqueConstraintError(err, []string{"Name", "channels_name_teamid_key"}) { diff --git a/server/channels/store/sqlstore/role_store.go b/server/channels/store/sqlstore/role_store.go index c188cc1c8e1..c6ff9d8a061 100644 --- a/server/channels/store/sqlstore/role_store.go +++ b/server/channels/store/sqlstore/role_store.go @@ -101,8 +101,8 @@ func newSqlRoleStore(sqlStore *SqlStore) store.RoleStore { func (s *SqlRoleStore) Save(role *model.Role) (_ *model.Role, err error) { // Check the role is valid before proceeding. - if !role.IsValidWithoutId() { - return nil, store.NewErrInvalidInput("Role", "", fmt.Sprintf("%v", role)) + if err = role.IsValidWithoutId(); err != nil { + return nil, store.NewErrInvalidInput("Role", "", err.Error()) } if role.Id == "" { @@ -148,8 +148,8 @@ func (s *SqlRoleStore) Save(role *model.Role) (_ *model.Role, err error) { func (s *SqlRoleStore) createRole(role *model.Role, transaction *sqlxTxWrapper) (*model.Role, error) { // Check the role is valid before proceeding. - if !role.IsValidWithoutId() { - return nil, store.NewErrInvalidInput("Role", "", fmt.Sprintf("%v", role)) + if err := role.IsValidWithoutId(); err != nil { + return nil, store.NewErrInvalidInput("Role", "", err.Error()) } dbRole := NewRoleFromModel(role) diff --git a/server/channels/store/sqlstore/schema_dump_test.go b/server/channels/store/sqlstore/schema_dump_test.go index 4ca5ee3f200..a4b52ec2e62 100644 --- a/server/channels/store/sqlstore/schema_dump_test.go +++ b/server/channels/store/sqlstore/schema_dump_test.go @@ -72,7 +72,7 @@ func TestGetSchemaDefinition(t *testing.T) { if table.Name == "channels" { // Check that indexes are present assert.NotEmpty(t, table.Indexes, "channels table should have indexes") - assert.Equal(t, 12, len(table.Indexes), "channels table should have 12 indexes") + assert.Equal(t, 13, len(table.Indexes), "channels table should have 13 indexes") // Expected index definitions expectedIndexDefs := map[string]string{ @@ -88,6 +88,7 @@ func TestGetSchemaDefinition(t *testing.T) { "idx_channels_team_id_display_name": "CREATE INDEX idx_channels_team_id_display_name ON public.channels USING btree (teamid, displayname)", "idx_channels_team_id_type": "CREATE INDEX idx_channels_team_id_type ON public.channels USING btree (teamid, type)", "idx_channels_autotranslation_enabled": "CREATE INDEX idx_channels_autotranslation_enabled ON public.channels USING btree (id) WHERE (autotranslation = true)", + "idx_channels_discoverable_team": "CREATE INDEX idx_channels_discoverable_team ON public.channels USING btree (teamid) WHERE ((discoverable = true) AND (type = 'P'::channel_type) AND (deleteat = 0))", } // Verify all expected indexes are present with correct definitions diff --git a/server/channels/store/sqlstore/store.go b/server/channels/store/sqlstore/store.go index d909da831e6..0bf3b4a3200 100644 --- a/server/channels/store/sqlstore/store.go +++ b/server/channels/store/sqlstore/store.go @@ -117,6 +117,7 @@ type SqlStoreStores struct { recap store.RecapStore readReceipt store.ReadReceiptStore temporaryPost store.TemporaryPostStore + channelJoinRequest store.ChannelJoinRequestStore } type SqlStore struct { @@ -303,6 +304,7 @@ func New(settings model.SqlSettings, logger mlog.LoggerIFace, metrics einterface store.stores.recap = newSqlRecapStore(store) store.stores.readReceipt = newSqlReadReceiptStore(store, metrics) store.stores.temporaryPost = newSqlTemporaryPostStore(store, metrics) + store.stores.channelJoinRequest = newSqlChannelJoinRequestStore(store) store.stores.preference.(*SqlPreferenceStore).deleteUnusedFeatures() @@ -926,6 +928,10 @@ func (ss *SqlStore) TemporaryPost() store.TemporaryPostStore { return ss.stores.temporaryPost } +func (ss *SqlStore) ChannelJoinRequest() store.ChannelJoinRequestStore { + return ss.stores.channelJoinRequest +} + func (ss *SqlStore) DropAllTables() { ss.masterX.Exec(`DO $func$ diff --git a/server/channels/store/sqlstore/user_store.go b/server/channels/store/sqlstore/user_store.go index 7b11d1220e5..524bc917373 100644 --- a/server/channels/store/sqlstore/user_store.go +++ b/server/channels/store/sqlstore/user_store.go @@ -1281,6 +1281,27 @@ func (us SqlUserStore) GetByRemoteID(remoteID string) (*model.User, error) { return &user, nil } +func (us SqlUserStore) GetByAuthData(authData *string) (*model.User, error) { + if authData == nil || *authData == "" { + return nil, store.NewErrInvalidInput("User", "", "empty or nil") + } + + query := us.usersQuery.Where("Users.AuthData = ?", authData) + + queryString, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrap(err, "get_by_auth_data_tosql") + } + + user := model.User{} + if err := us.GetReplica().Get(&user, queryString, args...); err == sql.ErrNoRows { + return nil, store.NewErrNotFound("User", fmt.Sprintf("authData=%s", *authData)) + } else if err != nil { + return nil, errors.Wrapf(err, "failed to find User with authData=%s", *authData) + } + return &user, nil +} + func (us SqlUserStore) GetByAuth(authData *string, authService string) (*model.User, error) { if authData == nil || *authData == "" { return nil, store.NewErrInvalidInput("User", "", "empty or nil") diff --git a/server/channels/store/store.go b/server/channels/store/store.go index d09b52d8895..596af20bbeb 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -102,6 +102,7 @@ type Store interface { Recap() RecapStore ReadReceipt() ReadReceiptStore TemporaryPost() TemporaryPostStore + ChannelJoinRequest() ChannelJoinRequestStore } type RetentionPolicyStore interface { @@ -472,6 +473,7 @@ type UserStore interface { GetByEmail(email string) (*model.User, error) GetByRemoteID(remoteID string) (*model.User, error) GetByAuth(authData *string, authService string) (*model.User, error) + GetByAuthData(authData *string) (*model.User, error) GetAllUsingAuthService(authService string) ([]*model.User, error) GetAllNotInAuthService(authServices []string) ([]*model.User, error) GetByUsername(username string) (*model.User, error) @@ -1332,6 +1334,19 @@ type ThreadMembershipImportData struct { UnreadMentions int64 } +// ChannelJoinRequestStore persists user requests to join discoverable private +// channels. Rows are never deleted; status transitions are recorded with +// reviewer and timestamps so the table doubles as an audit trail. +type ChannelJoinRequestStore interface { + Save(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) + Get(id string) (*model.ChannelJoinRequest, error) + GetPendingForChannelAndUser(channelId, userId string) (*model.ChannelJoinRequest, error) + GetForChannel(channelId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) + GetForUser(userId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) + Update(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) + CountPending(channelId string) (int64, error) +} + type RecapStore interface { SaveRecap(recap *model.Recap) (*model.Recap, error) UpdateRecap(recap *model.Recap) (*model.Recap, error) diff --git a/server/channels/store/storetest/channel_join_request_store.go b/server/channels/store/storetest/channel_join_request_store.go new file mode 100644 index 00000000000..de716535742 --- /dev/null +++ b/server/channels/store/storetest/channel_join_request_store.go @@ -0,0 +1,264 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package storetest + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/mattermost/mattermost/server/v8/channels/store" +) + +func TestChannelJoinRequestStore(t *testing.T, _ request.CTX, ss store.Store) { + t.Run("Save inserts a pending row", testChannelJoinRequestSave(ss)) + t.Run("Save rejects duplicate pending row", testChannelJoinRequestSaveDuplicatePending(ss)) + t.Run("Save allows another pending row after withdrawal", testChannelJoinRequestSaveAfterWithdraw(ss)) + t.Run("Get returns NotFound for unknown id", testChannelJoinRequestGetNotFound(ss)) + t.Run("GetPendingForChannelAndUser only returns pending rows", testChannelJoinRequestGetPending(ss)) + t.Run("GetForChannel paginates and filters by status", testChannelJoinRequestGetForChannel(ss)) + t.Run("GetForUser paginates and filters by status", testChannelJoinRequestGetForUser(ss)) + t.Run("Update transitions status and stores reviewer", testChannelJoinRequestUpdate(ss)) + t.Run("CountPending only counts pending rows", testChannelJoinRequestCountPending(ss)) +} + +func newPendingRequest(channelId, userId string) *model.ChannelJoinRequest { + return &model.ChannelJoinRequest{ + ChannelId: channelId, + UserId: userId, + Message: "please let me in", + Status: model.ChannelJoinRequestStatusPending, + } +} + +func testChannelJoinRequestSave(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + channelId := model.NewId() + userId := model.NewId() + + req, err := ss.ChannelJoinRequest().Save(newPendingRequest(channelId, userId)) + require.NoError(t, err) + require.NotEmpty(t, req.Id) + assert.Equal(t, channelId, req.ChannelId) + assert.Equal(t, userId, req.UserId) + assert.Equal(t, model.ChannelJoinRequestStatusPending, req.Status) + assert.NotZero(t, req.CreateAt) + assert.Equal(t, req.CreateAt, req.UpdateAt) + + fetched, err := ss.ChannelJoinRequest().Get(req.Id) + require.NoError(t, err) + assert.Equal(t, req.Id, fetched.Id) + assert.Equal(t, req.Message, fetched.Message) + assert.Equal(t, req.Status, fetched.Status) + } +} + +func testChannelJoinRequestSaveDuplicatePending(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + channelId := model.NewId() + userId := model.NewId() + + _, err := ss.ChannelJoinRequest().Save(newPendingRequest(channelId, userId)) + require.NoError(t, err) + + _, err = ss.ChannelJoinRequest().Save(newPendingRequest(channelId, userId)) + require.Error(t, err) + var conflict *store.ErrConflict + assert.ErrorAs(t, err, &conflict, "duplicate pending row must surface store.ErrConflict") + } +} + +func testChannelJoinRequestSaveAfterWithdraw(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + channelId := model.NewId() + userId := model.NewId() + + first, err := ss.ChannelJoinRequest().Save(newPendingRequest(channelId, userId)) + require.NoError(t, err) + + first.Status = model.ChannelJoinRequestStatusWithdrawn + _, err = ss.ChannelJoinRequest().Update(first) + require.NoError(t, err) + + // Allow the millisecond-resolution UpdateAt to advance. + time.Sleep(2 * time.Millisecond) + + second, err := ss.ChannelJoinRequest().Save(newPendingRequest(channelId, userId)) + require.NoError(t, err, "a new pending row must be insertable once the previous one is no longer pending") + assert.NotEqual(t, first.Id, second.Id) + } +} + +func testChannelJoinRequestGetNotFound(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + _, err := ss.ChannelJoinRequest().Get(model.NewId()) + require.Error(t, err) + var nf *store.ErrNotFound + assert.ErrorAs(t, err, &nf) + } +} + +func testChannelJoinRequestGetPending(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + channelId := model.NewId() + userId := model.NewId() + + _, err := ss.ChannelJoinRequest().GetPendingForChannelAndUser(channelId, userId) + require.Error(t, err, "must return NotFound when no row exists") + + _, err = ss.ChannelJoinRequest().Save(newPendingRequest(channelId, userId)) + require.NoError(t, err) + + got, err := ss.ChannelJoinRequest().GetPendingForChannelAndUser(channelId, userId) + require.NoError(t, err) + assert.Equal(t, channelId, got.ChannelId) + assert.Equal(t, userId, got.UserId) + assert.Equal(t, model.ChannelJoinRequestStatusPending, got.Status) + + got.Status = model.ChannelJoinRequestStatusWithdrawn + _, err = ss.ChannelJoinRequest().Update(got) + require.NoError(t, err) + + _, err = ss.ChannelJoinRequest().GetPendingForChannelAndUser(channelId, userId) + require.Error(t, err, "withdrawn row must not be considered pending") + } +} + +func testChannelJoinRequestGetForChannel(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + channelId := model.NewId() + + // Three pending requests across distinct users + one denied row for the + // same channel so we can prove the status filter actually filters. + for range 3 { + _, err := ss.ChannelJoinRequest().Save(newPendingRequest(channelId, model.NewId())) + require.NoError(t, err) + time.Sleep(2 * time.Millisecond) + } + + denied := newPendingRequest(channelId, model.NewId()) + saved, err := ss.ChannelJoinRequest().Save(denied) + require.NoError(t, err) + saved.Status = model.ChannelJoinRequestStatusDenied + saved.ReviewedBy = model.NewId() + saved.ReviewedAt = model.GetMillis() + saved.DenialReason = "policy mismatch" + _, err = ss.ChannelJoinRequest().Update(saved) + require.NoError(t, err) + + rows, total, err := ss.ChannelJoinRequest().GetForChannel(channelId, model.GetChannelJoinRequestsOpts{PerPage: 10}) + require.NoError(t, err) + assert.Len(t, rows, 3) + assert.Equal(t, int64(3), total) + for i := 1; i < len(rows); i++ { + assert.GreaterOrEqual(t, rows[i-1].CreateAt, rows[i].CreateAt, "list should be newest first") + } + + rows, total, err = ss.ChannelJoinRequest().GetForChannel(channelId, model.GetChannelJoinRequestsOpts{Status: model.ChannelJoinRequestStatusDenied, PerPage: 10}) + require.NoError(t, err) + assert.Equal(t, int64(1), total) + require.Len(t, rows, 1) + assert.Equal(t, "policy mismatch", rows[0].DenialReason) + + rows, total, err = ss.ChannelJoinRequest().GetForChannel(channelId, model.GetChannelJoinRequestsOpts{PerPage: 2, Page: 0}) + require.NoError(t, err) + assert.Len(t, rows, 2) + assert.Equal(t, int64(3), total, "TotalCount must not be truncated by paging") + } +} + +func testChannelJoinRequestGetForUser(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + userId := model.NewId() + + for range 2 { + _, err := ss.ChannelJoinRequest().Save(newPendingRequest(model.NewId(), userId)) + require.NoError(t, err) + time.Sleep(2 * time.Millisecond) + } + + denied := newPendingRequest(model.NewId(), userId) + saved, err := ss.ChannelJoinRequest().Save(denied) + require.NoError(t, err) + saved.Status = model.ChannelJoinRequestStatusDenied + saved.ReviewedBy = model.NewId() + saved.ReviewedAt = model.GetMillis() + _, err = ss.ChannelJoinRequest().Update(saved) + require.NoError(t, err) + + rows, total, err := ss.ChannelJoinRequest().GetForUser(userId, model.GetChannelJoinRequestsOpts{PerPage: 10}) + require.NoError(t, err) + assert.Len(t, rows, 2) + assert.Equal(t, int64(2), total) + + rows, total, err = ss.ChannelJoinRequest().GetForUser(userId, model.GetChannelJoinRequestsOpts{Status: model.ChannelJoinRequestStatusDenied, PerPage: 10}) + require.NoError(t, err) + assert.Equal(t, int64(1), total) + assert.Len(t, rows, 1) + } +} + +func testChannelJoinRequestUpdate(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + req, err := ss.ChannelJoinRequest().Save(newPendingRequest(model.NewId(), model.NewId())) + require.NoError(t, err) + originalUpdateAt := req.UpdateAt + + reviewerId := model.NewId() + reviewedAt := model.GetMillis() + 1 + req.Status = model.ChannelJoinRequestStatusApproved + req.ReviewedBy = reviewerId + req.ReviewedAt = reviewedAt + + // Allow UpdateAt to advance. + time.Sleep(2 * time.Millisecond) + updated, err := ss.ChannelJoinRequest().Update(req) + require.NoError(t, err) + assert.Equal(t, model.ChannelJoinRequestStatusApproved, updated.Status) + assert.Equal(t, reviewerId, updated.ReviewedBy) + assert.Equal(t, reviewedAt, updated.ReviewedAt) + assert.Greater(t, updated.UpdateAt, originalUpdateAt) + + fetched, err := ss.ChannelJoinRequest().Get(req.Id) + require.NoError(t, err) + assert.Equal(t, model.ChannelJoinRequestStatusApproved, fetched.Status) + assert.Equal(t, reviewerId, fetched.ReviewedBy) + } +} + +func testChannelJoinRequestCountPending(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + channelId := model.NewId() + + count, err := ss.ChannelJoinRequest().CountPending(channelId) + require.NoError(t, err) + assert.Equal(t, int64(0), count) + + for range 4 { + _, err = ss.ChannelJoinRequest().Save(newPendingRequest(channelId, model.NewId())) + require.NoError(t, err) + } + + count, err = ss.ChannelJoinRequest().CountPending(channelId) + require.NoError(t, err) + assert.Equal(t, int64(4), count) + + // Withdraw one — count should drop by 1. + reqs, _, err := ss.ChannelJoinRequest().GetForChannel(channelId, model.GetChannelJoinRequestsOpts{PerPage: 10}) + require.NoError(t, err) + require.NotEmpty(t, reqs) + first := reqs[0] + first.Status = model.ChannelJoinRequestStatusWithdrawn + _, err = ss.ChannelJoinRequest().Update(first) + require.NoError(t, err) + + count, err = ss.ChannelJoinRequest().CountPending(channelId) + require.NoError(t, err) + assert.Equal(t, int64(3), count) + } +} diff --git a/server/channels/store/storetest/mocks/ChannelJoinRequestStore.go b/server/channels/store/storetest/mocks/ChannelJoinRequestStore.go new file mode 100644 index 00000000000..23d71a295a3 --- /dev/null +++ b/server/channels/store/storetest/mocks/ChannelJoinRequestStore.go @@ -0,0 +1,251 @@ +// Code generated by mockery v2.53.4. DO NOT EDIT. + +// Regenerate this file using `make store-mocks`. + +package mocks + +import ( + model "github.com/mattermost/mattermost/server/public/model" + mock "github.com/stretchr/testify/mock" +) + +// ChannelJoinRequestStore is an autogenerated mock type for the ChannelJoinRequestStore type +type ChannelJoinRequestStore struct { + mock.Mock +} + +// CountPending provides a mock function with given fields: channelId +func (_m *ChannelJoinRequestStore) CountPending(channelId string) (int64, error) { + ret := _m.Called(channelId) + + if len(ret) == 0 { + panic("no return value specified for CountPending") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(string) (int64, error)); ok { + return rf(channelId) + } + if rf, ok := ret.Get(0).(func(string) int64); ok { + r0 = rf(channelId) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(channelId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Get provides a mock function with given fields: id +func (_m *ChannelJoinRequestStore) Get(id string) (*model.ChannelJoinRequest, error) { + ret := _m.Called(id) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 *model.ChannelJoinRequest + var r1 error + if rf, ok := ret.Get(0).(func(string) (*model.ChannelJoinRequest, error)); ok { + return rf(id) + } + if rf, ok := ret.Get(0).(func(string) *model.ChannelJoinRequest); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ChannelJoinRequest) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetForChannel provides a mock function with given fields: channelId, opts +func (_m *ChannelJoinRequestStore) GetForChannel(channelId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + ret := _m.Called(channelId, opts) + + if len(ret) == 0 { + panic("no return value specified for GetForChannel") + } + + var r0 []*model.ChannelJoinRequest + var r1 int64 + var r2 error + if rf, ok := ret.Get(0).(func(string, model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error)); ok { + return rf(channelId, opts) + } + if rf, ok := ret.Get(0).(func(string, model.GetChannelJoinRequestsOpts) []*model.ChannelJoinRequest); ok { + r0 = rf(channelId, opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.ChannelJoinRequest) + } + } + + if rf, ok := ret.Get(1).(func(string, model.GetChannelJoinRequestsOpts) int64); ok { + r1 = rf(channelId, opts) + } else { + r1 = ret.Get(1).(int64) + } + + if rf, ok := ret.Get(2).(func(string, model.GetChannelJoinRequestsOpts) error); ok { + r2 = rf(channelId, opts) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// GetForUser provides a mock function with given fields: userId, opts +func (_m *ChannelJoinRequestStore) GetForUser(userId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + ret := _m.Called(userId, opts) + + if len(ret) == 0 { + panic("no return value specified for GetForUser") + } + + var r0 []*model.ChannelJoinRequest + var r1 int64 + var r2 error + if rf, ok := ret.Get(0).(func(string, model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error)); ok { + return rf(userId, opts) + } + if rf, ok := ret.Get(0).(func(string, model.GetChannelJoinRequestsOpts) []*model.ChannelJoinRequest); ok { + r0 = rf(userId, opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.ChannelJoinRequest) + } + } + + if rf, ok := ret.Get(1).(func(string, model.GetChannelJoinRequestsOpts) int64); ok { + r1 = rf(userId, opts) + } else { + r1 = ret.Get(1).(int64) + } + + if rf, ok := ret.Get(2).(func(string, model.GetChannelJoinRequestsOpts) error); ok { + r2 = rf(userId, opts) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// GetPendingForChannelAndUser provides a mock function with given fields: channelId, userId +func (_m *ChannelJoinRequestStore) GetPendingForChannelAndUser(channelId string, userId string) (*model.ChannelJoinRequest, error) { + ret := _m.Called(channelId, userId) + + if len(ret) == 0 { + panic("no return value specified for GetPendingForChannelAndUser") + } + + var r0 *model.ChannelJoinRequest + var r1 error + if rf, ok := ret.Get(0).(func(string, string) (*model.ChannelJoinRequest, error)); ok { + return rf(channelId, userId) + } + if rf, ok := ret.Get(0).(func(string, string) *model.ChannelJoinRequest); ok { + r0 = rf(channelId, userId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ChannelJoinRequest) + } + } + + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(channelId, userId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Save provides a mock function with given fields: req +func (_m *ChannelJoinRequestStore) Save(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + ret := _m.Called(req) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 *model.ChannelJoinRequest + var r1 error + if rf, ok := ret.Get(0).(func(*model.ChannelJoinRequest) (*model.ChannelJoinRequest, error)); ok { + return rf(req) + } + if rf, ok := ret.Get(0).(func(*model.ChannelJoinRequest) *model.ChannelJoinRequest); ok { + r0 = rf(req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ChannelJoinRequest) + } + } + + if rf, ok := ret.Get(1).(func(*model.ChannelJoinRequest) error); ok { + r1 = rf(req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Update provides a mock function with given fields: req +func (_m *ChannelJoinRequestStore) Update(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + ret := _m.Called(req) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 *model.ChannelJoinRequest + var r1 error + if rf, ok := ret.Get(0).(func(*model.ChannelJoinRequest) (*model.ChannelJoinRequest, error)); ok { + return rf(req) + } + if rf, ok := ret.Get(0).(func(*model.ChannelJoinRequest) *model.ChannelJoinRequest); ok { + r0 = rf(req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ChannelJoinRequest) + } + } + + if rf, ok := ret.Get(1).(func(*model.ChannelJoinRequest) error); ok { + r1 = rf(req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewChannelJoinRequestStore creates a new instance of ChannelJoinRequestStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewChannelJoinRequestStore(t interface { + mock.TestingT + Cleanup(func()) +}) *ChannelJoinRequestStore { + mock := &ChannelJoinRequestStore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/server/channels/store/storetest/mocks/Store.go b/server/channels/store/storetest/mocks/Store.go index 30f21da0292..2a9565d9fbd 100644 --- a/server/channels/store/storetest/mocks/Store.go +++ b/server/channels/store/storetest/mocks/Store.go @@ -5,16 +5,13 @@ package mocks import ( - mlog "github.com/mattermost/mattermost/server/public/shared/mlog" - mock "github.com/stretchr/testify/mock" - - model "github.com/mattermost/mattermost/server/public/model" - sql "database/sql" + time "time" + model "github.com/mattermost/mattermost/server/public/model" + mlog "github.com/mattermost/mattermost/server/public/shared/mlog" store "github.com/mattermost/mattermost/server/v8/channels/store" - - time "time" + mock "github.com/stretchr/testify/mock" ) // Store is an autogenerated mock type for the Store type @@ -162,6 +159,26 @@ func (_m *Store) ChannelBookmark() store.ChannelBookmarkStore { return r0 } +// ChannelJoinRequest provides a mock function with no fields +func (_m *Store) ChannelJoinRequest() store.ChannelJoinRequestStore { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ChannelJoinRequest") + } + + var r0 store.ChannelJoinRequestStore + if rf, ok := ret.Get(0).(func() store.ChannelJoinRequestStore); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.ChannelJoinRequestStore) + } + } + + return r0 +} + // ChannelMemberHistory provides a mock function with no fields func (_m *Store) ChannelMemberHistory() store.ChannelMemberHistoryStore { ret := _m.Called() diff --git a/server/channels/store/storetest/mocks/UserStore.go b/server/channels/store/storetest/mocks/UserStore.go index 0845e1e23ab..12026e7f4b6 100644 --- a/server/channels/store/storetest/mocks/UserStore.go +++ b/server/channels/store/storetest/mocks/UserStore.go @@ -655,6 +655,36 @@ func (_m *UserStore) GetByAuth(authData *string, authService string) (*model.Use return r0, r1 } +// GetByAuthData provides a mock function with given fields: authData +func (_m *UserStore) GetByAuthData(authData *string) (*model.User, error) { + ret := _m.Called(authData) + + if len(ret) == 0 { + panic("no return value specified for GetByAuthData") + } + + var r0 *model.User + var r1 error + if rf, ok := ret.Get(0).(func(*string) (*model.User, error)); ok { + return rf(authData) + } + if rf, ok := ret.Get(0).(func(*string) *model.User); ok { + r0 = rf(authData) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.User) + } + } + + if rf, ok := ret.Get(1).(func(*string) error); ok { + r1 = rf(authData) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetByEmail provides a mock function with given fields: email func (_m *UserStore) GetByEmail(email string) (*model.User, error) { ret := _m.Called(email) diff --git a/server/channels/store/storetest/store.go b/server/channels/store/storetest/store.go index be8bf0f1f14..5c54760f6d7 100644 --- a/server/channels/store/storetest/store.go +++ b/server/channels/store/storetest/store.go @@ -75,6 +75,7 @@ type Store struct { ReadReceiptStore mocks.ReadReceiptStore TemporaryPostStore mocks.TemporaryPostStore ViewStore mocks.ViewStore + ChannelJoinRequestStore mocks.ChannelJoinRequestStore } func (s *Store) Logger() mlog.LoggerIFace { return s.logger } @@ -180,6 +181,9 @@ func (s *Store) ReadReceipt() store.ReadReceiptStore { func (s *Store) TemporaryPost() store.TemporaryPostStore { return &s.TemporaryPostStore } +func (s *Store) ChannelJoinRequest() store.ChannelJoinRequestStore { + return &s.ChannelJoinRequestStore +} func (s *Store) View() store.ViewStore { return &s.ViewStore } @@ -239,5 +243,6 @@ func (s *Store) AssertExpectations(t mock.TestingT) bool { &s.ReadReceiptStore, &s.TemporaryPostStore, &s.ViewStore, + &s.ChannelJoinRequestStore, ) } diff --git a/server/channels/store/storetest/user_store.go b/server/channels/store/storetest/user_store.go index ec303b8c6c5..da4ef47f1a2 100644 --- a/server/channels/store/storetest/user_store.go +++ b/server/channels/store/storetest/user_store.go @@ -69,6 +69,7 @@ func TestUserStore(t *testing.T, rctx request.CTX, ss store.Store, s SqlStore) { t.Run("GetProfilesByUsernames", func(t *testing.T) { testUserStoreGetProfilesByUsernames(t, rctx, ss) }) t.Run("GetSystemAdminProfiles", func(t *testing.T) { testUserStoreGetSystemAdminProfiles(t, rctx, ss) }) t.Run("GetByEmail", func(t *testing.T) { testUserStoreGetByEmail(t, rctx, ss) }) + t.Run("GetByAuth", func(t *testing.T) { testUserStoreGetByAuth(t, rctx, ss) }) t.Run("GetByAuthData", func(t *testing.T) { testUserStoreGetByAuthData(t, rctx, ss) }) t.Run("GetByUsername", func(t *testing.T) { testUserStoreGetByUsername(t, rctx, ss) }) t.Run("GetForLogin", func(t *testing.T) { testUserStoreGetForLogin(t, rctx, ss) }) @@ -2104,7 +2105,7 @@ func testUserStoreGetByEmail(t *testing.T, rctx request.CTX, ss store.Store) { }) } -func testUserStoreGetByAuthData(t *testing.T, rctx request.CTX, ss store.Store) { +func testUserStoreGetByAuth(t *testing.T, rctx request.CTX, ss store.Store) { teamID := model.NewId() auth1 := model.NewId() auth3 := model.NewId() @@ -2167,23 +2168,130 @@ func testUserStoreGetByAuthData(t *testing.T, rctx request.CTX, ss store.Store) require.True(t, errors.As(err, &nfErr)) }) - t.Run("get by unknown auth, u1 service", func(t *testing.T) { - unknownAuth := "" + t.Run("get by unknown non-empty auth, u1 service", func(t *testing.T) { + unknownAuth := model.NewId() _, err := ss.User().GetByAuth(&unknownAuth, u1.AuthService) require.Error(t, err) + var nfErr *store.ErrNotFound + require.True(t, errors.As(err, &nfErr)) + }) + + t.Run("get by empty auth, u1 service", func(t *testing.T) { + emptyAuth := "" + _, err := ss.User().GetByAuth(&emptyAuth, u1.AuthService) + require.Error(t, err) var invErr *store.ErrInvalidInput require.True(t, errors.As(err, &invErr)) }) - t.Run("get by unknown auth, unknown service", func(t *testing.T) { - unknownAuth := "" + t.Run("get by nil auth, u1 service", func(t *testing.T) { + _, err := ss.User().GetByAuth(nil, u1.AuthService) + require.Error(t, err) + var invErr *store.ErrInvalidInput + require.True(t, errors.As(err, &invErr)) + }) + + t.Run("get by unknown non-empty auth, unknown service", func(t *testing.T) { + unknownAuth := model.NewId() _, err := ss.User().GetByAuth(&unknownAuth, "unknown") require.Error(t, err) + var nfErr *store.ErrNotFound + require.True(t, errors.As(err, &nfErr)) + }) + + t.Run("get by empty auth, unknown service", func(t *testing.T) { + emptyAuth := "" + _, err := ss.User().GetByAuth(&emptyAuth, "unknown") + require.Error(t, err) var invErr *store.ErrInvalidInput require.True(t, errors.As(err, &invErr)) }) } +func testUserStoreGetByAuthData(t *testing.T, rctx request.CTX, ss store.Store) { + teamID := model.NewId() + auth1 := model.NewId() + auth2 := model.NewId() + + u1, err := ss.User().Save(rctx, &model.User{ + Email: MakeEmail(), + Username: "u1" + model.NewId(), + AuthData: &auth1, + AuthService: "service", + }) + require.NoError(t, err) + defer func() { require.NoError(t, ss.User().PermanentDelete(rctx, u1.Id)) }() + _, nErr := ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: teamID, UserId: u1.Id}, -1) + require.NoError(t, nErr) + + u2, err := ss.User().Save(rctx, &model.User{ + Email: MakeEmail(), + Username: "u2" + model.NewId(), + AuthData: &auth2, + AuthService: "service2", + }) + require.NoError(t, err) + defer func() { require.NoError(t, ss.User().PermanentDelete(rctx, u2.Id)) }() + _, nErr = ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: teamID, UserId: u2.Id}, -1) + require.NoError(t, nErr) + + t.Run("returns full user when auth data matches", func(t *testing.T) { + u, err := ss.User().GetByAuthData(u1.AuthData) + require.NoError(t, err) + assert.Equal(t, u1, u) + }) + + t.Run("matches regardless of auth service", func(t *testing.T) { + u, err := ss.User().GetByAuthData(u2.AuthData) + require.NoError(t, err) + assert.Equal(t, u2.Id, u.Id) + assert.Equal(t, "service2", u.AuthService) + }) + + t.Run("returns ErrNotFound for unknown auth data", func(t *testing.T) { + unknownAuth := model.NewId() + _, err := ss.User().GetByAuthData(&unknownAuth) + require.Error(t, err) + var nfErr *store.ErrNotFound + require.True(t, errors.As(err, &nfErr)) + }) + + t.Run("returns ErrInvalidInput for nil auth data", func(t *testing.T) { + _, err := ss.User().GetByAuthData(nil) + require.Error(t, err) + var invErr *store.ErrInvalidInput + require.True(t, errors.As(err, &invErr)) + }) + + t.Run("returns ErrInvalidInput for empty auth data", func(t *testing.T) { + emptyAuth := "" + _, err := ss.User().GetByAuthData(&emptyAuth) + require.Error(t, err) + var invErr *store.ErrInvalidInput + require.True(t, errors.As(err, &invErr)) + }) + + t.Run("matches when auth data is an email-shaped value", func(t *testing.T) { + // ResetAuthDataToEmailForUsers sets AuthData = Email for whole batches of + // users, so email-shaped auth_data values are common in practice. + emailAuth := "u3-" + model.NewId() + "@example.com" + u3, err := ss.User().Save(rctx, &model.User{ + Email: MakeEmail(), + Username: "u3" + model.NewId(), + AuthData: &emailAuth, + AuthService: "service", + }) + require.NoError(t, err) + defer func() { require.NoError(t, ss.User().PermanentDelete(rctx, u3.Id)) }() + + u, err := ss.User().GetByAuthData(&emailAuth) + require.NoError(t, err) + assert.Equal(t, u3.Id, u.Id) + require.NotNil(t, u.AuthData) + assert.Equal(t, emailAuth, *u.AuthData) + }) +} + func testUserStoreGetByUsername(t *testing.T, rctx request.CTX, ss store.Store) { teamID := model.NewId() diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index 82f7f583b7a..b1b4e1ed635 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -26,6 +26,7 @@ type TimerLayer struct { BotStore store.BotStore ChannelStore store.ChannelStore ChannelBookmarkStore store.ChannelBookmarkStore + ChannelJoinRequestStore store.ChannelJoinRequestStore ChannelMemberHistoryStore store.ChannelMemberHistoryStore ClusterDiscoveryStore store.ClusterDiscoveryStore CommandStore store.CommandStore @@ -106,6 +107,10 @@ func (s *TimerLayer) ChannelBookmark() store.ChannelBookmarkStore { return s.ChannelBookmarkStore } +func (s *TimerLayer) ChannelJoinRequest() store.ChannelJoinRequestStore { + return s.ChannelJoinRequestStore +} + func (s *TimerLayer) ChannelMemberHistory() store.ChannelMemberHistoryStore { return s.ChannelMemberHistoryStore } @@ -341,6 +346,11 @@ type TimerLayerChannelBookmarkStore struct { Root *TimerLayer } +type TimerLayerChannelJoinRequestStore struct { + store.ChannelJoinRequestStore + Root *TimerLayer +} + type TimerLayerChannelMemberHistoryStore struct { store.ChannelMemberHistoryStore Root *TimerLayer @@ -3218,6 +3228,118 @@ func (s *TimerLayerChannelBookmarkStore) UpdateSortOrder(bookmarkID string, chan return result, err } +func (s *TimerLayerChannelJoinRequestStore) CountPending(channelId string) (int64, error) { + start := time.Now() + + result, err := s.ChannelJoinRequestStore.CountPending(channelId) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.CountPending", success, elapsed) + } + return result, err +} + +func (s *TimerLayerChannelJoinRequestStore) Get(id string) (*model.ChannelJoinRequest, error) { + start := time.Now() + + result, err := s.ChannelJoinRequestStore.Get(id) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.Get", success, elapsed) + } + return result, err +} + +func (s *TimerLayerChannelJoinRequestStore) GetForChannel(channelId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + start := time.Now() + + result, resultVar1, err := s.ChannelJoinRequestStore.GetForChannel(channelId, opts) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.GetForChannel", success, elapsed) + } + return result, resultVar1, err +} + +func (s *TimerLayerChannelJoinRequestStore) GetForUser(userId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + start := time.Now() + + result, resultVar1, err := s.ChannelJoinRequestStore.GetForUser(userId, opts) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.GetForUser", success, elapsed) + } + return result, resultVar1, err +} + +func (s *TimerLayerChannelJoinRequestStore) GetPendingForChannelAndUser(channelId string, userId string) (*model.ChannelJoinRequest, error) { + start := time.Now() + + result, err := s.ChannelJoinRequestStore.GetPendingForChannelAndUser(channelId, userId) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.GetPendingForChannelAndUser", success, elapsed) + } + return result, err +} + +func (s *TimerLayerChannelJoinRequestStore) Save(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + start := time.Now() + + result, err := s.ChannelJoinRequestStore.Save(req) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.Save", success, elapsed) + } + return result, err +} + +func (s *TimerLayerChannelJoinRequestStore) Update(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + start := time.Now() + + result, err := s.ChannelJoinRequestStore.Update(req) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.Update", success, elapsed) + } + return result, err +} + func (s *TimerLayerChannelMemberHistoryStore) DeleteOrphanedRows(limit int) (int64, error) { start := time.Now() @@ -3250,10 +3372,10 @@ func (s *TimerLayerChannelMemberHistoryStore) GetChannelsLeftSince(userID string return result, err } -func (s *TimerLayerChannelMemberHistoryStore) GetEverMembersInChannel(channelID string, userIDs []string) ([]string, error) { +func (s *TimerLayerChannelMemberHistoryStore) GetChannelsWithActivityDuring(startTime int64, endTime int64) ([]string, error) { start := time.Now() - result, err := s.ChannelMemberHistoryStore.GetEverMembersInChannel(channelID, userIDs) + result, err := s.ChannelMemberHistoryStore.GetChannelsWithActivityDuring(startTime, endTime) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { @@ -3261,15 +3383,15 @@ func (s *TimerLayerChannelMemberHistoryStore) GetEverMembersInChannel(channelID if err == nil { success = "true" } - s.Root.Metrics.ObserveStoreMethodDuration("ChannelMemberHistoryStore.GetEverMembersInChannel", success, elapsed) + s.Root.Metrics.ObserveStoreMethodDuration("ChannelMemberHistoryStore.GetChannelsWithActivityDuring", success, elapsed) } return result, err } -func (s *TimerLayerChannelMemberHistoryStore) GetChannelsWithActivityDuring(startTime int64, endTime int64) ([]string, error) { +func (s *TimerLayerChannelMemberHistoryStore) GetEverMembersInChannel(channelID string, userIDs []string) ([]string, error) { start := time.Now() - result, err := s.ChannelMemberHistoryStore.GetChannelsWithActivityDuring(startTime, endTime) + result, err := s.ChannelMemberHistoryStore.GetEverMembersInChannel(channelID, userIDs) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { @@ -3277,7 +3399,7 @@ func (s *TimerLayerChannelMemberHistoryStore) GetChannelsWithActivityDuring(star if err == nil { success = "true" } - s.Root.Metrics.ObserveStoreMethodDuration("ChannelMemberHistoryStore.GetChannelsWithActivityDuring", success, elapsed) + s.Root.Metrics.ObserveStoreMethodDuration("ChannelMemberHistoryStore.GetEverMembersInChannel", success, elapsed) } return result, err } @@ -12742,6 +12864,22 @@ func (s *TimerLayerUserStore) GetByAuth(authData *string, authService string) (* return result, err } +func (s *TimerLayerUserStore) GetByAuthData(authData *string) (*model.User, error) { + start := time.Now() + + result, err := s.UserStore.GetByAuthData(authData) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("UserStore.GetByAuthData", success, elapsed) + } + return result, err +} + func (s *TimerLayerUserStore) GetByEmail(email string) (*model.User, error) { start := time.Now() @@ -14558,6 +14696,7 @@ func New(childStore store.Store, metrics einterfaces.MetricsInterface) *TimerLay newStore.BotStore = &TimerLayerBotStore{BotStore: childStore.Bot(), Root: &newStore} newStore.ChannelStore = &TimerLayerChannelStore{ChannelStore: childStore.Channel(), Root: &newStore} newStore.ChannelBookmarkStore = &TimerLayerChannelBookmarkStore{ChannelBookmarkStore: childStore.ChannelBookmark(), Root: &newStore} + newStore.ChannelJoinRequestStore = &TimerLayerChannelJoinRequestStore{ChannelJoinRequestStore: childStore.ChannelJoinRequest(), Root: &newStore} newStore.ChannelMemberHistoryStore = &TimerLayerChannelMemberHistoryStore{ChannelMemberHistoryStore: childStore.ChannelMemberHistory(), Root: &newStore} newStore.ClusterDiscoveryStore = &TimerLayerClusterDiscoveryStore{ClusterDiscoveryStore: childStore.ClusterDiscovery(), Root: &newStore} newStore.CommandStore = &TimerLayerCommandStore{CommandStore: childStore.Command(), Root: &newStore} diff --git a/server/channels/testlib/store.go b/server/channels/testlib/store.go index 9aee6d57e96..5d01e25c6ab 100644 --- a/server/channels/testlib/store.go +++ b/server/channels/testlib/store.go @@ -102,6 +102,7 @@ func GetMockStoreForSetupFunctions() *mocks.Store { systemStore.On("GetByName", model.MigrationKeyAccessControlPolicyV0_3).Return(&model.System{Name: model.MigrationKeyAccessControlPolicyV0_3, Value: "true"}, nil) systemStore.On("GetByName", model.MigrationKeyAddManageAgentPermissions).Return(&model.System{Name: model.MigrationKeyAddManageAgentPermissions, Value: "true"}, nil) systemStore.On("GetByName", model.MigrationKeyAddEditFileAttachmentPermission).Return(&model.System{Name: model.MigrationKeyAddEditFileAttachmentPermission, Value: "true"}, nil) + systemStore.On("GetByName", model.MigrationKeyAddDiscoverableChannelPermissions).Return(&model.System{Name: model.MigrationKeyAddDiscoverableChannelPermissions, Value: "true"}, nil) systemStore.On("InsertIfExists", mock.AnythingOfType("*model.System")).Return(&model.System{}, nil).Once() systemStore.On("Save", mock.AnythingOfType("*model.System")).Return(nil) diff --git a/server/cmd/mmctl/commands/permissions_test.go b/server/cmd/mmctl/commands/permissions_test.go index 51cbd4f5332..671ba68b051 100644 --- a/server/cmd/mmctl/commands/permissions_test.go +++ b/server/cmd/mmctl/commands/permissions_test.go @@ -251,6 +251,8 @@ func (s *MmctlUnitTestSuite) TestResetPermissionsCmd() { "manage_channel_access_rules", "manage_public_channel_auto_translation", "manage_private_channel_auto_translation", + "manage_private_channel_discoverability", + "manage_channel_join_requests", } expectedPatch := &model.RolePatch{ Permissions: &expectedPermissions, diff --git a/server/i18n/en.json b/server/i18n/en.json index 15a5aa99763..b46a9eb05d3 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -10738,6 +10738,10 @@ "id": "model.channel.is_valid.creator_id.app_error", "translation": "Invalid creator id." }, + { + "id": "model.channel.is_valid.discoverable.app_error", + "translation": "Only private channels can be marked as discoverable." + }, { "id": "model.channel.is_valid.display_name.app_error", "translation": "Invalid display name." @@ -10830,6 +10834,50 @@ "id": "model.channel_bookmark.is_valid.update_at.app_error", "translation": "Update at must be a valid time." }, + { + "id": "model.channel_join_request.is_valid.channel_id.app_error", + "translation": "Invalid channel id." + }, + { + "id": "model.channel_join_request.is_valid.create_at.app_error", + "translation": "Create at must be a valid time." + }, + { + "id": "model.channel_join_request.is_valid.denial_reason.app_error", + "translation": "Denial reason is too long." + }, + { + "id": "model.channel_join_request.is_valid.denial_reason_status.app_error", + "translation": "Denial reason can only be set on a denied join request." + }, + { + "id": "model.channel_join_request.is_valid.id.app_error", + "translation": "Invalid Id." + }, + { + "id": "model.channel_join_request.is_valid.message.app_error", + "translation": "Join request message is too long." + }, + { + "id": "model.channel_join_request.is_valid.reviewed_by.app_error", + "translation": "Invalid reviewer id." + }, + { + "id": "model.channel_join_request.is_valid.reviewer.app_error", + "translation": "An approved or denied join request must record the reviewer and review time." + }, + { + "id": "model.channel_join_request.is_valid.status.app_error", + "translation": "Invalid join request status." + }, + { + "id": "model.channel_join_request.is_valid.update_at.app_error", + "translation": "Update at must be a valid time." + }, + { + "id": "model.channel_join_request.is_valid.user_id.app_error", + "translation": "Invalid user id." + }, { "id": "model.channel_member.is_valid.channel_auto_follow_threads_value.app_error", "translation": "Invalid channel-auto-follow-threads value." diff --git a/server/public/model/channel.go b/server/public/model/channel.go index 5226f160c97..ef73d651fe8 100644 --- a/server/public/model/channel.go +++ b/server/public/model/channel.go @@ -108,6 +108,7 @@ type Channel struct { PolicyIsActive bool `json:"policy_is_active"` DefaultCategoryName string `json:"default_category_name"` ManagedCategoryName string `json:"managed_category_name"` + Discoverable bool `json:"discoverable"` } func (o *Channel) Auditable() map[string]any { @@ -131,6 +132,7 @@ func (o *Channel) Auditable() map[string]any { "policy_enforced": o.PolicyEnforced, "autotranslation": o.AutoTranslation, "policy_is_active": o.PolicyIsActive, // this field is only for logging purposes + "discoverable": o.Discoverable, } } @@ -160,6 +162,7 @@ type ChannelPatch struct { AutoTranslation *bool `json:"autotranslation"` ManagedCategoryName *string `json:"managed_category_name"` DefaultCategoryName *string `json:"default_category_name"` + Discoverable *bool `json:"discoverable"` } func (c *ChannelPatch) Auditable() map[string]any { @@ -169,6 +172,7 @@ func (c *ChannelPatch) Auditable() map[string]any { "purpose": c.Purpose, "default_category_name": c.DefaultCategoryName, "managed_category_name": c.ManagedCategoryName, + "discoverable": c.Discoverable, } } @@ -339,6 +343,10 @@ func (o *Channel) IsValid() *AppError { } } + if o.Discoverable && o.Type != ChannelTypePrivate { + return NewAppError("Channel.IsValid", "model.channel.is_valid.discoverable.app_error", nil, "id="+o.Id, http.StatusBadRequest) + } + return nil } @@ -459,6 +467,10 @@ func (o *Channel) Patch(patch *ChannelPatch) { if patch.DefaultCategoryName != nil { o.DefaultCategoryName = strings.TrimSpace(*patch.DefaultCategoryName) } + + if patch.Discoverable != nil { + o.Discoverable = *patch.Discoverable + } } func (o *Channel) MakeNonNil() { diff --git a/server/public/model/channel_join_request.go b/server/public/model/channel_join_request.go new file mode 100644 index 00000000000..38c0e6248dc --- /dev/null +++ b/server/public/model/channel_join_request.go @@ -0,0 +1,165 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package model + +import ( + "net/http" + "unicode/utf8" +) + +const ( + ChannelJoinRequestStatusPending = "pending" + ChannelJoinRequestStatusApproved = "approved" + ChannelJoinRequestStatusDenied = "denied" + ChannelJoinRequestStatusWithdrawn = "withdrawn" + + ChannelJoinRequestMessageMaxRunes = 500 + ChannelJoinRequestDenialReasonMaxRunes = 500 +) + +// ChannelJoinRequest records a user's request to join a discoverable private channel. +// +// Rows are append-only / status-mutating: a request transitions through +// pending → approved | denied | withdrawn. Rows are never deleted so the full +// audit history is preserved. A partial unique index in Postgres enforces at +// most one active pending row per (ChannelId, UserId). +type ChannelJoinRequest struct { + Id string `json:"id"` + ChannelId string `json:"channel_id"` + UserId string `json:"user_id"` + Message string `json:"message"` + Status string `json:"status"` + DenialReason string `json:"denial_reason"` + CreateAt int64 `json:"create_at"` + UpdateAt int64 `json:"update_at"` + ReviewedBy string `json:"reviewed_by"` + ReviewedAt int64 `json:"reviewed_at"` +} + +// ChannelJoinRequestList is the paginated response shape returned by list endpoints. +type ChannelJoinRequestList struct { + Requests []*ChannelJoinRequest `json:"requests"` + TotalCount int64 `json:"total_count"` +} + +// ChannelJoinRequestPatch represents the admin review action: approve or deny, +// with an optional denial reason that is surfaced to the requester. +type ChannelJoinRequestPatch struct { + Status string `json:"status"` + DenialReason *string `json:"denial_reason,omitempty"` +} + +// GetChannelJoinRequestsOpts filters and paginates list queries on the store. +// An empty Status means "pending". +type GetChannelJoinRequestsOpts struct { + Status string + Page int + PerPage int +} + +// IsValidChannelJoinRequestStatus reports whether the given status string is a +// recognized lifecycle value for a ChannelJoinRequest. +func IsValidChannelJoinRequestStatus(s string) bool { + switch s { + case ChannelJoinRequestStatusPending, + ChannelJoinRequestStatusApproved, + ChannelJoinRequestStatusDenied, + ChannelJoinRequestStatusWithdrawn: + return true + } + return false +} + +func (r *ChannelJoinRequest) Auditable() map[string]any { + return map[string]any{ + "id": r.Id, + "channel_id": r.ChannelId, + "user_id": r.UserId, + "status": r.Status, + "create_at": r.CreateAt, + "update_at": r.UpdateAt, + "reviewed_by": r.ReviewedBy, + "reviewed_at": r.ReviewedAt, + "has_message": r.Message != "", + "has_denial_reason": r.DenialReason != "", + } +} + +func (r *ChannelJoinRequest) LogClone() any { + return r.Auditable() +} + +func (r *ChannelJoinRequest) IsValid() *AppError { + if !IsValidId(r.Id) { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.id.app_error", nil, "", http.StatusBadRequest) + } + + if !IsValidId(r.ChannelId) { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.channel_id.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if !IsValidId(r.UserId) { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.user_id.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if r.CreateAt == 0 { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.create_at.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if r.UpdateAt == 0 { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.update_at.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if !IsValidChannelJoinRequestStatus(r.Status) { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.status.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if utf8.RuneCountInString(r.Message) > ChannelJoinRequestMessageMaxRunes { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.message.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if utf8.RuneCountInString(r.DenialReason) > ChannelJoinRequestDenialReasonMaxRunes { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.denial_reason.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + // A denial reason is only meaningful on a denied request. + if r.DenialReason != "" && r.Status != ChannelJoinRequestStatusDenied { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.denial_reason_status.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if r.ReviewedBy != "" && !IsValidId(r.ReviewedBy) { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.reviewed_by.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + // Reviewer and reviewed-at must accompany a terminal review action. + switch r.Status { + case ChannelJoinRequestStatusApproved, ChannelJoinRequestStatusDenied: + if r.ReviewedBy == "" || r.ReviewedAt == 0 { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.reviewer.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + } + + return nil +} + +func (r *ChannelJoinRequest) PreSave() { + if r.Id == "" { + r.Id = NewId() + } + if r.Status == "" { + r.Status = ChannelJoinRequestStatusPending + } + if r.CreateAt == 0 { + r.CreateAt = GetMillis() + } + r.UpdateAt = r.CreateAt + r.Message = SanitizeUnicode(r.Message) + r.DenialReason = SanitizeUnicode(r.DenialReason) +} + +func (r *ChannelJoinRequest) PreUpdate() { + r.UpdateAt = GetMillis() + r.Message = SanitizeUnicode(r.Message) + r.DenialReason = SanitizeUnicode(r.DenialReason) +} diff --git a/server/public/model/channel_join_request_test.go b/server/public/model/channel_join_request_test.go new file mode 100644 index 00000000000..78f354732c4 --- /dev/null +++ b/server/public/model/channel_join_request_test.go @@ -0,0 +1,114 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package model + +import ( + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func validRequest() *ChannelJoinRequest { + return &ChannelJoinRequest{ + Id: NewId(), + ChannelId: NewId(), + UserId: NewId(), + Status: ChannelJoinRequestStatusPending, + CreateAt: GetMillis(), + UpdateAt: GetMillis(), + } +} + +func TestChannelJoinRequestPreSaveDefaults(t *testing.T) { + r := &ChannelJoinRequest{ + ChannelId: NewId(), + UserId: NewId(), + } + r.PreSave() + + assert.NotEmpty(t, r.Id, "PreSave must assign an Id when missing") + assert.Equal(t, ChannelJoinRequestStatusPending, r.Status, "PreSave must default Status to pending") + assert.NotZero(t, r.CreateAt) + assert.Equal(t, r.CreateAt, r.UpdateAt, "PreSave must align UpdateAt with CreateAt") +} + +func TestChannelJoinRequestPreUpdateAdvancesUpdateAt(t *testing.T) { + r := validRequest() + originalCreate := r.CreateAt + // Seed UpdateAt to a known-old value so we can prove PreUpdate actually + // advanced it (the validRequest factory sets UpdateAt = GetMillis(), so + // a no-op PreUpdate could otherwise still pass a GreaterOrEqual check). + r.UpdateAt = 1 + r.PreUpdate() + + assert.Greater(t, r.UpdateAt, int64(1)) + assert.Equal(t, originalCreate, r.CreateAt, "PreUpdate must not mutate CreateAt") +} + +func TestChannelJoinRequestIsValid(t *testing.T) { + t.Run("happy path pending", func(t *testing.T) { + require.Nil(t, validRequest().IsValid()) + }) + + t.Run("invalid id", func(t *testing.T) { + r := validRequest() + r.Id = "not-an-id" + err := r.IsValid() + require.NotNil(t, err) + assert.Equal(t, http.StatusBadRequest, err.StatusCode) + }) + + t.Run("rejects unknown status", func(t *testing.T) { + r := validRequest() + r.Status = "weird" + require.NotNil(t, r.IsValid()) + }) + + t.Run("rejects message over rune limit", func(t *testing.T) { + r := validRequest() + r.Message = strings.Repeat("a", ChannelJoinRequestMessageMaxRunes+1) + require.NotNil(t, r.IsValid()) + }) + + t.Run("rejects denial reason on non-denied request", func(t *testing.T) { + r := validRequest() + r.Status = ChannelJoinRequestStatusApproved + r.ReviewedBy = NewId() + r.ReviewedAt = GetMillis() + r.DenialReason = "nope" + require.NotNil(t, r.IsValid(), "denial reason must only be set on denied rows") + }) + + t.Run("requires reviewer info for terminal review", func(t *testing.T) { + r := validRequest() + r.Status = ChannelJoinRequestStatusApproved + require.NotNil(t, r.IsValid(), "approved without reviewer must be invalid") + + r.ReviewedBy = NewId() + r.ReviewedAt = GetMillis() + require.Nil(t, r.IsValid()) + }) + + t.Run("withdrawn does not require reviewer", func(t *testing.T) { + r := validRequest() + r.Status = ChannelJoinRequestStatusWithdrawn + require.Nil(t, r.IsValid(), "withdrawn is a self-service action, not a review") + }) +} + +func TestIsValidChannelJoinRequestStatus(t *testing.T) { + for _, s := range []string{ + ChannelJoinRequestStatusPending, + ChannelJoinRequestStatusApproved, + ChannelJoinRequestStatusDenied, + ChannelJoinRequestStatusWithdrawn, + } { + assert.True(t, IsValidChannelJoinRequestStatus(s), "%q should be a valid status", s) + } + assert.False(t, IsValidChannelJoinRequestStatus("")) + assert.False(t, IsValidChannelJoinRequestStatus("approved ")) +} diff --git a/server/public/model/channel_test.go b/server/public/model/channel_test.go index 8d7a3b2ad09..21143f3b7f5 100644 --- a/server/public/model/channel_test.go +++ b/server/public/model/channel_test.go @@ -35,6 +35,64 @@ func TestChannelPatch(t *testing.T) { require.Equal(t, *p.GroupConstrained, *o.GroupConstrained) } +func TestChannelPatchDiscoverable(t *testing.T) { + t.Run("applies discoverable when set", func(t *testing.T) { + on := true + p := &ChannelPatch{Discoverable: &on} + o := Channel{Id: NewId(), Name: NewId(), Type: ChannelTypePrivate} + o.Patch(p) + require.True(t, o.Discoverable) + }) + + t.Run("clears discoverable when set to false", func(t *testing.T) { + off := false + p := &ChannelPatch{Discoverable: &off} + o := Channel{Id: NewId(), Name: NewId(), Type: ChannelTypePrivate, Discoverable: true} + o.Patch(p) + require.False(t, o.Discoverable) + }) + + t.Run("nil discoverable leaves channel untouched", func(t *testing.T) { + o := Channel{Id: NewId(), Name: NewId(), Type: ChannelTypePrivate, Discoverable: true} + o.Patch(&ChannelPatch{}) + require.True(t, o.Discoverable) + }) +} + +func TestChannelIsValidDiscoverable(t *testing.T) { + base := Channel{ + Id: NewId(), + CreateAt: GetMillis(), + UpdateAt: GetMillis(), + DisplayName: "x", + Name: "valid-name", + Header: "h", + Purpose: "p", + } + + t.Run("discoverable=false is valid on any type", func(t *testing.T) { + c := base + c.Type = ChannelTypeOpen + require.Nil(t, c.IsValid()) + }) + + t.Run("discoverable=true requires private channel", func(t *testing.T) { + c := base + c.Type = ChannelTypeOpen + c.Discoverable = true + require.NotNil(t, c.IsValid(), "discoverable=true on public channel must be rejected") + + c.Type = ChannelTypeDirect + require.NotNil(t, c.IsValid()) + + c.Type = ChannelTypeGroup + require.NotNil(t, c.IsValid()) + + c.Type = ChannelTypePrivate + require.Nil(t, c.IsValid()) + }) +} + func TestChannelIsValid(t *testing.T) { o := Channel{} diff --git a/server/public/model/client4.go b/server/public/model/client4.go index 179940aff96..c3eb8c76e43 100644 --- a/server/public/model/client4.go +++ b/server/public/model/client4.go @@ -1186,6 +1186,18 @@ func (c *Client4) GetUserByEmail(ctx context.Context, email, etag string) (*User return DecodeJSONFromResponse[*User](r) } +// GetUserByAuthData returns a user by auth_data (external AuthData). +func (c *Client4) GetUserByAuthData(ctx context.Context, authData, etag string) (*User, *Response, error) { + values := url.Values{} + values.Set("value", authData) + r, err := c.doAPIGetWithQuery(ctx, c.usersRoute().Join("auth_data"), values, etag) + if err != nil { + return nil, BuildResponse(r), err + } + defer closeBody(r) + return DecodeJSONFromResponse[*User](r) +} + // AutocompleteUsersInTeam returns the users on a team based on search term. func (c *Client4) AutocompleteUsersInTeam(ctx context.Context, teamId string, username string, limit int, etag string) (*UserAutocomplete, *Response, error) { values := url.Values{} diff --git a/server/public/model/feature_flags.go b/server/public/model/feature_flags.go index ac8d06a3092..68465f32025 100644 --- a/server/public/model/feature_flags.go +++ b/server/public/model/feature_flags.go @@ -120,6 +120,11 @@ type FeatureFlags struct { // ManagedChannelCategories enables server-side managed sidebar category enforcement (Enterprise). ManagedChannelCategories bool + + // FEATURE_FLAG_REMOVAL: DiscoverableChannels - Remove this when the feature is GA. + // Gates the per-channel Discoverable toggle and the channel-join-request flow that lets + // non-members find a private channel in Browse Channels and request to join it. + DiscoverableChannels bool } func (f *FeatureFlags) SetDefaults() { @@ -176,6 +181,8 @@ func (f *FeatureFlags) SetDefaults() { f.AggregatePluginMetrics = false f.ManagedChannelCategories = false + + f.DiscoverableChannels = false } // ToMap returns the feature flags as a map[string]string diff --git a/server/public/model/migration.go b/server/public/model/migration.go index f29087a19f5..7ece8d257ba 100644 --- a/server/public/model/migration.go +++ b/server/public/model/migration.go @@ -65,4 +65,5 @@ const ( MigrationKeyAccessControlPolicyV0_3 = "access_control_policy_v0_3_migration" MigrationKeyAddManageAgentPermissions = "add_manage_agent_permissions" MigrationKeyAddEditFileAttachmentPermission = "add_edit_file_attachment_permission" + MigrationKeyAddDiscoverableChannelPermissions = "add_discoverable_channel_permissions" ) diff --git a/server/public/model/permission.go b/server/public/model/permission.go index 63ab17e5369..921d46ff894 100644 --- a/server/public/model/permission.go +++ b/server/public/model/permission.go @@ -49,6 +49,8 @@ var PermissionManagePublicChannelProperties *Permission var PermissionManagePrivateChannelProperties *Permission var PermissionManagePublicChannelAutoTranslation *Permission var PermissionManagePrivateChannelAutoTranslation *Permission +var PermissionManagePrivateChannelDiscoverability *Permission +var PermissionManageChannelJoinRequests *Permission var PermissionListPublicTeams *Permission var PermissionJoinPublicTeams *Permission var PermissionListPrivateTeams *Permission @@ -568,6 +570,18 @@ func initializePermissions() { "authentication.permissions.manage_private_channel_auto_translation.description", PermissionScopeChannel, } + PermissionManagePrivateChannelDiscoverability = &Permission{ + "manage_private_channel_discoverability", + "authentication.permissions.manage_private_channel_discoverability.name", + "authentication.permissions.manage_private_channel_discoverability.description", + PermissionScopeChannel, + } + PermissionManageChannelJoinRequests = &Permission{ + "manage_channel_join_requests", + "authentication.permissions.manage_channel_join_requests.name", + "authentication.permissions.manage_channel_join_requests.description", + PermissionScopeChannel, + } PermissionListPublicTeams = &Permission{ "list_public_teams", "authentication.permissions.list_public_teams.name", @@ -2631,6 +2645,8 @@ func initializePermissions() { PermissionManagePrivateChannelBanner, PermissionManageChannelAccessRules, PermissionEditFileAttachment, + PermissionManagePrivateChannelDiscoverability, + PermissionManageChannelJoinRequests, } GroupScopedPermissions := []*Permission{ diff --git a/server/public/model/role.go b/server/public/model/role.go index e7b84a5d829..17f2807f3d7 100644 --- a/server/public/model/role.go +++ b/server/public/model/role.go @@ -778,25 +778,28 @@ func (r *Role) RolePatchFromChannelModerationsPatch(channelModerationsPatch []*C return &RolePatch{Permissions: &patchPermissions} } -func (r *Role) IsValid() bool { +func (r *Role) IsValid() error { if !IsValidId(r.Id) { - return false + return fmt.Errorf("invalid role id %q", r.Id) } return r.IsValidWithoutId() } -func (r *Role) IsValidWithoutId() bool { +func (r *Role) IsValidWithoutId() error { if !IsValidRoleName(r.Name) { - return false + return fmt.Errorf("invalid role name %q", r.Name) } - if r.DisplayName == "" || len(r.DisplayName) > RoleDisplayNameMaxLength { - return false + if r.DisplayName == "" { + return fmt.Errorf("role display name must not be empty") + } + if len(r.DisplayName) > RoleDisplayNameMaxLength { + return fmt.Errorf("role display name %q exceeds maximum length of %d", r.DisplayName, RoleDisplayNameMaxLength) } if len(r.Description) > RoleDescriptionMaxLength { - return false + return fmt.Errorf("role description exceeds maximum length of %d", RoleDescriptionMaxLength) } check := func(perms []*Permission, permission string) bool { @@ -808,13 +811,12 @@ func (r *Role) IsValidWithoutId() bool { return false } for _, permission := range r.Permissions { - permissionValidated := check(AllPermissions, permission) || check(DeprecatedPermissions, permission) - if !permissionValidated { - return false + if !check(AllPermissions, permission) && !check(DeprecatedPermissions, permission) { + return fmt.Errorf("unknown permission %q", permission) } } - return true + return nil } func CleanRoleNames(roleNames []string) ([]string, bool) { @@ -930,6 +932,8 @@ func MakeDefaultRoles() map[string]*Role { PermissionManageChannelAccessRules.Id, PermissionManagePublicChannelAutoTranslation.Id, PermissionManagePrivateChannelAutoTranslation.Id, + PermissionManagePrivateChannelDiscoverability.Id, + PermissionManageChannelJoinRequests.Id, }, SchemeManaged: true, BuiltIn: true, diff --git a/server/public/model/role_test.go b/server/public/model/role_test.go index 9abf2c1c81e..0550509cbba 100644 --- a/server/public/model/role_test.go +++ b/server/public/model/role_test.go @@ -5,6 +5,7 @@ package model import ( "slices" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -362,6 +363,106 @@ func TestManageAgentPermissionsDefinition(t *testing.T) { }), "manage_others_agent should be in AllPermissions") } +func TestRoleIsValidWithoutId(t *testing.T) { + validRole := func() *Role { + return &Role{ + Name: "test_role", + DisplayName: "Test Role", + Description: "A test role.", + Permissions: []string{PermissionCreatePost.Id}, + } + } + + t.Run("valid role returns nil", func(t *testing.T) { + assert.NoError(t, validRole().IsValidWithoutId()) + }) + + t.Run("empty name", func(t *testing.T) { + r := validRole() + r.Name = "" + assert.ErrorContains(t, r.IsValidWithoutId(), "invalid role name") + }) + + t.Run("name too long", func(t *testing.T) { + r := validRole() + r.Name = strings.Repeat("a", RoleNameMaxLength+1) + assert.ErrorContains(t, r.IsValidWithoutId(), "invalid role name") + }) + + t.Run("name with invalid characters", func(t *testing.T) { + r := validRole() + r.Name = "invalid-name" + assert.ErrorContains(t, r.IsValidWithoutId(), "invalid role name") + }) + + t.Run("empty display name", func(t *testing.T) { + r := validRole() + r.DisplayName = "" + assert.ErrorContains(t, r.IsValidWithoutId(), "display name must not be empty") + }) + + t.Run("display name too long", func(t *testing.T) { + r := validRole() + r.DisplayName = strings.Repeat("a", RoleDisplayNameMaxLength+1) + err := r.IsValidWithoutId() + assert.ErrorContains(t, err, "display name") + assert.ErrorContains(t, err, "exceeds maximum length") + }) + + t.Run("description too long", func(t *testing.T) { + r := validRole() + r.Description = strings.Repeat("a", RoleDescriptionMaxLength+1) + assert.ErrorContains(t, r.IsValidWithoutId(), "description exceeds maximum length") + }) + + t.Run("unknown permission", func(t *testing.T) { + r := validRole() + r.Permissions = []string{"not_a_real_permission"} + err := r.IsValidWithoutId() + require.ErrorContains(t, err, "unknown permission") + assert.ErrorContains(t, err, "not_a_real_permission") + }) + + t.Run("no permissions is valid", func(t *testing.T) { + r := validRole() + r.Permissions = nil + assert.NoError(t, r.IsValidWithoutId()) + }) +} + +func TestRoleIsValid(t *testing.T) { + validRole := func() *Role { + return &Role{ + Id: NewId(), + Name: "test_role", + DisplayName: "Test Role", + Permissions: []string{PermissionCreatePost.Id}, + } + } + + t.Run("valid role returns nil", func(t *testing.T) { + assert.NoError(t, validRole().IsValid()) + }) + + t.Run("empty id", func(t *testing.T) { + r := validRole() + r.Id = "" + assert.ErrorContains(t, r.IsValid(), "invalid role id") + }) + + t.Run("invalid id", func(t *testing.T) { + r := validRole() + r.Id = "not-a-valid-id!" + assert.ErrorContains(t, r.IsValid(), "invalid role id") + }) + + t.Run("propagates IsValidWithoutId error", func(t *testing.T) { + r := validRole() + r.DisplayName = "" + assert.ErrorContains(t, r.IsValid(), "display name must not be empty") + }) +} + func TestManageAgentPermissionsDefaultRoles(t *testing.T) { roles := MakeDefaultRoles() diff --git a/server/public/model/websocket_message.go b/server/public/model/websocket_message.go index c816d3234a7..87f9ead3544 100644 --- a/server/public/model/websocket_message.go +++ b/server/public/model/websocket_message.go @@ -117,6 +117,8 @@ const ( WebsocketEventFileDownloadRejected WebsocketEventType = "file_download_rejected" WebsocketEventShowToast WebsocketEventType = "show_toast" WebsocketEventSharedChannelRemoteUpdated WebsocketEventType = "shared_channel_remote_updated" + WebsocketEventChannelJoinRequestCreated WebsocketEventType = "channel_join_request_created" + WebsocketEventChannelJoinRequestUpdated WebsocketEventType = "channel_join_request_updated" WebSocketMsgTypeResponse = "response" WebSocketMsgTypeEvent = "event" diff --git a/tools/mattermost-govet/apiAuditLogs/whitelist.go b/tools/mattermost-govet/apiAuditLogs/whitelist.go index b79b3f6b7b5..9ec9502a30a 100644 --- a/tools/mattermost-govet/apiAuditLogs/whitelist.go +++ b/tools/mattermost-govet/apiAuditLogs/whitelist.go @@ -121,6 +121,7 @@ var whiteList = map[string]bool{ "getUserAccessToken": true, "getUserAccessTokens": true, "getUserAccessTokensForUser": true, + "getUserByAuthData": true, "getUserByEmail": true, "getUserByUsername": true, "getUsers": true, @@ -133,6 +134,7 @@ var whiteList = map[string]bool{ "getWebappPlugins": true, "listAutocompleteCommands": true, "listCommands": true, + "localGetUserByAuthData": true, "openDialog": true, "patchChannelModerations": true, "pinPost": true, diff --git a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/attribute_selector_menu.tsx b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/attribute_selector_menu.tsx index be7101c5c4e..a24032fef18 100644 --- a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/attribute_selector_menu.tsx +++ b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/attribute_selector_menu.tsx @@ -23,8 +23,22 @@ import type {UserPropertyField} from '@mattermost/types/properties'; import * as Menu from 'components/menu'; +import {getUserPropertyFieldLabel} from 'utils/properties'; + import './selector_menus.scss'; +type AttributeLabelProps = { + displayName: string; + name: string; +}; + +const AttributeLabel = ({displayName, name}: AttributeLabelProps) => ( + + {displayName} + {name} + +); + // Define AttributeIcon outside the main component const AttributeIcon = (props: IconProps & { attribute?: UserPropertyField }) => { const {attribute, ...iconProps} = props; @@ -76,8 +90,12 @@ const AttributeSelectorMenu = ({currentAttribute, availableAttributes, disabled, }, []); // setFilter is stable const options = useMemo(() => { + const q = filter.toLowerCase(); return availableAttributes.filter((attr) => { - return attr.name.toLowerCase().includes(filter.toLowerCase()); + return ( + attr.name.toLowerCase().includes(q) || + getUserPropertyFieldLabel(attr).toLowerCase().includes(q) + ); }); }, [availableAttributes, filter]); @@ -90,6 +108,13 @@ const AttributeSelectorMenu = ({currentAttribute, availableAttributes, disabled, return availableAttributes.find((attr) => attr.name === currentAttribute); }, [currentAttribute, availableAttributes]); + let selectedAttributeLabel; + if (selectedAttributeObject) { + selectedAttributeLabel = getUserPropertyFieldLabel(selectedAttributeObject); + } else { + selectedAttributeLabel = currentAttribute || formatMessage({id: 'admin.access_control.table_editor.selector.select_attribute', defaultMessage: 'Select attribute'}); + } + useEffect(() => { if (autoOpen && !prevAutoOpen.current) { const buttonElement = document.getElementById(buttonId); @@ -111,7 +136,7 @@ const AttributeSelectorMenu = ({currentAttribute, availableAttributes, disabled, children: ( <> - {currentAttribute || formatMessage({id: 'admin.access_control.table_editor.selector.select_attribute', defaultMessage: 'Select attribute'})} + {selectedAttributeLabel} ), dataTestId: 'attributeSelectorMenuButton', @@ -134,6 +159,10 @@ const AttributeSelectorMenu = ({currentAttribute, availableAttributes, disabled, /> {options.map((option) => { const {name} = option; + const displayName = option.attrs?.display_name; + + // hasSpaces checks the CEL identifier (name), not the display label. + // New fields cannot have spaces in name but leaving this check for backwards compatibility with grandfathered legacy fields. const hasSpaces = name.includes(' '); const isSynced = option.attrs?.ldap || option.attrs?.saml; const isAdminManaged = option.attrs?.managed === 'admin'; @@ -148,7 +177,14 @@ const AttributeSelectorMenu = ({currentAttribute, availableAttributes, disabled, forceCloseOnSelect={true} aria-checked={name === currentAttribute} onClick={hasSpaces ? undefined : () => handleAttributeChange(name)} - labels={{name}} + labels={ + displayName ? ( + + ) : {name} + } disabled={hasSpaces || !allowed} leadingElement={ { return userAttributes.find((attr) => { - const hasSpaces = attr.name.includes(' '); + const isValidCELIdentifier = CPA_FIELD_NAME_PATTERN.test(attr.name); const isSynced = attr.attrs?.ldap || attr.attrs?.saml; const isAdminManaged = attr.attrs?.managed === 'admin'; const isProtected = attr.attrs?.protected; const allowed = isSynced || isAdminManaged || isProtected || enableUserManagedAttributes; - return !hasSpaces && allowed; + return isValidCELIdentifier && allowed; }); }; diff --git a/webapp/channels/src/components/admin_console/custom_profile_attributes/custom_profile_attributes.test.tsx b/webapp/channels/src/components/admin_console/custom_profile_attributes/custom_profile_attributes.test.tsx index e84cfbbe448..aea9e4b80e3 100644 --- a/webapp/channels/src/components/admin_console/custom_profile_attributes/custom_profile_attributes.test.tsx +++ b/webapp/channels/src/components/admin_console/custom_profile_attributes/custom_profile_attributes.test.tsx @@ -273,4 +273,63 @@ describe('components/admin_console/custom_profile_attributes/CustomProfileAttrib const warning = await screen.findByText((content) => content.includes('This attribute will be converted to a TEXT attribute')); expect(warning).toBeInTheDocument(); }); + + describe('display_name labels', () => { + test('should render TextSetting label and help text using display_name', async () => { + const displayNameAttr: UserPropertyField = { + ...baseField, + id: 'attr_display', + name: 'my_field', + attrs: { + sort_order: 0, + visibility: 'when_set', + value_type: '', + ldap: 'department', + display_name: 'My Display Name', + }, + }; + const state = createInitialState({displayNameAttr}); + + renderWithContext( + , + state, + ); + + const labelEl = await screen.findByTestId('custom_profile_attribute-my_fieldlabel'); + expect(labelEl.tagName).toBe('LABEL'); + expect(labelEl).toHaveTextContent('My Display Name'); + expect(labelEl).not.toHaveTextContent('my_field'); + + const helpTextEl = screen.getByTestId('custom_profile_attribute-my_fieldhelp-text'); + expect(helpTextEl).toHaveTextContent(/users cannot edit their My Display Name/); + expect(helpTextEl).not.toHaveTextContent('users cannot edit their my_field'); + }); + + test('should fall back to name when display_name is missing', async () => { + const fallbackAttr: UserPropertyField = { + ...baseField, + id: 'attr_fallback', + name: 'my_field', + attrs: { + sort_order: 0, + visibility: 'when_set', + value_type: '', + ldap: 'department', + }, + }; + const state = createInitialState({fallbackAttr}); + + renderWithContext( + , + state, + ); + + const labelEl = await screen.findByTestId('custom_profile_attribute-my_fieldlabel'); + expect(labelEl.tagName).toBe('LABEL'); + expect(labelEl).toHaveTextContent('my_field'); + + const helpTextEl = screen.getByTestId('custom_profile_attribute-my_fieldhelp-text'); + expect(helpTextEl).toHaveTextContent(/users cannot edit their my_field/); + }); + }); }); diff --git a/webapp/channels/src/components/admin_console/custom_profile_attributes/custom_profile_attributes.tsx b/webapp/channels/src/components/admin_console/custom_profile_attributes/custom_profile_attributes.tsx index c58b936affa..32ae1d46061 100644 --- a/webapp/channels/src/components/admin_console/custom_profile_attributes/custom_profile_attributes.tsx +++ b/webapp/channels/src/components/admin_console/custom_profile_attributes/custom_profile_attributes.tsx @@ -21,6 +21,8 @@ import {getPluginDisplayName} from 'selectors/plugins'; import SettingsGroup from 'components/admin_console/settings_group'; import TextSetting from 'components/admin_console/text_setting'; +import {getUserPropertyFieldLabel} from 'utils/properties'; + import type {GlobalState} from 'types/store'; type AttributeHelpTextProps = { @@ -168,7 +170,7 @@ const CustomProfileAttributes: React.FC = (props: Props): JSX.Element | n { setAttributes((prevAttrs) => prevAttrs.map((a) => { @@ -194,7 +196,7 @@ const CustomProfileAttributes: React.FC = (props: Props): JSX.Element | n ) : ( ) diff --git a/webapp/channels/src/components/admin_console/system_properties/user_properties_delete_modal.test.tsx b/webapp/channels/src/components/admin_console/system_properties/user_properties_delete_modal.test.tsx index 85c397a56e4..717024bb1d2 100644 --- a/webapp/channels/src/components/admin_console/system_properties/user_properties_delete_modal.test.tsx +++ b/webapp/channels/src/components/admin_console/system_properties/user_properties_delete_modal.test.tsx @@ -9,6 +9,7 @@ import {openModal} from 'actions/views/modals'; import {renderHookWithContext, renderWithContext, screen, userEvent} from 'tests/react_testing_utils'; import {ModalIdentifiers} from 'utils/constants'; +import {getUserPropertyFieldLabel} from 'utils/properties'; import RemoveUserPropertyFieldModal, {useUserPropertyFieldDelete} from './user_properties_delete_modal'; @@ -95,7 +96,7 @@ describe('useUserPropertyFieldDelete', () => { modalId: ModalIdentifiers.USER_PROPERTY_FIELD_DELETE, dialogType: RemoveUserPropertyFieldModal, dialogProps: { - name: baseField.name, + name: getUserPropertyFieldLabel(baseField), onConfirm: expect.any(Function), isOrphaned: false, sourcePluginId: undefined, @@ -103,6 +104,23 @@ describe('useUserPropertyFieldDelete', () => { }); }); + it('passes display_name as the modal name when set', () => { + const {result} = renderHookWithContext(() => useUserPropertyFieldDelete()); + const fieldWithDisplayName = { + ...baseField, + name: 'department', + attrs: {...baseField.attrs, display_name: 'Department Head'}, + }; + + result.current.promptDelete(fieldWithDisplayName); + + expect(openModal).toHaveBeenCalledWith(expect.objectContaining({ + dialogProps: expect.objectContaining({ + name: 'Department Head', + }), + })); + }); + it('returns a promise that resolves when onConfirm is called', async () => { const {result} = renderHookWithContext(() => useUserPropertyFieldDelete()); diff --git a/webapp/channels/src/components/admin_console/system_properties/user_properties_delete_modal.tsx b/webapp/channels/src/components/admin_console/system_properties/user_properties_delete_modal.tsx index 0892a1b9bf4..05a9c0d86cd 100644 --- a/webapp/channels/src/components/admin_console/system_properties/user_properties_delete_modal.tsx +++ b/webapp/channels/src/components/admin_console/system_properties/user_properties_delete_modal.tsx @@ -11,6 +11,7 @@ import type {UserPropertyField} from '@mattermost/types/properties'; import {openModal} from 'actions/views/modals'; import {ModalIdentifiers} from 'utils/constants'; +import {getUserPropertyFieldLabel} from 'utils/properties'; type Props = { name: string; @@ -31,7 +32,7 @@ export const useUserPropertyFieldDelete = () => { modalId: ModalIdentifiers.USER_PROPERTY_FIELD_DELETE, dialogType: RemoveUserPropertyFieldModal, dialogProps: { - name: field.name, + name: getUserPropertyFieldLabel(field), onConfirm: () => resolve(true), isOrphaned, sourcePluginId: field.attrs?.source_plugin_id as string | undefined, diff --git a/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.test.tsx b/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.test.tsx index c448a7e2325..d95062fa138 100644 --- a/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.test.tsx +++ b/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.test.tsx @@ -6,11 +6,14 @@ import React from 'react'; import type {UserPropertyField} from '@mattermost/types/properties'; +import {Client4} from 'mattermost-redux/client'; + import ModalController from 'components/modal_controller'; import {renderWithContext, screen, userEvent, waitFor} from 'tests/react_testing_utils'; import DotMenu from './user_properties_dot_menu'; +import {useUserPropertyFields} from './user_properties_utils'; describe('UserPropertyDotMenu', () => { const baseField: UserPropertyField = { @@ -36,6 +39,7 @@ describe('UserPropertyDotMenu', () => { const updateField = jest.fn(); const deleteField = jest.fn(); const createField = jest.fn(); + const getFields = jest.spyOn(Client4, 'getCustomProfileAttributeFields'); const renderComponent = (field: UserPropertyField = baseField, dotMenuProps?: Partial>) => { return renderWithContext( @@ -55,6 +59,11 @@ describe('UserPropertyDotMenu', () => { ); }; + beforeEach(() => { + jest.clearAllMocks(); + getFields.mockReset(); + }); + it('renders dot menu button', () => { renderComponent(); @@ -221,14 +230,66 @@ describe('UserPropertyDotMenu', () => { // Wait for createField to be called await waitFor(() => { - // Verify createField was called with the correct parameters + // Verify createField was called with the slugified snake_case name + // ('Test Field' -> 'test_field') plus the _copy suffix. expect(createField).toHaveBeenCalledWith(expect.objectContaining({ id: baseField.id, - name: 'Test Field (copy)', + name: 'test_field_copy', })); }); }); + it('duplicate produces _2 suffix when base name is already taken', async () => { + const existingCopy = { + ...baseField, + id: 'copy-id', + name: 'test_field_copy', + attrs: { + ...baseField.attrs, + sort_order: 1, + }, + }; + getFields.mockResolvedValueOnce([baseField, existingCopy]); + + const Harness = () => { + const [fields, readIO,, itemOps] = useUserPropertyFields(); + + if (readIO.loading || !fields.data[baseField.id]) { + return null; + } + + return ( +
+ + {fields.order.map((id) => ( + + {fields.data[id].name} + + ))} +
+ ); + }; + + renderWithContext(); + + const menuButton = await screen.findByTestId(`user-property-field_dotmenu-${baseField.id}`); + await userEvent.click(menuButton); + await userEvent.click(screen.getByText(/Duplicate attribute/)); + + await waitFor(() => { + expect(screen.getByText('test_field_copy_2')).toBeInTheDocument(); + }); + }); + it('hides field duplication when at field limit', async () => { renderComponent(undefined, {canCreate: false}); diff --git a/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.tsx b/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.tsx index 162e2d54544..eb9d08fc653 100644 --- a/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.tsx +++ b/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.tsx @@ -2,7 +2,7 @@ // See LICENSE.txt for license information. import React from 'react'; -import {FormattedMessage, useIntl} from 'react-intl'; +import {FormattedMessage} from 'react-intl'; import {useDispatch} from 'react-redux'; import {CheckIcon, ChevronRightIcon, DotsHorizontalIcon, EyeOutlineIcon, LockOutlineIcon, PencilOutlineIcon, SyncIcon, TrashCanOutlineIcon, ContentCopyIcon} from '@mattermost/compass-icons/components'; @@ -14,6 +14,7 @@ import * as Menu from 'components/menu'; import Toggle from 'components/toggle'; import {ModalIdentifiers} from 'utils/constants'; +import {slugifyForCEL} from 'utils/properties'; import AttributeModal from './attribute_modal'; import {useUserPropertyFieldDelete} from './user_properties_delete_modal'; @@ -114,18 +115,13 @@ const DotMenu = ({ updateField, deleteField, }: Props) => { - const {formatMessage} = useIntl(); const {promptDelete} = useUserPropertyFieldDelete(); const {promptEditLdapLink, promptEditSamlLink} = useAttributeLinkModal(field, updateField); const isProtected = Boolean(field.attrs?.protected); const handleDuplicate = () => { - const name = formatMessage({ - id: 'admin.system_properties.user_properties.dotmenu.duplicate.name_copy', - defaultMessage: '{fieldName} (copy)', - }, {fieldName: field.name}); - + const name = `${slugifyForCEL(field.name)}_copy`; createField({...field, attrs: {...field.attrs}, name}); }; diff --git a/webapp/channels/src/components/admin_console/system_properties/user_properties_table.test.tsx b/webapp/channels/src/components/admin_console/system_properties/user_properties_table.test.tsx index 49409f4add5..32dfc3960e8 100644 --- a/webapp/channels/src/components/admin_console/system_properties/user_properties_table.test.tsx +++ b/webapp/channels/src/components/admin_console/system_properties/user_properties_table.test.tsx @@ -1,14 +1,19 @@ // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. // See LICENSE.txt for license information. +import {act} from '@testing-library/react'; import React from 'react'; import type {UserPropertyField} from '@mattermost/types/properties'; import {collectionFromArray} from '@mattermost/types/utilities'; +import {Client4} from 'mattermost-redux/client'; + import {fireEvent, renderWithContext, screen, userEvent, waitFor} from 'tests/react_testing_utils'; +import Constants from 'utils/constants'; -import {UserPropertiesTable} from './user_properties_table'; +import {UserPropertiesTable, useUserPropertiesTable} from './user_properties_table'; +import {ValidationWarningNameInvalidCEL} from './user_properties_utils'; jest.mock('./user_properties_delete_modal', () => ({ useUserPropertyFieldDelete: jest.fn(() => ({ @@ -67,9 +72,7 @@ describe('UserPropertiesTable', () => { const deleteField = jest.fn(); const reorderField = jest.fn(); - const renderComponent = (fields = baseFields) => { - const collection = collectionFromArray(fields); - + const renderComponent = (fields = baseFields, collection = collectionFromArray(fields)) => { return renderWithContext( { renderComponent(); // Check column headers - expect(screen.getByText('Attribute')).toBeInTheDocument(); + expect(screen.getByText('Name')).toBeInTheDocument(); + expect(screen.getByText('Display Name')).toBeInTheDocument(); expect(screen.getByText('Type')).toBeInTheDocument(); expect(screen.getByText('Values')).toBeInTheDocument(); expect(screen.getByText('Actions')).toBeInTheDocument(); @@ -96,6 +100,19 @@ describe('UserPropertiesTable', () => { expect(screen.getByDisplayValue('Field 2')).toBeInTheDocument(); expect(screen.getByText('Text')).toBeInTheDocument(); expect(screen.getByText('Select')).toBeInTheDocument(); + expect(screen.getAllByTestId('property-display-name-input')[0]).toHaveValue(''); + }); + + it('shows display_name value in the Display name column', () => { + const fields = baseFields.map((field, index) => ({ + ...field, + attrs: {...field.attrs, display_name: `Display ${index + 1}`}, + })); + + renderComponent(fields); + + expect(screen.getByDisplayValue('Display 1')).toBeInTheDocument(); + expect(screen.getByDisplayValue('Display 2')).toBeInTheDocument(); }); it('allows editing field names', async () => { @@ -103,17 +120,28 @@ describe('UserPropertiesTable', () => { const field1Input = screen.getByDisplayValue('Field 1'); await userEvent.clear(field1Input); - await userEvent.type(field1Input, 'Edited Field 1'); + await userEvent.type(field1Input, 'EditedField1'); - // Trigger blur to save the edited field name - fireEvent used because userEvent doesn't have direct focus/blur methods fireEvent.blur(field1Input); expect(updateField).toHaveBeenCalledWith({ ...baseFields[0], - name: 'Edited Field 1', + name: 'EditedField1', }); }); + it('calls updateField with updated display_name on blur', async () => { + renderComponent(); + + const displayNameInput = screen.getAllByTestId('property-display-name-input')[0]; + await userEvent.type(displayNameInput, 'Department Head'); + fireEvent.blur(displayNameInput); + + expect(updateField).toHaveBeenCalledWith(expect.objectContaining({ + attrs: expect.objectContaining({display_name: 'Department Head'}), + })); + }); + it('shows type selection menu', () => { renderComponent(); @@ -174,6 +202,49 @@ describe('UserPropertiesTable', () => { }); }); + it('shows CEL validation error for invalid identifiers', async () => { + const collection = collectionFromArray(baseFields); + collection.warnings = { + field1: {name: ValidationWarningNameInvalidCEL}, + }; + + renderComponent(baseFields, collection); + + await waitFor(() => { + expect(screen.getByText(/Identifier must start with a letter or underscore/)).toBeInTheDocument(); + }); + expect(screen.getByTestId('property-field-validation-error')).toBeInTheDocument(); + }); + + it('editing display_name of a legacy invalid-named field does not fire CEL warning', async () => { + const legacyField = { + ...baseFields[0], + name: 'My Legacy Field', + }; + + renderComponent([legacyField]); + + const displayNameInput = screen.getByTestId('property-display-name-input'); + await userEvent.type(displayNameInput, 'Legacy Display'); + fireEvent.blur(displayNameInput); + + expect(screen.queryByText(/Identifier must start with a letter or underscore/)).not.toBeInTheDocument(); + }); + + it('editing identifier of a legacy field triggers CEL validation', async () => { + const legacyField = {...baseFields[0], name: 'My Legacy Field'}; + const collection = collectionFromArray([legacyField]); + collection.warnings = { + [legacyField.id]: {name: ValidationWarningNameInvalidCEL}, + }; + + renderComponent([legacyField], collection); + + await waitFor(() => { + expect(screen.getByText(/Reserved CEL words are not allowed/)).toBeInTheDocument(); + }); + }); + it('autofocuses name input for new text field', () => { const pendingTextField: UserPropertyField = { id: 'pending-text', @@ -197,9 +268,9 @@ describe('UserPropertiesTable', () => { renderComponent([...baseFields, pendingTextField]); - // The name input for the new text field should be autofocused - const nameInputs = screen.getAllByTestId('property-field-input'); - expect(document.activeElement).toBe(nameInputs[2]); + // The display name input for the new text field should be autofocused + const displayNameInputs = screen.getAllByTestId('property-display-name-input'); + expect(document.activeElement).toBe(displayNameInputs[2]); }); it('autofocuses values input for new select field', () => { @@ -260,3 +331,442 @@ describe('UserPropertiesTable', () => { expect(document.activeElement).toBe(comboboxes[comboboxes.length - 1]); }); }); + +describe('UserPropertiesTable input filtering', () => { + const baseFields: UserPropertyField[] = [ + { + id: 'field1', + name: 'Field1', + type: 'text', + group_id: 'custom_profile_attributes', + create_at: 1736541716295, + delete_at: 0, + update_at: 0, + created_by: '', + updated_by: '', + target_id: '', + target_type: '', + object_type: '', + attrs: { + sort_order: 0, + visibility: 'when_set', + value_type: '', + }, + }, + ]; + + const createField = jest.fn(); + const updateField = jest.fn(); + const deleteField = jest.fn(); + const reorderField = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('Name column has sanitize prop that strips invalid characters', () => { + renderWithContext( + , + ); + + const nameInput = screen.getByTestId('property-field-input'); + expect(nameInput).toBeInTheDocument(); + + // filterCELIdentifier is tested in detail in properties.test.ts. + // Here we verify it's wired up: the input exists and carries an + // aria-label confirming it's the Name column input. + expect(nameInput).toHaveAttribute('aria-label', 'Attribute Name'); + }); + + it('Display Name auto-fills Name for new pending text field via live preview', async () => { + const pendingField: UserPropertyField = { + id: 'pending-new', + name: '', + type: 'text', + group_id: 'custom_profile_attributes', + create_at: 0, + delete_at: 0, + update_at: 0, + created_by: '', + updated_by: '', + target_id: '', + target_type: '', + object_type: '', + attrs: { + sort_order: 1, + visibility: 'when_set', + value_type: '', + }, + }; + + renderWithContext( + , + ); + + // Type a single character in Display Name - each keystroke triggers + // handleDisplayNameChange which updates the Name column's live preview. + const displayNameInput = screen.getByTestId('property-display-name-input'); + await userEvent.type(displayNameInput, 'D'); + + // The Name input should show the slugified snake_case preview + const nameInput = screen.getByTestId('property-field-input'); + await waitFor(() => { + expect(nameInput).toHaveValue('d'); + }); + }); + + it('Auto-fill deactivates permanently when user manually edits the Name input', async () => { + const pendingField: UserPropertyField = { + id: 'pending-deactivate', + name: '', + type: 'text', + group_id: 'custom_profile_attributes', + create_at: 0, + delete_at: 0, + update_at: 0, + created_by: '', + updated_by: '', + target_id: '', + target_type: '', + object_type: '', + attrs: { + sort_order: 0, + visibility: 'when_set', + value_type: '', + }, + }; + + renderWithContext( + , + ); + + const nameInput = screen.getByTestId('property-field-input'); + + // Manually type in the Name input to diverge from any auto-fill + await userEvent.type(nameInput, 'custom'); + fireEvent.blur(nameInput); + + updateField.mockClear(); + + // Now type in Display Name and blur + const displayNameInput = screen.getByTestId('property-display-name-input'); + await userEvent.type(displayNameInput, 'Department'); + fireEvent.blur(displayNameInput); + + // updateField should be called with display_name change but NOT with + // a name override — the manual edit deactivated auto-fill + const nameChangeCalls = updateField.mock.calls.filter( + (call: [UserPropertyField]) => call[0].name && call[0].name !== 'custom', + ); + expect(nameChangeCalls).toHaveLength(0); + }); + + it('Auto-fill does not fire for existing fields', async () => { + renderWithContext( + , + ); + + const displayNameInput = screen.getByTestId('property-display-name-input'); + await userEvent.type(displayNameInput, 'New Display'); + fireEvent.blur(displayNameInput); + + const nameUpdateCalls = updateField.mock.calls.filter( + (call: [UserPropertyField]) => call[0].name !== baseFields[0].name, + ); + expect(nameUpdateCalls).toHaveLength(0); + }); + + it('auto-fill freezes when Display Name slugifies to a reserved word', async () => { + const pendingField: UserPropertyField = { + id: 'pending-reserved', + name: '', + type: 'text', + group_id: 'custom_profile_attributes', + create_at: 0, + delete_at: 0, + update_at: 0, + created_by: '', + updated_by: '', + target_id: '', + target_type: '', + object_type: '', + attrs: { + sort_order: 0, + visibility: 'when_set', + value_type: '', + }, + }; + + renderWithContext( + , + ); + + const displayNameInput = screen.getByTestId('property-display-name-input'); + const nameInput = screen.getByTestId('property-field-input'); + + fireEvent.change(displayNameInput, {target: {value: 'function'}}); + expect(nameInput).toHaveValue(''); + + await userEvent.clear(displayNameInput); + await userEvent.type(displayNameInput, 'dept'); + await waitFor(() => { + expect(nameInput).toHaveValue('dept'); + }); + + fireEvent.change(displayNameInput, {target: {value: 'true'}}); + expect(nameInput).toHaveValue('dept'); + }); + + it('auto-fill converts multi-word display names to snake_case', async () => { + const pendingField: UserPropertyField = { + id: 'pending-snake-case', + name: '', + type: 'text', + group_id: 'custom_profile_attributes', + create_at: 0, + delete_at: 0, + update_at: 0, + created_by: '', + updated_by: '', + target_id: '', + target_type: '', + object_type: '', + attrs: { + sort_order: 0, + visibility: 'when_set', + value_type: '', + }, + }; + + renderWithContext( + , + ); + + const displayNameInput = screen.getByTestId('property-display-name-input'); + const nameInput = screen.getByTestId('property-field-input'); + + fireEvent.change(displayNameInput, {target: {value: 'My Field Name'}}); + await waitFor(() => { + expect(nameInput).toHaveValue('my_field_name'); + }); + + fireEvent.change(displayNameInput, {target: {value: 'XMLParser'}}); + await waitFor(() => { + expect(nameInput).toHaveValue('xml_parser'); + }); + + fireEvent.change(displayNameInput, {target: {value: 'Does this work?'}}); + await waitFor(() => { + expect(nameInput).toHaveValue('does_this_work'); + }); + }); + + it('auto-fill prepends underscore for leading-digit display names', async () => { + const pendingField: UserPropertyField = { + id: 'pending-leading-digit', + name: '', + type: 'text', + group_id: 'custom_profile_attributes', + create_at: 0, + delete_at: 0, + update_at: 0, + created_by: '', + updated_by: '', + target_id: '', + target_type: '', + object_type: '', + attrs: { + sort_order: 0, + visibility: 'when_set', + value_type: '', + }, + }; + + renderWithContext( + , + ); + + const displayNameInput = screen.getByTestId('property-display-name-input'); + await userEvent.type(displayNameInput, '7Department'); + + const nameInput = screen.getByTestId('property-field-input'); + await waitFor(() => { + // snake_case inserts a boundary between the digit and the + // following uppercase letter before lowercasing. + expect(nameInput).toHaveValue('_7_department'); + }); + }); + + it('auto-fill truncates display names longer than the max attribute name length', async () => { + const pendingField: UserPropertyField = { + id: 'pending-truncate', + name: '', + type: 'text', + group_id: 'custom_profile_attributes', + create_at: 0, + delete_at: 0, + update_at: 0, + created_by: '', + updated_by: '', + target_id: '', + target_type: '', + object_type: '', + attrs: { + sort_order: 0, + visibility: 'when_set', + value_type: '', + }, + }; + + renderWithContext( + , + ); + + const displayNameInput = screen.getByTestId('property-display-name-input'); + + // fireEvent.change bypasses the input's maxLength so an oversize value + // reaches the onChange handler and exercises the truncation branch. + const oversize = Constants.MAX_CUSTOM_ATTRIBUTE_NAME_LENGTH + 1; + fireEvent.change(displayNameInput, {target: {value: 'a'.repeat(oversize)}}); + + const nameInput = screen.getByTestId('property-field-input') as HTMLInputElement; + await waitFor(() => { + expect(nameInput.value).toHaveLength(Constants.MAX_CUSTOM_ATTRIBUTE_NAME_LENGTH); + }); + }); +}); + +describe('useUserPropertiesTable grandfather regression', () => { + const getFields = jest.spyOn(Client4, 'getCustomProfileAttributeFields'); + const patchField = jest.spyOn(Client4, 'patchCustomProfileAttributeField'); + + afterEach(() => { + getFields.mockReset(); + patchField.mockReset(); + }); + + it('rename of legacy field clears grandfather: subsequent edits to the now-valid name trigger validation', async () => { + const legacyField: UserPropertyField = { + id: 'legacy-field', + name: 'dept head', + type: 'text', + group_id: 'custom_profile_attributes', + create_at: 1736541716295, + delete_at: 0, + update_at: 0, + created_by: '', + updated_by: '', + target_id: '', + target_type: '', + object_type: '', + attrs: { + sort_order: 0, + visibility: 'when_set', + value_type: '', + }, + }; + + getFields.mockResolvedValue([legacyField]); + patchField.mockImplementation(async (id, patch) => ({ + ...legacyField, + ...patch, + id, + attrs: { + ...legacyField.attrs, + ...patch.attrs, + }, + update_at: Date.now(), + } as UserPropertyField)); + + let latestSection!: ReturnType; + const HookHarness = () => { + latestSection = useUserPropertiesTable(); + return <>{latestSection.content}; + }; + + renderWithContext(); + + await waitFor(() => { + expect(screen.getByDisplayValue('dept head')).toBeInTheDocument(); + }); + + const identifierInput = screen.getByDisplayValue('dept head'); + await userEvent.clear(identifierInput); + await userEvent.type(identifierInput, 'dept_head'); + fireEvent.blur(identifierInput); + + await act(async () => { + await latestSection.save(); + }); + + await waitFor(() => { + expect(patchField).toHaveBeenCalledWith('legacy-field', expect.objectContaining({name: 'dept_head'})); + expect(screen.getByDisplayValue('dept_head')).toBeInTheDocument(); + }); + + const renamedInput = screen.getByDisplayValue('dept_head'); + await userEvent.clear(renamedInput); + await userEvent.type(renamedInput, 'for'); + fireEvent.blur(renamedInput); + + await waitFor(() => { + expect(screen.getByText(/Identifier must start with a letter or underscore/)).toBeInTheDocument(); + }); + }); +}); diff --git a/webapp/channels/src/components/admin_console/system_properties/user_properties_table.tsx b/webapp/channels/src/components/admin_console/system_properties/user_properties_table.tsx index f30423460de..8c0e5be3257 100644 --- a/webapp/channels/src/components/admin_console/system_properties/user_properties_table.tsx +++ b/webapp/channels/src/components/admin_console/system_properties/user_properties_table.tsx @@ -3,17 +3,19 @@ import {createColumnHelper, getCoreRowModel, getSortedRowModel, useReactTable, type ColumnDef} from '@tanstack/react-table'; import type {ReactNode} from 'react'; -import React, {useEffect, useMemo, useState} from 'react'; +import React, {useCallback, useEffect, useMemo, useRef, useState} from 'react'; import {FormattedMessage, useIntl} from 'react-intl'; import styled from 'styled-components'; -import {PlusIcon} from '@mattermost/compass-icons/components'; +import {InformationOutlineIcon, PlusIcon} from '@mattermost/compass-icons/components'; +import {WithTooltip} from '@mattermost/shared/components/tooltip'; import {supportsOptions, type UserPropertyField} from '@mattermost/types/properties'; import {collectionToArray} from '@mattermost/types/utilities'; import LoadingScreen from 'components/loading_screen'; import Constants from 'utils/constants'; +import {CPA_FIELD_NAME_RESERVED_WORDS, filterCELIdentifier, slugifyForCEL} from 'utils/properties'; import {DangerText, BorderlessInput, LinkButton} from './controls'; import {useIsFieldOrphaned} from './orphaned_fields_utils'; @@ -22,11 +24,13 @@ import DotMenu from './user_properties_dot_menu'; import OrphanedFieldDeleteButton from './user_properties_orphaned_delete_button'; import SelectType from './user_properties_type_menu'; import type {UserPropertyFields} from './user_properties_utils'; -import {isCreatePending, useUserPropertyFields, ValidationWarningNameRequired, ValidationWarningNameTaken, ValidationWarningNameUnique} from './user_properties_utils'; +import {isCreatePending, useUserPropertyFields, ValidationWarningNameInvalidCEL, ValidationWarningNameRequired, ValidationWarningNameTaken, ValidationWarningNameUnique} from './user_properties_utils'; import UserPropertyValues from './user_properties_values'; import {AdminConsoleListTable} from '../list_table'; +const columnHelper = createColumnHelper(); + type FieldActions = { createField: (field: UserPropertyField) => void; updateField: (field: UserPropertyField) => void; @@ -104,18 +108,149 @@ export function UserPropertiesTable({ }: Props & FieldActions) { const {formatMessage} = useIntl(); const data = collectionToArray(collection); - const col = createColumnHelper(); + + const autoFillActiveRef = useRef>(new Set()); + const nameOverridesRef = useRef>({}); + const [, forceNameUpdate] = useState(0); + + const computeAutoFillSlug = useCallback((displayName: string): string | null => { + let slug = slugifyForCEL(displayName); + + // slugifyForCEL returns '_copy' when the input normalizes to empty; + // treat that as "nothing to auto-fill" rather than writing '_copy' + // into the Name field. + if (slug === '_copy' || CPA_FIELD_NAME_RESERVED_WORDS.has(slug)) { + return null; + } + const runes = [...slug]; + if (runes.length > Constants.MAX_CUSTOM_ATTRIBUTE_NAME_LENGTH) { + slug = runes.slice(0, Constants.MAX_CUSTOM_ATTRIBUTE_NAME_LENGTH).join(''); + } + return slug; + }, []); + + const handleDisplayNameChange = useCallback((rowId: string, value: string) => { + if (!autoFillActiveRef.current.has(rowId)) { + return; + } + const slug = computeAutoFillSlug(value); + if (slug === null) { + return; + } + if (nameOverridesRef.current[rowId] !== slug) { + nameOverridesRef.current = {...nameOverridesRef.current, [rowId]: slug}; + forceNameUpdate((n) => n + 1); + } + }, [computeAutoFillSlug]); + + // Returns the auto-filled name slug if auto-fill is active for this row, + // or null if auto-fill is inactive or the slug is invalid/reserved. + const getAutoFillSlug = useCallback((rowId: string, displayNameValue: string): string | null => { + if (!autoFillActiveRef.current.has(rowId)) { + return null; + } + return computeAutoFillSlug(displayNameValue); + }, [computeAutoFillSlug]); + + // This callback only fires from manual user edits to the Name (the + // DOM onChange event). Auto-fill updates via liveValue → useEffect → setValue + // bypass the onChange handler entirely, so this comparison is always between + // what the user manually typed and the expected slug derived from the + // *committed* display_name. This invariant is what makes the deactivation + // check correct — do not refactor liveValue to go through onChange. + const handleNameChange = useCallback((rowId: string, value: string, currentField: UserPropertyField) => { + const displayName = currentField.attrs?.display_name ?? ''; + const expectedSlug = slugifyForCEL(displayName); + if (value !== expectedSlug) { + autoFillActiveRef.current.delete(rowId); + if (Object.prototype.hasOwnProperty.call(nameOverridesRef.current, rowId)) { + const next = {...nameOverridesRef.current}; + Reflect.deleteProperty(next, rowId); + nameOverridesRef.current = next; + } + } + }, []); + + // Activate auto-fill for newly created pending rows with empty names + useEffect(() => { + for (const field of data) { + if (isCreatePending(field) && field.name === '' && !autoFillActiveRef.current.has(field.id)) { + autoFillActiveRef.current.add(field.id); + } + } + }, [data]); + const columns = useMemo>>(() => { return [ - col.accessor('name', { + columnHelper.accessor((row) => row.attrs?.display_name ?? '', { + id: 'display_name', + size: 200, + header: () => ( + + + + ), + cell: ({getValue, row}) => { + const toDelete = row.original.delete_at !== 0; + const isProtected = Boolean(row.original.attrs?.protected); + + return ( + { + handleDisplayNameChange(row.original.id, value); + }} + setValue={(value: string) => { + const slug = getAutoFillSlug(row.original.id, value); + updateField({ + ...row.original, + ...(slug === null ? {} : {name: slug}), + attrs: { + ...row.original.attrs, + display_name: value.trim() || undefined, + }, + }); + }} + /> + ); + }, + enableHiding: false, + enableSorting: false, + }), + columnHelper.accessor('name', { size: 180, header: () => { return ( - + + + + + + + + ); }, @@ -150,18 +285,30 @@ export function UserPropertiesTable({ defaultMessage='Attribute name already taken.' /> ); + } else if (warningId === ValidationWarningNameInvalidCEL) { + warning = ( + + + + ); } return ( <> { + handleNameChange(row.original.id, value, row.original); + }} setValue={(value: string) => { updateField({...row.original, name: value.trim()}); }} @@ -174,7 +321,7 @@ export function UserPropertiesTable({ enableHiding: false, enableSorting: false, }), - col.accessor('type', { + columnHelper.accessor('type', { size: 100, header: () => { return ( @@ -197,7 +344,7 @@ export function UserPropertiesTable({ enableHiding: false, enableSorting: false, }), - col.display({ + columnHelper.display({ id: 'options', size: 300, header: () => ( @@ -220,7 +367,7 @@ export function UserPropertiesTable({ enableHiding: false, enableSorting: false, }), - col.display({ + columnHelper.display({ id: 'actions', size: 40, header: () => { @@ -248,7 +395,8 @@ export function UserPropertiesTable({ enableSorting: false, }), ]; - }, [createField, updateField, deleteField, collection.warnings, canCreate]); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [createField, updateField, deleteField, collection.warnings, canCreate, handleDisplayNameChange, getAutoFillSlug, handleNameChange, formatMessage]); const table = useReactTable({ data, @@ -339,6 +487,19 @@ const ColHeaderRight = styled.div` text-align: right; `; +const NameHeaderLabel = styled.span` + display: inline-flex; + align-items: center; + gap: 4px; +`; + +const InfoIconWrapper = styled.span` + display: inline-flex; + align-items: center; + color: rgba(var(--center-channel-color-rgb), 0.56); + cursor: pointer; +`; + const ActionsRoot = styled.div` text-align: right; `; @@ -376,9 +537,12 @@ const ActionsCell = ({field, canCreate, createField, updateField, deleteField}: type EditCellProps = { value: string; + liveValue?: string; label?: string; testid?: string; setValue: (value: string) => void; + onChange?: (value: string) => void; + sanitize?: (value: string) => string; autoFocus?: boolean; disabled?: boolean; deleted?: boolean; @@ -393,6 +557,12 @@ const EditCell = (props: EditCellProps) => { setValue(props.value); }, [props.value]); + useEffect(() => { + if (props.liveValue !== undefined) { + setValue(props.liveValue); + } + }, [props.liveValue]); + return ( <> { }} value={value} onChange={(e: React.ChangeEvent) => { - setValue(e.target.value); + let next = e.target.value; + if (props.sanitize) { + next = props.sanitize(next); + } + setValue(next); + props.onChange?.(next); }} onBlur={() => { if (value !== props.value) { diff --git a/webapp/channels/src/components/admin_console/system_properties/user_properties_utils.test.ts b/webapp/channels/src/components/admin_console/system_properties/user_properties_utils.test.ts index 6870d718b1b..8aa164441e5 100644 --- a/webapp/channels/src/components/admin_console/system_properties/user_properties_utils.test.ts +++ b/webapp/channels/src/components/admin_console/system_properties/user_properties_utils.test.ts @@ -16,6 +16,7 @@ import type {GlobalState} from 'types/store'; import { useUserPropertyFields, + ValidationWarningNameInvalidCEL, ValidationWarningNameRequired, ValidationWarningNameTaken, ValidationWarningNameUnique, @@ -50,7 +51,7 @@ describe('useUserPropertyFields', () => { const baseField: UserPropertyField = { id: 'test-id', - name: 'Test Field', + name: 'test_field', type: 'text' as const, group_id: 'custom_profile_attributes', create_at: 1736541716295, @@ -68,10 +69,10 @@ describe('useUserPropertyFields', () => { }, }; - const field0: UserPropertyField = {...baseField, id: 'test-id-0', name: 'test attribute 0', attrs: {...baseField.attrs, sort_order: 0}}; - const field1: UserPropertyField = {...baseField, id: 'test-id-1', name: 'test attribute 1', attrs: {...baseField.attrs, sort_order: 1}}; - const field2: UserPropertyField = {...baseField, id: 'test-id-2', name: 'test attribute 2', attrs: {...baseField.attrs, sort_order: 2}}; - const field3: UserPropertyField = {...baseField, id: 'test-id-3', name: 'test attribute 3', attrs: {...baseField.attrs, sort_order: 3}}; + const field0: UserPropertyField = {...baseField, id: 'test-id-0', name: 'test_attribute_0', attrs: {...baseField.attrs, sort_order: 0}}; + const field1: UserPropertyField = {...baseField, id: 'test-id-1', name: 'test_attribute_1', attrs: {...baseField.attrs, sort_order: 1}}; + const field2: UserPropertyField = {...baseField, id: 'test-id-2', name: 'test_attribute_2', attrs: {...baseField.attrs, sort_order: 2}}; + const field3: UserPropertyField = {...baseField, id: 'test-id-3', name: 'test_attribute_3', attrs: {...baseField.attrs, sort_order: 3}}; getFields.mockResolvedValue([field0, field1, field2, field3]); @@ -121,12 +122,12 @@ describe('useUserPropertyFields', () => { const [fields2,,, ops2] = result.current; act(() => { - ops2.update({...fields2.data[field1.id], name: 'changed attribute value'}); + ops2.update({...fields2.data[field1.id], name: 'changed_attribute_value'}); }); rerender(); const [fields3, readIO3, pendingIO3] = result.current; - expect(fields3.data[field1.id].name).toBe('changed attribute value'); + expect(fields3.data[field1.id].name).toBe('changed_attribute_value'); expect(pendingIO3.hasChanges).toBe(true); patchField.mockResolvedValue({...fields3.data[field1.id]}); @@ -145,12 +146,12 @@ describe('useUserPropertyFields', () => { expect(pending.saving).toBe(false); }); - expect(patchField).toHaveBeenCalledWith(field1.id, {type: 'text', name: 'changed attribute value', attrs: {sort_order: 1, value_type: '', visibility: 'when_set'}}); + expect(patchField).toHaveBeenCalledWith(field1.id, {type: 'text', name: 'changed_attribute_value', attrs: {sort_order: 1, value_type: '', visibility: 'when_set'}}); const [fields4,, pendingIO4] = result.current; expect(pendingIO4.hasChanges).toBe(false); expect(pendingIO4.error).toBe(undefined); - expect(fields4.data[field1.id].name).toBe('changed attribute value'); + expect(fields4.data[field1.id].name).toBe('changed_attribute_value'); }); it('should successfully handle reordering', async () => { @@ -197,8 +198,8 @@ describe('useUserPropertyFields', () => { expect(pending.saving).toBe(false); }); - expect(patchField).toHaveBeenCalledWith(field1.id, {type: 'text', name: 'test attribute 1', attrs: {sort_order: 0, value_type: '', visibility: 'when_set'}}); - expect(patchField).toHaveBeenCalledWith(field0.id, {type: 'text', name: 'test attribute 0', attrs: {sort_order: 1, value_type: '', visibility: 'when_set'}}); + expect(patchField).toHaveBeenCalledWith(field1.id, {type: 'text', name: 'test_attribute_1', attrs: {sort_order: 0, value_type: '', visibility: 'when_set'}}); + expect(patchField).toHaveBeenCalledWith(field0.id, {type: 'text', name: 'test_attribute_0', attrs: {sort_order: 1, value_type: '', visibility: 'when_set'}}); const [fields4,, pendingIO4] = result.current; expect(pendingIO4.hasChanges).toBe(false); @@ -294,11 +295,11 @@ describe('useUserPropertyFields', () => { expect(pendingIO4.saving).toBe(false); }); - expect(createField).toHaveBeenCalledWith({type: 'text', name: 'Text', attrs: {sort_order: 4, value_type: '', visibility: 'when_set'}}); + expect(createField).toHaveBeenCalledWith({type: 'text', name: '', attrs: {sort_order: 4, value_type: '', visibility: 'when_set'}}); const [fields4,,,] = result.current; expect(Object.values(fields4.data)).toEqual(expect.arrayContaining([ - expect.objectContaining({name: 'Text'}), + expect.objectContaining({name: ''}), ])); expect(fields4.order).toEqual(expect.arrayContaining(Object.keys(fields4.data))); @@ -321,12 +322,12 @@ describe('useUserPropertyFields', () => { act(() => { const [fields,,, ops] = result.current; - ops.update({...fields.data[field0.id], name: 'test attribute 1'}); + ops.update({...fields.data[field0.id], name: 'test_attribute_1'}); }); rerender(); const [fields,, pendingIO3] = result.current; - expect(fields.data[field0.id].name).toBe('test attribute 1'); + expect(fields.data[field0.id].name).toBe('test_attribute_1'); expect(pendingIO3.hasChanges).toBe(true); expect(fields.warnings).toEqual(expect.objectContaining({ [field0.id]: {name: ValidationWarningNameUnique}, @@ -352,8 +353,8 @@ describe('useUserPropertyFields', () => { act(() => { const [fields,,, ops] = result.current; - ops.update({...fields.data[field0.id], name: 'test attribute 1'}); - ops.update({...fields.data[field1.id], name: 'Test something else'}); + ops.update({...fields.data[field0.id], name: 'test_attribute_1'}); + ops.update({...fields.data[field1.id], name: 'Test_something_else'}); }); rerender(); @@ -402,4 +403,223 @@ describe('useUserPropertyFields', () => { [field0.id]: {name: ValidationWarningNameRequired}, })); }); + + it('should NOT trigger Required warning for a freshly-created untouched field', async () => { + const {result, rerender} = renderHookWithContext(() => useUserPropertyFields(), getBaseState()); + + act(() => { + jest.runAllTimers(); + }); + rerender(); + + await waitFor(() => { + const [, read] = result.current; + expect(read.loading).toBe(false); + }); + + act(() => { + const [,,, ops] = result.current; + ops.create(); + }); + rerender(); + + const [fields] = result.current; + const createdId = fields.order[fields.order.length - 1]; + + expect(fields.data[createdId].create_at).toBe(0); + expect(fields.data[createdId].name).toBe(''); + expect(fields.data[createdId].attrs?.display_name).toBeUndefined(); + expect(fields.warnings?.[createdId]).toBeUndefined(); + }); + + it('should preserve required warning precedence when multiple names are empty', async () => { + const {result, rerender} = renderHookWithContext(() => { + return useUserPropertyFields(); + }, getBaseState()); + + act(() => { + jest.runAllTimers(); + }); + rerender(); + + await waitFor(() => { + const [, read] = result.current; + expect(read.loading).toBe(false); + }); + + act(() => { + const [fields,,, ops] = result.current; + ops.update({...fields.data[field0.id], name: ''}); + ops.update({...fields.data[field1.id], name: ''}); + }); + rerender(); + + const [fields] = result.current; + expect(fields.warnings).toEqual(expect.objectContaining({ + [field0.id]: {name: ValidationWarningNameRequired}, + [field1.id]: {name: ValidationWarningNameRequired}, + })); + }); + + it('should flag CEL invalid name on new field', async () => { + const {result, rerender} = renderHookWithContext(() => useUserPropertyFields(), getBaseState()); + + act(() => { + jest.runAllTimers(); + }); + rerender(); + + await waitFor(() => { + const [, read] = result.current; + expect(read.loading).toBe(false); + }); + + act(() => { + const [,,, ops] = result.current; + ops.create({name: 'My Field'}); + }); + rerender(); + + const [fields] = result.current; + const createdId = fields.order[fields.order.length - 1]; + expect(fields.warnings).toEqual(expect.objectContaining({ + [createdId]: {name: ValidationWarningNameInvalidCEL}, + })); + }); + + it('should flag reserved word name', async () => { + const {result, rerender} = renderHookWithContext(() => useUserPropertyFields(), getBaseState()); + + act(() => { + jest.runAllTimers(); + }); + rerender(); + + await waitFor(() => { + const [, read] = result.current; + expect(read.loading).toBe(false); + }); + + act(() => { + const [fields,,, ops] = result.current; + ops.update({...fields.data[field0.id], name: 'in'}); + }); + rerender(); + + const [fields] = result.current; + expect(fields.warnings).toEqual(expect.objectContaining({ + [field0.id]: {name: ValidationWarningNameInvalidCEL}, + })); + }); + + it('should NOT flag CEL error when name is unchanged (grandfather)', async () => { + const legacyField = {...field0, name: 'My Legacy Field'}; + getFields.mockResolvedValueOnce([legacyField, field1, field2, field3]); + + const {result, rerender} = renderHookWithContext(() => useUserPropertyFields(), getBaseState()); + + act(() => { + jest.runAllTimers(); + }); + rerender(); + + await waitFor(() => { + const [, read] = result.current; + expect(read.loading).toBe(false); + }); + + act(() => { + const [fields,,, ops] = result.current; + ops.update({ + ...fields.data[legacyField.id], + attrs: { + ...fields.data[legacyField.id].attrs, + display_name: 'My Legacy Label', + }, + }); + }); + rerender(); + + const [fields] = result.current; + expect(fields.warnings?.[legacyField.id]?.name).not.toBe(ValidationWarningNameInvalidCEL); + }); + + it('should flag CEL error when renaming a legacy invalid-named field to another invalid name', async () => { + const legacyField = {...field0, name: 'My Legacy Field'}; + getFields.mockResolvedValueOnce([legacyField, field1, field2, field3]); + + const {result, rerender} = renderHookWithContext(() => useUserPropertyFields(), getBaseState()); + + act(() => { + jest.runAllTimers(); + }); + rerender(); + + await waitFor(() => { + const [, read] = result.current; + expect(read.loading).toBe(false); + }); + + act(() => { + const [fields,,, ops] = result.current; + ops.update({...fields.data[legacyField.id], name: '7invalid'}); + }); + rerender(); + + const [fields] = result.current; + expect(fields.warnings).toEqual(expect.objectContaining({ + [legacyField.id]: {name: ValidationWarningNameInvalidCEL}, + })); + }); + + it('should accept a valid rename of a legacy invalid-named field', async () => { + const legacyField = {...field0, name: 'My Legacy Field'}; + getFields.mockResolvedValueOnce([legacyField, field1, field2, field3]); + + const {result, rerender} = renderHookWithContext(() => useUserPropertyFields(), getBaseState()); + + act(() => { + jest.runAllTimers(); + }); + rerender(); + + await waitFor(() => { + const [, read] = result.current; + expect(read.loading).toBe(false); + }); + + act(() => { + const [fields,,, ops] = result.current; + ops.update({...fields.data[legacyField.id], name: 'my_legacy_field'}); + }); + rerender(); + + const [fields] = result.current; + expect(fields.warnings?.[legacyField.id]?.name).not.toBe(ValidationWarningNameInvalidCEL); + }); + + it('should use CEL-safe collision suffixes for created field names', async () => { + const {result, rerender} = renderHookWithContext(() => useUserPropertyFields(), getBaseState()); + + act(() => { + jest.runAllTimers(); + }); + rerender(); + + await waitFor(() => { + const [, read] = result.current; + expect(read.loading).toBe(false); + }); + + act(() => { + const [,,, ops] = result.current; + ops.create({name: 'Text'}); + ops.create({name: 'Text'}); + }); + rerender(); + + const [fields] = result.current; + const pendingNames = fields.order.slice(-2).map((id) => fields.data[id].name); + expect(pendingNames).toEqual(['Text', 'Text_2']); + }); }); diff --git a/webapp/channels/src/components/admin_console/system_properties/user_properties_utils.ts b/webapp/channels/src/components/admin_console/system_properties/user_properties_utils.ts index dd37a7f3094..c2cd2088f50 100644 --- a/webapp/channels/src/components/admin_console/system_properties/user_properties_utils.ts +++ b/webapp/channels/src/components/admin_console/system_properties/user_properties_utils.ts @@ -13,6 +13,7 @@ import type {IDMappedCollection, IDMappedObjects} from '@mattermost/types/utilit import {Client4} from 'mattermost-redux/client'; import {insertWithoutDuplicates} from 'mattermost-redux/utils/array_utils'; +import {validateCPAFieldName} from 'utils/properties'; import {generateId} from 'utils/utils'; import type {CollectionIO} from './section_utils'; @@ -141,9 +142,27 @@ export const useUserPropertyFields = () => { } if (!field.name) { - // name not provided - acc[field.id] = {name: ValidationWarningNameRequired}; - } else if (pendingByName[field.name.toLowerCase()]?.filter((x) => x.delete_at === 0)?.length > 1) { + // name not provided — suppress for brand-new fields that + // haven't been interacted with yet (user just clicked "Add attribute") + const hasDisplayName = Boolean(field.attrs?.display_name?.trim()); + if (field.create_at !== 0 || hasDisplayName) { + acc[field.id] = {name: ValidationWarningNameRequired}; + } + return acc; + } + + // Lenient grandfather: only validate CEL names after a rename. + // Newly created fields always validate because they have no + // server-persisted identifier to grandfather from. + const originalName = current.data[field.id]?.name; + const nameChanged = field.create_at === 0 || field.name !== originalName; + + if (nameChanged && validateCPAFieldName(field.name)) { + acc[field.id] = {name: ValidationWarningNameInvalidCEL}; + return acc; + } + + if (pendingByName[field.name.toLowerCase()]?.filter((x) => x.delete_at === 0)?.length > 1) { // duplicate pending name acc[field.id] = {name: ValidationWarningNameUnique}; } else if ( @@ -206,10 +225,14 @@ export const useUserPropertyFields = () => { pendingIO.apply((pending) => { const nextOrder = Object.values(pending.data).filter((x) => !isDeletePending(x)).length; + const name = patch?.name ? + getIncrementedCELName(patch.name, pending) : + ''; + const field = newPendingField({ type: 'text', ...patch, - name: getIncrementedName(patch?.name ?? 'Text', pending), + name, attrs: { visibility: 'when_set', value_type: '', @@ -263,15 +286,20 @@ export const useUserPropertyFields = () => { export const ValidationWarningNameRequired = 'user_properties.validation.name_required'; export const ValidationWarningNameUnique = 'user_properties.validation.name_unique'; export const ValidationWarningNameTaken = 'user_properties.validation.name_taken'; +export const ValidationWarningNameInvalidCEL = 'user_properties.validation.name_invalid_cel'; export const ValidationWarningOptionsRequired = 'user_properties.validation.options_required'; -const getIncrementedName = (desiredName: string, collection: UserPropertyFields) => { - const names = new Set(Object.values(collection.data).map(({name}) => name)); +const getIncrementedCELName = (desiredName: string, collection: UserPropertyFields) => { + const names = new Set( + Object.values(collection.data). + filter(({delete_at: deleteAt}) => deleteAt === 0). + map(({name}) => name.toLowerCase()), + ); let newName = desiredName; let n = 1; - while (names.has(newName)) { + while (names.has(newName.toLowerCase())) { n++; - newName = `${desiredName} ${n}`; + newName = `${desiredName}_${n}`; } return newName; }; diff --git a/webapp/channels/src/components/admin_console/system_user_detail/system_user_detail.test.tsx b/webapp/channels/src/components/admin_console/system_user_detail/system_user_detail.test.tsx index c34365f2e60..ca683547e5c 100644 --- a/webapp/channels/src/components/admin_console/system_user_detail/system_user_detail.test.tsx +++ b/webapp/channels/src/components/admin_console/system_user_detail/system_user_detail.test.tsx @@ -8,6 +8,7 @@ import React from 'react'; import type {IntlShape} from 'react-intl'; import type {RouteComponentProps} from 'react-router-dom'; +import type {UserPropertyField} from '@mattermost/types/properties'; import type {UserProfile} from '@mattermost/types/users'; import SystemUserDetail, {getUserAuthenticationTextField} from 'components/admin_console/system_user_detail/system_user_detail'; @@ -370,6 +371,62 @@ describe('SystemUserDetail', () => { consoleSpy.mockRestore(); }); }); + + describe('CPA field labels', () => { + const buildCPAField = (overrides: Partial = {}): UserPropertyField => ({ + id: 'cpa-1', + name: 'department', + type: 'text', + group_id: 'custom_profile_attributes', + create_at: 0, + update_at: 0, + delete_at: 0, + created_by: '', + updated_by: '', + target_id: '', + target_type: '', + object_type: '', + attrs: { + sort_order: 0, + visibility: 'when_set', + value_type: '', + ...overrides, + }, + }); + + test('should render CPA label using display_name', async () => { + const cpaField = buildCPAField({display_name: 'Engineering Department'}); + const props = { + ...defaultProps, + customProfileAttributeFields: [cpaField], + getCustomProfileAttributeFields: jest.fn().mockResolvedValue({data: [cpaField]}), + }; + + renderWithContext(); + + await waitForLoadingToFinish(); + + const labelEl = await screen.findByTestId('user-detail-custom-attribute-label-cpa-1'); + expect(labelEl).toHaveTextContent('Engineering Department'); + expect(labelEl).not.toHaveTextContent('department'); + }); + + test('should fall back to name when display_name is empty', async () => { + const cpaField = buildCPAField({display_name: ''}); + const props = { + ...defaultProps, + customProfileAttributeFields: [cpaField], + getCustomProfileAttributeFields: jest.fn().mockResolvedValue({data: [cpaField]}), + }; + + renderWithContext(); + + await waitForLoadingToFinish(); + + const labelEl = await screen.findByTestId('user-detail-custom-attribute-label-cpa-1'); + expect(labelEl).toHaveTextContent('department'); + }); + }); }); describe('getUserAuthenticationTextField', () => { diff --git a/webapp/channels/src/components/admin_console/system_user_detail/system_user_detail.tsx b/webapp/channels/src/components/admin_console/system_user_detail/system_user_detail.tsx index 9caf9805026..1534491a984 100644 --- a/webapp/channels/src/components/admin_console/system_user_detail/system_user_detail.tsx +++ b/webapp/channels/src/components/admin_console/system_user_detail/system_user_detail.tsx @@ -43,6 +43,7 @@ import ShieldOutlineIcon from 'components/widgets/icons/shield_outline_icon'; import LoadingSpinner from 'components/widgets/loading/loading_spinner'; import {Constants, ModalIdentifiers} from 'utils/constants'; +import {getUserPropertyFieldLabel} from 'utils/properties'; import {validHttpUrl} from 'utils/url'; import {toTitleCase} from 'utils/utils'; @@ -653,11 +654,12 @@ export class SystemUserDetail extends PureComponent {