From bbf6751bc0d87c4a3aaf21b54e26ce328ab998b3 Mon Sep 17 00:00:00 2001 From: Filip Hanik Date: Fri, 12 May 2017 09:51:14 -0700 Subject: [PATCH] Add zone ID to expiring codes [#145313231] https://www.pivotaltracker.com/story/show/145313231 --- .../uaa/codestore/ExpiringCodeStore.java | 21 +++++-- .../codestore/InMemoryExpiringCodeStore.java | 11 ++-- .../uaa/codestore/JdbcExpiringCodeStore.java | 59 +++++++++---------- .../uaa/codestore/ExpiringCodeStoreTests.java | 47 +++++++++++++-- .../InvitationsEndpointMockMvcTests.java | 22 ++++--- .../login/InvitationsServiceMockMvcTests.java | 3 + .../uaa/mock/ldap/LdapMockMvcTests.java | 9 ++- .../ScimUserEndpointsMockMvcTests.java | 10 ++++ 8 files changed, 128 insertions(+), 54 deletions(-) diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/codestore/ExpiringCodeStore.java b/server/src/main/java/org/cloudfoundry/identity/uaa/codestore/ExpiringCodeStore.java index 535cea240da..366a6ea8cfe 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/codestore/ExpiringCodeStore.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/codestore/ExpiringCodeStore.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Cloud Foundry + * Cloud Foundry * Copyright (c) [2009-2016] Pivotal Software, Inc. All Rights Reserved. * * This product is licensed to you under the Apache License, Version 2.0 (the "License"). @@ -12,6 +12,7 @@ *******************************************************************************/ package org.cloudfoundry.identity.uaa.codestore; +import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder; import org.springframework.security.oauth2.common.util.RandomValueStringGenerator; import java.sql.Timestamp; @@ -20,7 +21,7 @@ public interface ExpiringCodeStore { /** * Generate and persist a one-time code with an expiry date. - * + * * @param data JSON object to be associated with the code * @param intent An optional key (not necessarily unique) for looking up codes * @return code the generated one-time code @@ -31,7 +32,7 @@ public interface ExpiringCodeStore { /** * Retrieve a code and delete it if it exists. - * + * * @param code the one-time code to look for * @return code or null if the code is not found * @throws java.lang.NullPointerException if the code is null @@ -40,7 +41,7 @@ public interface ExpiringCodeStore { /** * Set the code generator for this store. - * + * * @param generator Code generator */ void setGenerator(RandomValueStringGenerator generator); @@ -51,4 +52,16 @@ public interface ExpiringCodeStore { * @param intent Intent of codes to remove */ void expireByIntent(String intent); + + default String zonifyCode(String code) { + return code + "[zone[" + IdentityZoneHolder.get().getId()+"]]"; + } + + default String extractCode(String zoneCode) { + int endIndex = zoneCode.indexOf("[zone[" + IdentityZoneHolder.get().getId()+"]]"); + if (endIndex<0) { + return zoneCode; + } + return zoneCode.substring(0, endIndex); + } } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/codestore/InMemoryExpiringCodeStore.java b/server/src/main/java/org/cloudfoundry/identity/uaa/codestore/InMemoryExpiringCodeStore.java index 5c586c02018..602965c9c29 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/codestore/InMemoryExpiringCodeStore.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/codestore/InMemoryExpiringCodeStore.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Cloud Foundry + * Cloud Foundry * Copyright (c) [2009-2016] Pivotal Software, Inc. All Rights Reserved. * * This product is licensed to you under the Apache License, Version 2.0 (the "License"). @@ -12,6 +12,7 @@ *******************************************************************************/ package org.cloudfoundry.identity.uaa.codestore; +import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder; import org.springframework.dao.DataIntegrityViolationException; import org.springframework.security.oauth2.common.util.RandomValueStringGenerator; import org.springframework.util.Assert; @@ -40,7 +41,7 @@ public ExpiringCode generateCode(String data, Timestamp expiresAt, String intent ExpiringCode expiringCode = new ExpiringCode(code, expiresAt, data, intent); - ExpiringCode duplicate = store.putIfAbsent(code, expiringCode); + ExpiringCode duplicate = store.putIfAbsent(zonifyCode(code), expiringCode); if (duplicate != null) { throw new DataIntegrityViolationException("Duplicate code: " + code); } @@ -54,7 +55,7 @@ public ExpiringCode retrieveCode(String code) { throw new NullPointerException(); } - ExpiringCode expiringCode = store.remove(code); + ExpiringCode expiringCode = store.remove(zonifyCode(code)); if (expiringCode == null || expiringCode.getExpiresAt().getTime() < System.currentTimeMillis()) { expiringCode = null; @@ -71,7 +72,7 @@ public void setGenerator(RandomValueStringGenerator generator) { @Override public void expireByIntent(String intent) { Assert.hasText(intent); - - store.values().stream().filter(c -> intent.equals(c.getIntent())).forEach(c -> store.remove(c.getCode())); + String id = IdentityZoneHolder.get().getId(); + store.entrySet().stream().filter(c -> c.getKey().contains(id) && intent.equals(c.getValue().getIntent())).forEach(c -> store.remove(c.getKey())); } } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/codestore/JdbcExpiringCodeStore.java b/server/src/main/java/org/cloudfoundry/identity/uaa/codestore/JdbcExpiringCodeStore.java index 6184fc08e03..cf936d60641 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/codestore/JdbcExpiringCodeStore.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/codestore/JdbcExpiringCodeStore.java @@ -12,13 +12,6 @@ *******************************************************************************/ package org.cloudfoundry.identity.uaa.codestore; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Timestamp; -import java.util.concurrent.atomic.AtomicLong; - -import javax.sql.DataSource; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.dao.DataIntegrityViolationException; @@ -28,16 +21,26 @@ import org.springframework.security.oauth2.common.util.RandomValueStringGenerator; import org.springframework.util.Assert; +import javax.sql.DataSource; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.concurrent.atomic.AtomicLong; + public class JdbcExpiringCodeStore implements ExpiringCodeStore { public static final String tableName = "expiring_code_store"; public static final String fields = "code, expiresat, data, intent"; - public static final String insert = "insert into " + tableName + " (" + fields + ") values (?,?,?,?)"; - public static final String delete = "delete from " + tableName + " where code = ?"; - public static final String deleteIntent = "delete from " + tableName + " where intent = ?"; - public static final String deleteExpired = "delete from " + tableName + " where expiresat < ?"; - public static final String select = "select " + fields + " from " + tableName + " where code = ?"; + protected static final String insert = "insert into " + tableName + " (" + fields + ") values (?,?,?,?)"; + protected static final String delete = "delete from " + tableName + " where code = ?"; + protected static final String deleteIntent = "delete from " + tableName + " where intent = ? and code LIKE ?"; + protected static final String deleteExpired = "delete from " + tableName + " where expiresat < ?"; + + private final JdbcExpiringCodeMapper rowMapper = new JdbcExpiringCodeMapper(); + + protected static final String selectAllFields = "select " + fields + " from " + tableName + " where code = ?"; + public static final String SELECT_BY_EMAIL_AND_CLIENT_ID = "select " + fields + " from " + tableName + " where data like '%%\"email\":\"%s\"%%' and data like '%%\"client_id\":\"%s\"%%' ORDER BY expiresat DESC LIMIT 1"; @@ -87,7 +90,7 @@ public ExpiringCode generateCode(String data, Timestamp expiresAt, String intent count++; String code = generator.generate(); try { - int update = jdbcTemplate.update(insert, code, expiresAt.getTime(), data, intent); + int update = jdbcTemplate.update(insert, zonifyCode(code), expiresAt.getTime(), data, intent); if (update == 1) { ExpiringCode expiringCode = new ExpiringCode(code, expiresAt, data, intent); return expiringCode; @@ -113,17 +116,14 @@ public ExpiringCode retrieveCode(String code) { } try { - ExpiringCode expiringCode = jdbcTemplate.queryForObject(select, new JdbcExpiringCodeMapper(), code); - try { - if (expiringCode != null) { - jdbcTemplate.update(delete, code); - } - if (expiringCode.getExpiresAt().getTime() < System.currentTimeMillis()) { - expiringCode = null; - } - } finally { - return expiringCode; + ExpiringCode expiringCode = jdbcTemplate.queryForObject(selectAllFields, rowMapper, zonifyCode(code)); + if (expiringCode != null) { + jdbcTemplate.update(delete, zonifyCode(code)); + } + if (expiringCode.getExpiresAt().getTime() < System.currentTimeMillis()) { + expiringCode = null; } + return expiringCode; } catch (EmptyResultDataAccessException x) { return null; } @@ -138,7 +138,7 @@ public void setGenerator(RandomValueStringGenerator generator) { public void expireByIntent(String intent) { Assert.hasText(intent); - jdbcTemplate.update(deleteIntent, intent); + jdbcTemplate.update(deleteIntent, intent, zonifyCode("%")+"%"); } public int cleanExpiredEntries() { @@ -154,15 +154,14 @@ public int cleanExpiredEntries() { return 0; } - protected static class JdbcExpiringCodeMapper implements RowMapper { + protected class JdbcExpiringCodeMapper implements RowMapper { @Override public ExpiringCode mapRow(ResultSet rs, int rowNum) throws SQLException { - int pos = 1; - String code = rs.getString(pos++); - Timestamp expiresAt = new Timestamp(rs.getLong(pos++)); - String data = rs.getString(pos++); - String intent = rs.getString(pos++); + String code = extractCode(rs.getString("code")); + Timestamp expiresAt = new Timestamp(rs.getLong("expiresat")); + String intent = rs.getString("intent"); + String data = rs.getString("data"); return new ExpiringCode(code, expiresAt, data, intent); } diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/codestore/ExpiringCodeStoreTests.java b/server/src/test/java/org/cloudfoundry/identity/uaa/codestore/ExpiringCodeStoreTests.java index 5003f02cc53..4b583dc26ee 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/codestore/ExpiringCodeStoreTests.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/codestore/ExpiringCodeStoreTests.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Cloud Foundry + * Cloud Foundry * Copyright (c) [2009-2016] Pivotal Software, Inc. All Rights Reserved. * * This product is licensed to you under the Apache License, Version 2.0 (the "License"). @@ -14,6 +14,8 @@ import org.cloudfoundry.identity.uaa.test.JdbcTestBase; import org.cloudfoundry.identity.uaa.test.TestUtils; +import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder; +import org.cloudfoundry.identity.uaa.zone.MultitenancyFixture; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -24,12 +26,15 @@ import org.springframework.dao.DataAccessException; import org.springframework.dao.DataIntegrityViolationException; import org.springframework.dao.EmptyResultDataAccessException; +import org.springframework.jdbc.core.RowMapper; import org.springframework.security.oauth2.common.util.RandomValueStringGenerator; +import org.springframework.test.util.ReflectionTestUtils; import java.sql.SQLException; import java.sql.Timestamp; import java.util.Arrays; import java.util.Collection; +import java.util.Map; @RunWith(Parameterized.class) public class ExpiringCodeStoreTests extends JdbcTestBase { @@ -63,6 +68,16 @@ public void initExpiringCodeStoreTests() throws Exception { } } + public int countCodes() { + if (expiringCodeStore instanceof InMemoryExpiringCodeStore) { + Map map = (Map) ReflectionTestUtils.getField(expiringCodeStore, "store"); + return map.size(); + } else { + // confirm that everything is clean prior to test. + return jdbcTemplate.queryForObject("select count(*) from expiring_code_store", Integer.class); + } + } + @Test public void testGenerateCode() throws Exception { String data = "{}"; @@ -125,6 +140,22 @@ public void testRetrieveCode() throws Exception { Assert.assertNull(expiringCodeStore.retrieveCode(generatedCode.getCode())); } + @Test + public void testRetrieveCode_In_Another_Zone() throws Exception { + String data = "{}"; + Timestamp expiresAt = new Timestamp(System.currentTimeMillis() + 60000); + ExpiringCode generatedCode = expiringCodeStore.generateCode(data, expiresAt, null); + + IdentityZoneHolder.set(MultitenancyFixture.identityZone("other", "other")); + Assert.assertNull(expiringCodeStore.retrieveCode(generatedCode.getCode())); + + IdentityZoneHolder.clear(); + ExpiringCode retrievedCode = expiringCodeStore.retrieveCode(generatedCode.getCode()); + Assert.assertEquals(generatedCode, retrievedCode); + + + } + @Test public void testRetrieveCodeWithCodeNotFound() throws Exception { ExpiringCode retrievedCode = expiringCodeStore.retrieveCode("unknown"); @@ -143,7 +174,7 @@ public void testStoreLargeData() throws Exception { Arrays.fill(oneMb, 'a'); String aaaString = new String(oneMb); ExpiringCode expiringCode = expiringCodeStore.generateCode(aaaString, new Timestamp( - System.currentTimeMillis() + 60000), null); + System.currentTimeMillis() + 60000), null); String code = expiringCode.getCode(); ExpiringCode actualCode = expiringCodeStore.retrieveCode(code); Assert.assertEquals(expiringCode, actualCode); @@ -164,10 +195,16 @@ public void testExpiredCodeReturnsNull() throws Exception { public void testExpireCodeByIntent() throws Exception { ExpiringCode code = expiringCodeStore.generateCode("{}", new Timestamp(System.currentTimeMillis() + 60000), "Test Intent"); + Assert.assertEquals(1, countCodes()); + + IdentityZoneHolder.set(MultitenancyFixture.identityZone("id","id")); expiringCodeStore.expireByIntent("Test Intent"); + Assert.assertEquals(1, countCodes()); + IdentityZoneHolder.clear(); + expiringCodeStore.expireByIntent("Test Intent"); ExpiringCode retrievedCode = expiringCodeStore.retrieveCode(code.getCode()); - + Assert.assertEquals(0, countCodes()); Assert.assertNull(retrievedCode); } @@ -194,8 +231,8 @@ public void testExpirationCleaner() throws Exception { if (JdbcExpiringCodeStore.class == expiringCodeStoreClass) { jdbcTemplate.update(JdbcExpiringCodeStore.insert, "test", System.currentTimeMillis() - 1000, "{}", null); ((JdbcExpiringCodeStore) expiringCodeStore).cleanExpiredEntries(); - jdbcTemplate.queryForObject(JdbcExpiringCodeStore.select, - new JdbcExpiringCodeStore.JdbcExpiringCodeMapper(), "test"); + jdbcTemplate.queryForObject(JdbcExpiringCodeStore.selectAllFields, + (RowMapper) ReflectionTestUtils.getField(expiringCodeStore, "rowMapper"), "test"); } else { throw new EmptyResultDataAccessException(1); } diff --git a/uaa/src/test/java/org/cloudfoundry/identity/uaa/invitations/InvitationsEndpointMockMvcTests.java b/uaa/src/test/java/org/cloudfoundry/identity/uaa/invitations/InvitationsEndpointMockMvcTests.java index e026b60d7dd..c64a8dfc052 100644 --- a/uaa/src/test/java/org/cloudfoundry/identity/uaa/invitations/InvitationsEndpointMockMvcTests.java +++ b/uaa/src/test/java/org/cloudfoundry/identity/uaa/invitations/InvitationsEndpointMockMvcTests.java @@ -4,11 +4,13 @@ import org.cloudfoundry.identity.uaa.codestore.ExpiringCode; import org.cloudfoundry.identity.uaa.codestore.ExpiringCodeStore; import org.cloudfoundry.identity.uaa.codestore.ExpiringCodeType; +import org.cloudfoundry.identity.uaa.codestore.InMemoryExpiringCodeStore; import org.cloudfoundry.identity.uaa.constants.OriginKeys; import org.cloudfoundry.identity.uaa.mock.InjectedMockContextTest; import org.cloudfoundry.identity.uaa.mock.util.MockMvcUtils; import org.cloudfoundry.identity.uaa.provider.IdentityProvider; import org.cloudfoundry.identity.uaa.provider.IdentityProviderProvisioning; +import org.cloudfoundry.identity.uaa.provider.JdbcIdentityProviderProvisioning; import org.cloudfoundry.identity.uaa.provider.UaaIdentityProviderDefinition; import org.cloudfoundry.identity.uaa.scim.ScimUser; import org.cloudfoundry.identity.uaa.util.JsonUtils; @@ -90,7 +92,8 @@ public void setUp() throws Exception { @After public void cleanUpDomainList() throws Exception { - IdentityProvider uaaProvider = getWebApplicationContext().getBean(IdentityProviderProvisioning.class).retrieveByOrigin(UAA, IdentityZone.getUaa().getId()); + IdentityZoneHolder.clear(); + IdentityProvider uaaProvider = getWebApplicationContext().getBean(JdbcIdentityProviderProvisioning.class).retrieveByOrigin(UAA, IdentityZone.getUaa().getId()); uaaProvider.getConfig().setEmailDomain(null); getWebApplicationContext().getBean(IdentityProviderProvisioning.class).update(uaaProvider); } @@ -146,7 +149,7 @@ public void invite_User_In_Zone_With_DefaultZone_UaaAdmin() throws Exception { InvitationsResponse invitationsResponse = readValue(mvcResult.getResponse().getContentAsString(), InvitationsResponse.class); BaseClientDetails defaultClientDetails = new BaseClientDetails(); defaultClientDetails.setClientId("admin"); - assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone().getSubdomain(), invitationsResponse, defaultClientDetails); + assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone(), invitationsResponse, defaultClientDetails); } @@ -180,7 +183,7 @@ public void invite_User_In_Zone_With_DefaultZone_ZoneAdmin() throws Exception { .andReturn(); InvitationsResponse invitationsResponse = readValue(mvcResult.getResponse().getContentAsString(), InvitationsResponse.class); - assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone().getSubdomain(), invitationsResponse, zonifiedScimInviteClientDetails); + assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone(), invitationsResponse, zonifiedScimInviteClientDetails); } @@ -214,7 +217,7 @@ public void invite_User_In_Zone_With_DefaultZone_ScimInvite() throws Exception { .andReturn(); InvitationsResponse invitationsResponse = readValue(mvcResult.getResponse().getContentAsString(), InvitationsResponse.class); - assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone().getSubdomain(), invitationsResponse, zonifiedScimInviteClientDetails); + assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone(), invitationsResponse, zonifiedScimInviteClientDetails); } @@ -233,7 +236,7 @@ public void invite_User_Within_Zone() throws Exception { String redirectUrl = "example.com"; InvitationsResponse response = sendRequestWithTokenAndReturnResponse(zonedScimInviteToken, result.getIdentityZone().getSubdomain(), zonedClientDetails.getClientId(), redirectUrl, email); - assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone().getSubdomain(), response, zonedClientDetails); + assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone(), response, zonedClientDetails); } @Test @@ -325,6 +328,7 @@ public void invitations_Accept_Get_Security() throws Exception { sendRequestWithToken(userToken, null, clientId, "example.com", "user1@"+domain); String code = getWebApplicationContext().getBean(JdbcTemplate.class).queryForObject("SELECT code FROM expiring_code_store", String.class); + code = new InMemoryExpiringCodeStore().extractCode(code); assertNotNull("Invite Code Must be Present", code); MockHttpServletRequestBuilder accept = get("/invitations/accept") @@ -350,7 +354,7 @@ public static void sendRequestWithToken(String token, String subdomain, String c assertThat(response.getFailedInvites().size(), is(0)); } - private void assertResponseAndCodeCorrect(String[] emails, String redirectUrl, String subdomain, InvitationsResponse response, ClientDetails clientDetails) { + private void assertResponseAndCodeCorrect(String[] emails, String redirectUrl, IdentityZone zone, InvitationsResponse response, ClientDetails clientDetails) { for (int i = 0; i < emails.length; i++) { assertThat(response.getNewInvites().size(), is(emails.length)); assertThat(response.getNewInvites().get(i).getEmail(), is(emails[i])); @@ -361,8 +365,9 @@ private void assertResponseAndCodeCorrect(String[] emails, String redirectUrl, S String link = response.getNewInvites().get(i).getInviteLink().toString(); assertFalse(contains(link, "@")); assertFalse(contains(link, "%40")); - if (StringUtils.hasText(subdomain)) { - assertThat(link, startsWith("http://" + subdomain + ".localhost/invitations/accept")); + if (zone != null && StringUtils.hasText(zone.getSubdomain())) { + assertThat(link, startsWith("http://" + zone.getSubdomain() + ".localhost/invitations/accept")); + IdentityZoneHolder.set(zone); } else { assertThat(link, startsWith("http://localhost/invitations/accept")); } @@ -371,6 +376,7 @@ private void assertResponseAndCodeCorrect(String[] emails, String redirectUrl, S assertThat(query, startsWith("code=")); String code = query.split("=")[1]; ExpiringCode expiringCode = codeStore.retrieveCode(code); + IdentityZoneHolder.clear(); assertThat(expiringCode.getExpiresAt().getTime(), is(greaterThan(System.currentTimeMillis()))); assertThat(expiringCode.getIntent(), is(ExpiringCodeType.INVITATION.name())); Map data = readValue(expiringCode.getData(), new TypeReference>() {}); diff --git a/uaa/src/test/java/org/cloudfoundry/identity/uaa/login/InvitationsServiceMockMvcTests.java b/uaa/src/test/java/org/cloudfoundry/identity/uaa/login/InvitationsServiceMockMvcTests.java index 383bc14075d..2542073f07c 100644 --- a/uaa/src/test/java/org/cloudfoundry/identity/uaa/login/InvitationsServiceMockMvcTests.java +++ b/uaa/src/test/java/org/cloudfoundry/identity/uaa/login/InvitationsServiceMockMvcTests.java @@ -14,6 +14,8 @@ package org.cloudfoundry.identity.uaa.login; +import org.cloudfoundry.identity.uaa.codestore.InMemoryExpiringCodeStore; +import org.cloudfoundry.identity.uaa.constants.OriginKeys; import org.cloudfoundry.identity.uaa.message.EmailService; import org.cloudfoundry.identity.uaa.provider.AbstractIdentityProviderDefinition; import org.cloudfoundry.identity.uaa.constants.OriginKeys; @@ -243,6 +245,7 @@ public void accept_invitation_sets_your_password() throws Exception { .andReturn(); code = getWebApplicationContext().getBean(JdbcTemplate.class).queryForObject("select code from expiring_code_store", String.class); + code = new InMemoryExpiringCodeStore().extractCode(code); MockHttpSession session = (MockHttpSession) result.getRequest().getSession(false); result = getMockMvc().perform( post("/invitations/accept.do") diff --git a/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/ldap/LdapMockMvcTests.java b/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/ldap/LdapMockMvcTests.java index caf23b29f23..fe4edd25e99 100644 --- a/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/ldap/LdapMockMvcTests.java +++ b/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/ldap/LdapMockMvcTests.java @@ -16,6 +16,7 @@ import org.cloudfoundry.identity.uaa.authentication.UaaAuthentication; import org.cloudfoundry.identity.uaa.authentication.manager.AuthzAuthenticationManager; import org.cloudfoundry.identity.uaa.authentication.manager.DynamicZoneAwareAuthenticationManager; +import org.cloudfoundry.identity.uaa.codestore.InMemoryExpiringCodeStore; import org.cloudfoundry.identity.uaa.constants.OriginKeys; import org.cloudfoundry.identity.uaa.mock.util.ApacheDSHelper; import org.cloudfoundry.identity.uaa.mock.util.MockMvcUtils.ZoneScimInviteData; @@ -269,7 +270,9 @@ public void acceptInvitation_for_ldap_user_whose_username_is_not_email() throws .andReturn(); code = mainContext.getBean(JdbcTemplate.class).queryForObject("select code from expiring_code_store", String.class); - + IdentityZoneHolder.set(zone.getZone().getIdentityZone()); + code = new InMemoryExpiringCodeStore().extractCode(code); + IdentityZoneHolder.clear(); MockHttpSession session = (MockHttpSession) result.getRequest().getSession(false); mockMvc.perform(post("/invitations/accept_enterprise.do") .session(session) @@ -309,7 +312,9 @@ public void acceptInvitation_for_ldap_user_whose_username_is_not_email() throws .andReturn(); code = mainContext.getBean(JdbcTemplate.class).queryForObject("select code from expiring_code_store", String.class); - + IdentityZoneHolder.set(zone.getZone().getIdentityZone()); + code = new InMemoryExpiringCodeStore().extractCode(code); + IdentityZoneHolder.clear(); session = (MockHttpSession) result.getRequest().getSession(false); mockMvc.perform(post("/invitations/accept_enterprise.do") .session(session) diff --git a/uaa/src/test/java/org/cloudfoundry/identity/uaa/scim/endpoints/ScimUserEndpointsMockMvcTests.java b/uaa/src/test/java/org/cloudfoundry/identity/uaa/scim/endpoints/ScimUserEndpointsMockMvcTests.java index 84c25e39078..306b64b3c47 100644 --- a/uaa/src/test/java/org/cloudfoundry/identity/uaa/scim/endpoints/ScimUserEndpointsMockMvcTests.java +++ b/uaa/src/test/java/org/cloudfoundry/identity/uaa/scim/endpoints/ScimUserEndpointsMockMvcTests.java @@ -33,6 +33,7 @@ import org.cloudfoundry.identity.uaa.zone.IdentityZoneSwitchingFilter; import org.hamcrest.MatcherAssert; import org.json.JSONObject; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.springframework.http.HttpStatus; @@ -99,6 +100,11 @@ public void setUp() throws Exception { uaaAdminToken = testClient.getClientCredentialsOAuthAccessToken(clientId, clientSecret, "uaa.admin"); } + @After + public void clear() { + IdentityZoneHolder.clear(); + } + private ScimUser createUser(String token) throws Exception { return createUser(token, null); } @@ -274,7 +280,9 @@ public void verification_link_in_non_default_zone() throws Exception { String code = getQueryStringParam(query, "code"); assertThat(code, is(notNullValue())); + IdentityZoneHolder.set(zoneResult.getIdentityZone()); ExpiringCode expiringCode = codeStore.retrieveCode(code); + IdentityZoneHolder.clear(); assertThat(expiringCode.getExpiresAt().getTime(), is(greaterThan(System.currentTimeMillis()))); assertThat(expiringCode.getIntent(), is(REGISTRATION.name())); Map data = JsonUtils.readValue(expiringCode.getData(), new TypeReference>() {}); @@ -311,7 +319,9 @@ public void verification_link_in_non_default_zone_using_switch() throws Exceptio String code = getQueryStringParam(query, "code"); assertThat(code, is(notNullValue())); + IdentityZoneHolder.set(zoneResult.getIdentityZone()); ExpiringCode expiringCode = codeStore.retrieveCode(code); + IdentityZoneHolder.clear(); assertThat(expiringCode.getExpiresAt().getTime(), is(greaterThan(System.currentTimeMillis()))); assertThat(expiringCode.getIntent(), is(REGISTRATION.name())); Map data = JsonUtils.readValue(expiringCode.getData(), new TypeReference>() {});