Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring: testable code in JdbcScimUserProvisioning #2863

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ public interface AttributeNameMapper {

String mapToInternal(String attr);

String[] mapToInternal(String[] attr);

String mapFromInternal(String attr);

String[] mapFromInternal(String[] attr);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.cloudfoundry.identity.uaa.resources;

/**
* Support table joins using a prefixed attribute mapping, e.g.
* select * from table1 joinName join table2 joinName2 on joinName.origin = joinName2.origin_key ...
* Used in SearchQueryConverter
*/
public class JoinAttributeNameMapper implements AttributeNameMapper {
strehle marked this conversation as resolved.
Show resolved Hide resolved

private final String name;
private final String joinPrefix;
private final int prefixLength;

public JoinAttributeNameMapper(String name) {
this.name = name;
joinPrefix = name + ".";
prefixLength = joinPrefix.length();
}

@Override
public String mapToInternal(String attr) {
return joinPrefix + attr;
}

@Override
public String mapFromInternal(String attr) {
return attr.substring(prefixLength);
}

public String getName() {
return name;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,6 @@ public String mapToInternal(String attr) {
return mappedAttr;
}

@Override
public String[] mapToInternal(String[] attr) {
String[] result = new String[attr.length];
int x = 0;
for (String a : attr) {
result[x++] = mapToInternal(a);
}
return result;
}

@Override
public String mapFromInternal(String attr) {
String mappedAttr = attr;
Expand All @@ -50,14 +40,4 @@ public String mapFromInternal(String attr) {
}
return mappedAttr;
}

@Override
public String[] mapFromInternal(String[] attr) {
String[] result = new String[attr.length];
int x = 0;
for (String a : attr) {
result[x++] = mapFromInternal(a);
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
import java.util.stream.Collectors;

import static com.google.common.primitives.Ints.tryParse;
import static org.cloudfoundry.identity.uaa.resources.jdbc.SearchQueryConverter.ProcessedFilter.ORDER_BY;

public abstract class AbstractQueryable<T> implements Queryable<T> {

private NamedParameterJdbcTemplate namedParameterJdbcTemplate;
protected NamedParameterJdbcTemplate namedParameterJdbcTemplate;

protected final JdbcPagingListFactory pagingListFactory;

Expand All @@ -43,6 +44,10 @@ public void setQueryConverter(SearchQueryConverter queryConverter) {
this.queryConverter = queryConverter;
}

public void setNamedParameterJdbcTemplate(NamedParameterJdbcTemplate namedParameterJdbcTemplate) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be removed / refactored, if #2864 is merged before

this.namedParameterJdbcTemplate = namedParameterJdbcTemplate;
}

/**
* The maximum number of items fetched from the database in one hit. If less
* than or equal to zero, then there is no
Expand Down Expand Up @@ -87,7 +92,7 @@ public List<T> query(String filter, String sortBy, boolean ascending, String zon

private String getQuerySQL(SearchQueryConverter.ProcessedFilter where) {
if (where.hasOrderBy()) {
return getBaseSqlQuery() + " where (" + where.getSql().replace(where.ORDER_BY, ")" + where.ORDER_BY);
return getBaseSqlQuery() + " where (" + where.getSql().replace(ORDER_BY, ")" + ORDER_BY);
} else {
return getBaseSqlQuery() + " where (" + where.getSql() + ")";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,5 @@ public String toString() {

String map(String attribute);

String getJoinName();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import com.unboundid.scim.sdk.SCIMException;
import com.unboundid.scim.sdk.SCIMFilter;
import org.cloudfoundry.identity.uaa.resources.AttributeNameMapper;
import org.cloudfoundry.identity.uaa.resources.JoinAttributeNameMapper;
import org.cloudfoundry.identity.uaa.resources.SimpleAttributeNameMapper;
import org.cloudfoundry.identity.uaa.util.AlphanumericRandomValueStringGenerator;
import org.cloudfoundry.identity.uaa.util.UaaStringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -347,4 +349,8 @@ private Object getStringOrDate(String s) {
public String map(String attribute) {
return hasText(attribute) ? mapper.mapToInternal(attribute) : attribute;
}

public String getJoinName() {
return (mapper instanceof JoinAttributeNameMapper joinAttributeNameMapper) ? joinAttributeNameMapper.getName() : UaaStringUtils.EMPTY_STRING;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@
import java.util.Map;
import java.util.UUID;
import java.util.regex.Pattern;
import java.util.stream.Stream;

import org.cloudfoundry.identity.uaa.audit.event.SystemDeletable;
import org.cloudfoundry.identity.uaa.constants.OriginKeys;
import org.cloudfoundry.identity.uaa.resources.AttributeNameMapper;
import org.cloudfoundry.identity.uaa.resources.ResourceMonitor;
import org.cloudfoundry.identity.uaa.resources.jdbc.AbstractQueryable;
import org.cloudfoundry.identity.uaa.resources.jdbc.JdbcPagingListFactory;
import org.cloudfoundry.identity.uaa.resources.jdbc.SearchQueryConverter;
import org.cloudfoundry.identity.uaa.resources.jdbc.SearchQueryConverter.ProcessedFilter;
import org.cloudfoundry.identity.uaa.resources.jdbc.SimpleSearchQueryConverter;
import org.cloudfoundry.identity.uaa.scim.ScimMeta;
Expand Down Expand Up @@ -66,7 +65,6 @@
import org.springframework.dao.OptimisticLockingFailureException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -139,7 +137,7 @@ public Logger getLogger() {
private final JdbcIdentityZoneProvisioning jdbcIdentityZoneProvisioning;
private final IdentityZoneManager identityZoneManager;

private boolean useCaseInsensitiveQueries = false;
private SearchQueryConverter joinConverter;

public JdbcScimUserProvisioning(
final JdbcTemplate jdbcTemplate,
Expand All @@ -149,7 +147,7 @@ public JdbcScimUserProvisioning(
final JdbcIdentityZoneProvisioning jdbcIdentityZoneProvisioning
) {
super(jdbcTemplate, pagingListFactory, mapper);
Assert.notNull(jdbcTemplate);
Assert.notNull(jdbcTemplate, "JDBC must not be null");
this.jdbcTemplate = jdbcTemplate;
setQueryConverter(new SimpleSearchQueryConverter());
this.passwordEncoder = passwordEncoder;
Expand All @@ -161,8 +159,9 @@ public void setTimeService(TimeService timeService) {
this.timeService = timeService;
}

public void setUseCaseInsensitiveQueries(final boolean useCaseInsensitiveQueries) {
this.useCaseInsensitiveQueries = useCaseInsensitiveQueries;

public void setJoinConverter(SearchQueryConverter joinConverter) {
this.joinConverter = joinConverter;
}

@Override
Expand Down Expand Up @@ -191,45 +190,13 @@ public List<ScimUser> retrieveByScimFilterOnlyActive(
final boolean ascending,
final String zoneId
) {
/* We cannot reuse the query converter from the superclass here since the later query operates on both the
* "users" and the "identity_provider" table and they both have a column named "id". Since the SCIM filter might
* contain clauses on the "id" field, we must ensure that the "id" of the "users" table is used, which is done
* by attaching an AttributeNameMapper. */
final SimpleSearchQueryConverter queryConverter = new SimpleSearchQueryConverter();

// ensure that the generated query handles the case-insensitivity of the underlying DB correctly
queryConverter.setDbCaseInsensitive(useCaseInsensitiveQueries);

validateOrderBy(queryConverter.map(sortBy));

validateOrderBy(sortBy);
/* since the two tables used in the query ('users' and 'identity_provider') have columns with identical names,
* we must ensure that the columns of 'users' are used in the WHERE clause generated for the SCIM filter */
final AttributeNameMapper attributeNameMapper = new AttributeNameMapper() {
@Override
public String mapToInternal(final String attr) {
// in the later query, 'users' will have the alias 'u'
return "u." + attr;
}

@Override
public String[] mapToInternal(final String[] attr) {
return Stream.of(attr).map(this::mapToInternal).toArray(String[]::new);
}

@Override
public String mapFromInternal(final String attr) {
return attr.substring(2);
}

@Override
public String[] mapFromInternal(final String[] attr) {
return Stream.of(attr).map(this::mapFromInternal).toArray(String[]::new);
}
};
queryConverter.setAttributeNameMapper(attributeNameMapper);
String joinName = joinConverter.getJoinName();

// build WHERE clause
final ProcessedFilter where = queryConverter.convert(filter, sortBy, ascending, zoneId);
final ProcessedFilter where = joinConverter.convert(filter, sortBy, ascending, zoneId);
final String whereClauseScimFilter = where.getSql();
String whereClause = "idp.active is true and (";
if (where.hasOrderBy()) {
Expand All @@ -239,19 +206,21 @@ public String[] mapFromInternal(final String[] attr) {
}

final String userFieldsWithPrefix = Arrays.stream(USER_FIELDS.split(","))
.map(field -> "u." + field)
.map(field -> joinName + "." + field)
.collect(joining(", "));
String joinStatement = String.format(
"%s join identity_provider idp on %s.origin = idp.origin_key and %s.identity_zone_id = idp.identity_zone_id", joinName, joinName, joinName);
final String sql = String.format(
"select %s from users u join identity_provider idp on u.origin = idp.origin_key and u.identity_zone_id = idp.identity_zone_id where %s",
"select %s from users %s where %s",
userFieldsWithPrefix,
joinStatement,
whereClause
);

if (getPageSize() > 0 && getPageSize() < Integer.MAX_VALUE) {
return pagingListFactory.createJdbcPagingList(sql, where.getParams(), rowMapper, getPageSize());
}

final NamedParameterJdbcTemplate namedParameterJdbcTemplate = new NamedParameterJdbcTemplate(jdbcTemplate);
return namedParameterJdbcTemplate.query(sql, where.getParams(), rowMapper);
}

Expand Down Expand Up @@ -571,7 +540,6 @@ public int deleteByUser(String userId, String zoneId) {
return 1;
}


private static final class ScimUserRowMapper implements RowMapper<ScimUser> {
@Override
public ScimUser mapRow(ResultSet rs, int rowNum) throws SQLException {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.cloudfoundry.identity.uaa.resources.jdbc;

import org.cloudfoundry.identity.uaa.resources.JoinAttributeNameMapper;
import org.cloudfoundry.identity.uaa.test.ModelTestUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -9,6 +10,7 @@

import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static org.cloudfoundry.identity.uaa.util.AssertThrowsWithMessage.assertThrowsWithMessageThat;
import static org.hamcrest.MatcherAssert.assertThat;
Expand Down Expand Up @@ -94,4 +96,27 @@ void invalidOperator(final String operator) {
() -> converter.getFilterValues(query, validAttributes),
is("[" + operator + "] operator is not supported."));
}
}

@Test
void testJoinName() {
assertEquals("", converter.getJoinName());
converter.setAttributeNameMapper(new JoinAttributeNameMapper("myTable"));
assertEquals("myTable", converter.getJoinName());
}

@Test
void testJoinFilterAttributes() {
String query = "origin eq \"origin-value\" and id eq \"group-value\"";
List<String> validAttributes = Arrays.asList("origin", "id".toLowerCase());
JoinAttributeNameMapper joinAttributeNameMapper = new JoinAttributeNameMapper("prefix");
converter.setAttributeNameMapper(joinAttributeNameMapper);
Map filterValues = converter.getFilterValues(query, validAttributes);
assertNotNull(filterValues);
assertEquals("[origin-value]", filterValues.get("origin").toString());
assertEquals("[group-value]", filterValues.get("id").toString());
assertEquals("prefix.origin", converter.map("origin"));
assertEquals("prefix.id", converter.map("id"));
assertEquals("prefix", converter.getJoinName());
assertEquals("origin", joinAttributeNameMapper.mapFromInternal("prefix.origin"));
}
}
Loading
Loading