diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/resources/AttributeNameMapper.java b/server/src/main/java/org/cloudfoundry/identity/uaa/resources/AttributeNameMapper.java index 88eeef5d5a2..f687e0a7d4c 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/resources/AttributeNameMapper.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/resources/AttributeNameMapper.java @@ -20,10 +20,6 @@ public interface AttributeNameMapper { String mapToInternal(String attr); - String[] mapToInternal(String[] attr); - String mapFromInternal(String attr); - String[] mapFromInternal(String[] attr); - } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/resources/JoinAttributeNameMapper.java b/server/src/main/java/org/cloudfoundry/identity/uaa/resources/JoinAttributeNameMapper.java new file mode 100644 index 00000000000..2e490156ced --- /dev/null +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/resources/JoinAttributeNameMapper.java @@ -0,0 +1,33 @@ +package org.cloudfoundry.identity.uaa.resources; + +/** + * Support table joins using a prefixed attribute mapping, e.g. + * select * from table1 joinName join table2 joinName2 on joinName.origin = joinName2.origin_key ... + * Used in SearchQueryConverter + */ +public class JoinAttributeNameMapper implements AttributeNameMapper { + + private final String name; + private final String joinPrefix; + private final int prefixLength; + + public JoinAttributeNameMapper(String name) { + this.name = name; + joinPrefix = name + "."; + prefixLength = joinPrefix.length(); + } + + @Override + public String mapToInternal(String attr) { + return joinPrefix + attr; + } + + @Override + public String mapFromInternal(String attr) { + return attr.substring(prefixLength); + } + + public String getName() { + return name; + } +} diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/resources/SimpleAttributeNameMapper.java b/server/src/main/java/org/cloudfoundry/identity/uaa/resources/SimpleAttributeNameMapper.java index ce7cf381392..56bed85f476 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/resources/SimpleAttributeNameMapper.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/resources/SimpleAttributeNameMapper.java @@ -32,16 +32,6 @@ public String mapToInternal(String attr) { return mappedAttr; } - @Override - public String[] mapToInternal(String[] attr) { - String[] result = new String[attr.length]; - int x = 0; - for (String a : attr) { - result[x++] = mapToInternal(a); - } - return result; - } - @Override public String mapFromInternal(String attr) { String mappedAttr = attr; @@ -50,14 +40,4 @@ public String mapFromInternal(String attr) { } return mappedAttr; } - - @Override - public String[] mapFromInternal(String[] attr) { - String[] result = new String[attr.length]; - int x = 0; - for (String a : attr) { - result[x++] = mapFromInternal(a); - } - return result; - } } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/resources/jdbc/AbstractQueryable.java b/server/src/main/java/org/cloudfoundry/identity/uaa/resources/jdbc/AbstractQueryable.java index 28292abf8b2..934ce863407 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/resources/jdbc/AbstractQueryable.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/resources/jdbc/AbstractQueryable.java @@ -14,10 +14,11 @@ import java.util.stream.Collectors; import static com.google.common.primitives.Ints.tryParse; +import static org.cloudfoundry.identity.uaa.resources.jdbc.SearchQueryConverter.ProcessedFilter.ORDER_BY; public abstract class AbstractQueryable implements Queryable { - private NamedParameterJdbcTemplate namedParameterJdbcTemplate; + protected NamedParameterJdbcTemplate namedParameterJdbcTemplate; protected final JdbcPagingListFactory pagingListFactory; @@ -43,6 +44,10 @@ public void setQueryConverter(SearchQueryConverter queryConverter) { this.queryConverter = queryConverter; } + public void setNamedParameterJdbcTemplate(NamedParameterJdbcTemplate namedParameterJdbcTemplate) { + this.namedParameterJdbcTemplate = namedParameterJdbcTemplate; + } + /** * The maximum number of items fetched from the database in one hit. If less * than or equal to zero, then there is no @@ -87,7 +92,7 @@ public List query(String filter, String sortBy, boolean ascending, String zon private String getQuerySQL(SearchQueryConverter.ProcessedFilter where) { if (where.hasOrderBy()) { - return getBaseSqlQuery() + " where (" + where.getSql().replace(where.ORDER_BY, ")" + where.ORDER_BY); + return getBaseSqlQuery() + " where (" + where.getSql().replace(ORDER_BY, ")" + ORDER_BY); } else { return getBaseSqlQuery() + " where (" + where.getSql() + ")"; } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/resources/jdbc/SearchQueryConverter.java b/server/src/main/java/org/cloudfoundry/identity/uaa/resources/jdbc/SearchQueryConverter.java index d7fadc9cbaf..139d543eebb 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/resources/jdbc/SearchQueryConverter.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/resources/jdbc/SearchQueryConverter.java @@ -66,4 +66,5 @@ public String toString() { String map(String attribute); + String getJoinName(); } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/resources/jdbc/SimpleSearchQueryConverter.java b/server/src/main/java/org/cloudfoundry/identity/uaa/resources/jdbc/SimpleSearchQueryConverter.java index b1774ba674a..c7c3169cd78 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/resources/jdbc/SimpleSearchQueryConverter.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/resources/jdbc/SimpleSearchQueryConverter.java @@ -5,8 +5,10 @@ import com.unboundid.scim.sdk.SCIMException; import com.unboundid.scim.sdk.SCIMFilter; import org.cloudfoundry.identity.uaa.resources.AttributeNameMapper; +import org.cloudfoundry.identity.uaa.resources.JoinAttributeNameMapper; import org.cloudfoundry.identity.uaa.resources.SimpleAttributeNameMapper; import org.cloudfoundry.identity.uaa.util.AlphanumericRandomValueStringGenerator; +import org.cloudfoundry.identity.uaa.util.UaaStringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.lang.Nullable; @@ -347,4 +349,8 @@ private Object getStringOrDate(String s) { public String map(String attribute) { return hasText(attribute) ? mapper.mapToInternal(attribute) : attribute; } + + public String getJoinName() { + return (mapper instanceof JoinAttributeNameMapper joinAttributeNameMapper) ? joinAttributeNameMapper.getName() : UaaStringUtils.EMPTY_STRING; + } } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/scim/jdbc/JdbcScimUserProvisioning.java b/server/src/main/java/org/cloudfoundry/identity/uaa/scim/jdbc/JdbcScimUserProvisioning.java index 2771ad26125..98664ed793a 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/scim/jdbc/JdbcScimUserProvisioning.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/scim/jdbc/JdbcScimUserProvisioning.java @@ -29,14 +29,13 @@ import java.util.Map; import java.util.UUID; import java.util.regex.Pattern; -import java.util.stream.Stream; import org.cloudfoundry.identity.uaa.audit.event.SystemDeletable; import org.cloudfoundry.identity.uaa.constants.OriginKeys; -import org.cloudfoundry.identity.uaa.resources.AttributeNameMapper; import org.cloudfoundry.identity.uaa.resources.ResourceMonitor; import org.cloudfoundry.identity.uaa.resources.jdbc.AbstractQueryable; import org.cloudfoundry.identity.uaa.resources.jdbc.JdbcPagingListFactory; +import org.cloudfoundry.identity.uaa.resources.jdbc.SearchQueryConverter; import org.cloudfoundry.identity.uaa.resources.jdbc.SearchQueryConverter.ProcessedFilter; import org.cloudfoundry.identity.uaa.resources.jdbc.SimpleSearchQueryConverter; import org.cloudfoundry.identity.uaa.scim.ScimMeta; @@ -66,7 +65,6 @@ import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.RowMapper; -import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.util.Assert; @@ -139,7 +137,7 @@ public Logger getLogger() { private final JdbcIdentityZoneProvisioning jdbcIdentityZoneProvisioning; private final IdentityZoneManager identityZoneManager; - private boolean useCaseInsensitiveQueries = false; + private SearchQueryConverter joinConverter; public JdbcScimUserProvisioning( final JdbcTemplate jdbcTemplate, @@ -149,7 +147,7 @@ public JdbcScimUserProvisioning( final JdbcIdentityZoneProvisioning jdbcIdentityZoneProvisioning ) { super(jdbcTemplate, pagingListFactory, mapper); - Assert.notNull(jdbcTemplate); + Assert.notNull(jdbcTemplate, "JDBC must not be null"); this.jdbcTemplate = jdbcTemplate; setQueryConverter(new SimpleSearchQueryConverter()); this.passwordEncoder = passwordEncoder; @@ -161,8 +159,9 @@ public void setTimeService(TimeService timeService) { this.timeService = timeService; } - public void setUseCaseInsensitiveQueries(final boolean useCaseInsensitiveQueries) { - this.useCaseInsensitiveQueries = useCaseInsensitiveQueries; + + public void setJoinConverter(SearchQueryConverter joinConverter) { + this.joinConverter = joinConverter; } @Override @@ -191,45 +190,13 @@ public List retrieveByScimFilterOnlyActive( final boolean ascending, final String zoneId ) { - /* We cannot reuse the query converter from the superclass here since the later query operates on both the - * "users" and the "identity_provider" table and they both have a column named "id". Since the SCIM filter might - * contain clauses on the "id" field, we must ensure that the "id" of the "users" table is used, which is done - * by attaching an AttributeNameMapper. */ - final SimpleSearchQueryConverter queryConverter = new SimpleSearchQueryConverter(); - - // ensure that the generated query handles the case-insensitivity of the underlying DB correctly - queryConverter.setDbCaseInsensitive(useCaseInsensitiveQueries); - - validateOrderBy(queryConverter.map(sortBy)); - + validateOrderBy(sortBy); /* since the two tables used in the query ('users' and 'identity_provider') have columns with identical names, * we must ensure that the columns of 'users' are used in the WHERE clause generated for the SCIM filter */ - final AttributeNameMapper attributeNameMapper = new AttributeNameMapper() { - @Override - public String mapToInternal(final String attr) { - // in the later query, 'users' will have the alias 'u' - return "u." + attr; - } - - @Override - public String[] mapToInternal(final String[] attr) { - return Stream.of(attr).map(this::mapToInternal).toArray(String[]::new); - } - - @Override - public String mapFromInternal(final String attr) { - return attr.substring(2); - } - - @Override - public String[] mapFromInternal(final String[] attr) { - return Stream.of(attr).map(this::mapFromInternal).toArray(String[]::new); - } - }; - queryConverter.setAttributeNameMapper(attributeNameMapper); + String joinName = joinConverter.getJoinName(); // build WHERE clause - final ProcessedFilter where = queryConverter.convert(filter, sortBy, ascending, zoneId); + final ProcessedFilter where = joinConverter.convert(filter, sortBy, ascending, zoneId); final String whereClauseScimFilter = where.getSql(); String whereClause = "idp.active is true and ("; if (where.hasOrderBy()) { @@ -239,11 +206,14 @@ public String[] mapFromInternal(final String[] attr) { } final String userFieldsWithPrefix = Arrays.stream(USER_FIELDS.split(",")) - .map(field -> "u." + field) + .map(field -> joinName + "." + field) .collect(joining(", ")); + String joinStatement = String.format( + "%s join identity_provider idp on %s.origin = idp.origin_key and %s.identity_zone_id = idp.identity_zone_id", joinName, joinName, joinName); final String sql = String.format( - "select %s from users u join identity_provider idp on u.origin = idp.origin_key and u.identity_zone_id = idp.identity_zone_id where %s", + "select %s from users %s where %s", userFieldsWithPrefix, + joinStatement, whereClause ); @@ -251,7 +221,6 @@ public String[] mapFromInternal(final String[] attr) { return pagingListFactory.createJdbcPagingList(sql, where.getParams(), rowMapper, getPageSize()); } - final NamedParameterJdbcTemplate namedParameterJdbcTemplate = new NamedParameterJdbcTemplate(jdbcTemplate); return namedParameterJdbcTemplate.query(sql, where.getParams(), rowMapper); } @@ -571,7 +540,6 @@ public int deleteByUser(String userId, String zoneId) { return 1; } - private static final class ScimUserRowMapper implements RowMapper { @Override public ScimUser mapRow(ResultSet rs, int rowNum) throws SQLException { diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/resources/jdbc/SimpleSearchQueryConverterTests.java b/server/src/test/java/org/cloudfoundry/identity/uaa/resources/jdbc/SimpleSearchQueryConverterTests.java index 59497461c6b..1628f7f19eb 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/resources/jdbc/SimpleSearchQueryConverterTests.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/resources/jdbc/SimpleSearchQueryConverterTests.java @@ -1,5 +1,6 @@ package org.cloudfoundry.identity.uaa.resources.jdbc; +import org.cloudfoundry.identity.uaa.resources.JoinAttributeNameMapper; import org.cloudfoundry.identity.uaa.test.ModelTestUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -9,6 +10,7 @@ import java.util.Arrays; import java.util.List; +import java.util.Map; import static org.cloudfoundry.identity.uaa.util.AssertThrowsWithMessage.assertThrowsWithMessageThat; import static org.hamcrest.MatcherAssert.assertThat; @@ -94,4 +96,27 @@ void invalidOperator(final String operator) { () -> converter.getFilterValues(query, validAttributes), is("[" + operator + "] operator is not supported.")); } -} \ No newline at end of file + + @Test + void testJoinName() { + assertEquals("", converter.getJoinName()); + converter.setAttributeNameMapper(new JoinAttributeNameMapper("myTable")); + assertEquals("myTable", converter.getJoinName()); + } + + @Test + void testJoinFilterAttributes() { + String query = "origin eq \"origin-value\" and id eq \"group-value\""; + List validAttributes = Arrays.asList("origin", "id".toLowerCase()); + JoinAttributeNameMapper joinAttributeNameMapper = new JoinAttributeNameMapper("prefix"); + converter.setAttributeNameMapper(joinAttributeNameMapper); + Map filterValues = converter.getFilterValues(query, validAttributes); + assertNotNull(filterValues); + assertEquals("[origin-value]", filterValues.get("origin").toString()); + assertEquals("[group-value]", filterValues.get("id").toString()); + assertEquals("prefix.origin", converter.map("origin")); + assertEquals("prefix.id", converter.map("id")); + assertEquals("prefix", converter.getJoinName()); + assertEquals("origin", joinAttributeNameMapper.mapFromInternal("prefix.origin")); + } +} diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/scim/jdbc/JdbcScimUserProvisioningTests.java b/server/src/test/java/org/cloudfoundry/identity/uaa/scim/jdbc/JdbcScimUserProvisioningTests.java index 5196d8585ae..9c484a55136 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/scim/jdbc/JdbcScimUserProvisioningTests.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/scim/jdbc/JdbcScimUserProvisioningTests.java @@ -17,8 +17,14 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.contains; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.sql.Timestamp; @@ -37,6 +43,7 @@ import org.cloudfoundry.identity.uaa.audit.event.EntityDeletedEvent; import org.cloudfoundry.identity.uaa.constants.OriginKeys; import org.cloudfoundry.identity.uaa.provider.IdentityProvider; +import org.cloudfoundry.identity.uaa.resources.JoinAttributeNameMapper; import org.cloudfoundry.identity.uaa.resources.SimpleAttributeNameMapper; import org.cloudfoundry.identity.uaa.resources.jdbc.JdbcPagingListFactory; import org.cloudfoundry.identity.uaa.resources.jdbc.LimitSqlAdapter; @@ -67,6 +74,8 @@ import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.http.HttpStatus; import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.core.Authentication; import org.springframework.security.crypto.password.PasswordEncoder; @@ -120,6 +129,9 @@ void setUp(@Autowired LimitSqlAdapter limitSqlAdapter) { replaceWith.put("phoneNumbers\\.value", "phoneNumber"); filterConverter.setAttributeNameMapper(new SimpleAttributeNameMapper(replaceWith)); jdbcScimUserProvisioning.setQueryConverter(filterConverter); + SimpleSearchQueryConverter joinConverter = new SimpleSearchQueryConverter(); + joinConverter.setAttributeNameMapper(new JoinAttributeNameMapper("u")); + jdbcScimUserProvisioning.setJoinConverter(joinConverter); addUser(jdbcTemplate, joeId, JOE_NAME, passwordEncoder.encode("joespassword"), joeEmail, "Joe", "User", "+1-222-1234567", currentIdentityZoneId); @@ -297,6 +309,98 @@ void retrieveByScimFilterOnlyActive() { Assertions.assertThat(usernames).isEmpty(); } + @Test + void retrieveByScimFilterNoPaging() { + JdbcPagingListFactory notInUse = mock(JdbcPagingListFactory.class); + jdbcScimUserProvisioning = new JdbcScimUserProvisioning(jdbcTemplate, notInUse, passwordEncoder, new IdentityZoneManagerImpl(), + new JdbcIdentityZoneProvisioning(jdbcTemplate)); + SimpleSearchQueryConverter joinConverter = new SimpleSearchQueryConverter(); + joinConverter.setAttributeNameMapper(new JoinAttributeNameMapper("u")); + jdbcScimUserProvisioning.setJoinConverter(joinConverter); + String originActive = randomString(); + addIdentityProvider(jdbcTemplate, currentIdentityZoneId, originActive, true); + + String originInactive = randomString(); + addIdentityProvider(jdbcTemplate, currentIdentityZoneId, originInactive, false); + + ScimUser user1 = new ScimUser(null, "jo@foo.com", "Jo", "User"); + user1.addEmail("jo@blah.com"); + user1.setOrigin(originActive); + ScimUser created1 = jdbcScimUserProvisioning.createUser(user1, "j8hyqpassX", currentIdentityZoneId); + + ScimUser user2 = new ScimUser(null, "jo2@foo.com", "Jo", "User"); + user2.addEmail("jo2@blah.com"); + user2.setOrigin(originInactive); + ScimUser created2 = jdbcScimUserProvisioning.createUser(user2, "j8hyqpassX", currentIdentityZoneId); + + String scimFilter = String.format("id eq '%s' or username eq '%s' or origin eq '%s'", created1.getId(), created2.getUserName(), created2.getOrigin()); + jdbcScimUserProvisioning.setPageSize(0); + List result = jdbcScimUserProvisioning.retrieveByScimFilterOnlyActive( + scimFilter, + null, + false, + currentIdentityZoneId + ); + Assertions.assertThat(result).isNotNull(); + List usernames = result.stream().map(ScimUser::getUserName).collect(toList()); + Assertions.assertThat(usernames).isSorted(); + verify(notInUse, never()).createJdbcPagingList(anyString(), any(Map.class), any(RowMapper.class), any(Integer.class)); + // another option to query without paging + jdbcScimUserProvisioning.setPageSize(Integer.MAX_VALUE); + jdbcScimUserProvisioning.setPageSize(0); + jdbcScimUserProvisioning.retrieveByScimFilterOnlyActive( + scimFilter, + null, + false, + currentIdentityZoneId + ); + verify(notInUse, never()).createJdbcPagingList(anyString(), any(Map.class), any(RowMapper.class), any(Integer.class)); + // positive check, now with paging + jdbcScimUserProvisioning.setPageSize(1); + jdbcScimUserProvisioning.retrieveByScimFilterOnlyActive( + scimFilter, + null, + false, + currentIdentityZoneId + ); + verify(notInUse, times(1)).createJdbcPagingList(anyString(), any(Map.class), any(RowMapper.class), any(Integer.class)); + } + + @Test + void retrieveByScimFilterUsingLower() { + JdbcPagingListFactory notInUse = mock(JdbcPagingListFactory.class); + NamedParameterJdbcTemplate mockedJdbcTemplate = mock(NamedParameterJdbcTemplate.class); + SimpleSearchQueryConverter joinConverter = new SimpleSearchQueryConverter(); + joinConverter.setAttributeNameMapper(new JoinAttributeNameMapper("u")); + jdbcScimUserProvisioning.setJoinConverter(joinConverter); + + String scimFilter = "id eq '1111' or username eq 'j4hyqpassX' or origin eq 'uaa'"; + jdbcScimUserProvisioning.setPageSize(0); + jdbcScimUserProvisioning.setNamedParameterJdbcTemplate(mockedJdbcTemplate); + // MYSQL default, no LOWER statement in query + joinConverter.setDbCaseInsensitive(true); + List result = jdbcScimUserProvisioning.retrieveByScimFilterOnlyActive( + scimFilter, + null, + false, + currentIdentityZoneId + ); + Assertions.assertThat(result).isNotNull(); + verify(mockedJdbcTemplate).query(contains("u.id = "), any(Map.class), any(RowMapper.class)); + verify(mockedJdbcTemplate, never()).query(contains("LOWER(u.id) = LOWER("), any(Map.class), any(RowMapper.class)); + // POSTGRESQL and HSQL default + joinConverter.setDbCaseInsensitive(false); + result = jdbcScimUserProvisioning.retrieveByScimFilterOnlyActive( + scimFilter, + null, + false, + currentIdentityZoneId + ); + Assertions.assertThat(result).isNotNull(); + verify(notInUse, never()).createJdbcPagingList(anyString(), any(Map.class), any(RowMapper.class), any(Integer.class)); + verify(mockedJdbcTemplate).query(contains("LOWER(u.id) = LOWER("), any(Map.class), any(RowMapper.class)); + } + @Test void retrieveByScimFilter_IncludeInactive() { final String originActive = randomString(); diff --git a/uaa/src/main/webapp/WEB-INF/spring/scim-endpoints.xml b/uaa/src/main/webapp/WEB-INF/spring/scim-endpoints.xml index a15e6be746b..3df4cef469a 100644 --- a/uaa/src/main/webapp/WEB-INF/spring/scim-endpoints.xml +++ b/uaa/src/main/webapp/WEB-INF/spring/scim-endpoints.xml @@ -24,14 +24,23 @@ + + + + + + + + + + -