Skip to content

Commit

Permalink
Add zone ID to expiring codes
Browse files Browse the repository at this point in the history
  • Loading branch information
fhanik committed May 12, 2017
1 parent 7db5e58 commit eb3f860
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 54 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Cloud Foundry
* Cloud Foundry
* Copyright (c) [2009-2014] Pivotal Software, Inc. All Rights Reserved.
*
* This product is licensed to you under the Apache License, Version 2.0 (the "License").
Expand All @@ -12,6 +12,7 @@
*******************************************************************************/
package org.cloudfoundry.identity.uaa.codestore;

import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;

import java.sql.Timestamp;
Expand All @@ -20,17 +21,17 @@ public interface ExpiringCodeStore {

/**
* Generate and persist a one-time code with an expiry date.
*
*
* @param data JSON object to be associated with the code
* @return code the generated one-time code
* @throws java.lang.NullPointerException if data or expiresAt is null
* @throws java.lang.NullPointerException if data or expiresAt is null
* @throws java.lang.IllegalArgumentException if expiresAt is in the past
*/
ExpiringCode generateCode(String data, Timestamp expiresAt);

/**
* Retrieve a code and delete it if it exists.
*
*
* @param code the one-time code to look for
* @return code or null if the code is not found
* @throws java.lang.NullPointerException if the code is null
Expand All @@ -39,8 +40,20 @@ public interface ExpiringCodeStore {

/**
* Set the code generator for this store.
*
*
* @param generator Code generator
*/
void setGenerator(RandomValueStringGenerator generator);

default String zonifyCode(String code) {
return code + "[zone[" + IdentityZoneHolder.get().getId() + "]]";
}

default String extractCode(String zoneCode) {
int endIndex = zoneCode.indexOf("[zone[" + IdentityZoneHolder.get().getId() + "]]");
if (endIndex < 0) {
return zoneCode;
}
return zoneCode.substring(0, endIndex);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Cloud Foundry
* Cloud Foundry
* Copyright (c) [2009-2014] Pivotal Software, Inc. All Rights Reserved.
*
* This product is licensed to you under the Apache License, Version 2.0 (the "License").
Expand Down Expand Up @@ -47,7 +47,7 @@ public ExpiringCode generateCode(String data, Timestamp expiresAt) {

ExpiringCode expiringCode = new ExpiringCode(code, expiresAt, data);

ExpiringCode duplicate = store.putIfAbsent(code, expiringCode);
ExpiringCode duplicate = store.putIfAbsent(zonifyCode(code), expiringCode);
if (duplicate != null) {
throw new DataIntegrityViolationException("Duplicate code: " + code);
}
Expand All @@ -61,7 +61,7 @@ public ExpiringCode retrieveCode(String code) {
throw new NullPointerException();
}

ExpiringCode expiringCode = store.remove(code);
ExpiringCode expiringCode = store.remove(zonifyCode(code));

if (expiringCode == null || expiringCode.getExpiresAt().getTime() < System.currentTimeMillis()) {
expiringCode = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@
*******************************************************************************/
package org.cloudfoundry.identity.uaa.codestore;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.concurrent.atomic.AtomicLong;

import javax.sql.DataSource;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.dao.DataIntegrityViolationException;
Expand All @@ -27,6 +20,12 @@
import org.springframework.jdbc.core.RowMapper;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;

import javax.sql.DataSource;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.concurrent.atomic.AtomicLong;

public class JdbcExpiringCodeStore implements ExpiringCodeStore {

public static final String tableName = "expiring_code_store";
Expand All @@ -48,6 +47,8 @@ public class JdbcExpiringCodeStore implements ExpiringCodeStore {
private AtomicLong lastExpired = new AtomicLong();
private long expirationInterval = 60 * 1000; // once a minute

private RowMapper<ExpiringCode> rowMapper = new JdbcExpiringCodeMapper();

public long getExpirationInterval() {
return expirationInterval;
}
Expand Down Expand Up @@ -85,7 +86,7 @@ public ExpiringCode generateCode(String data, Timestamp expiresAt) {
count++;
String code = generator.generate();
try {
int update = jdbcTemplate.update(insert, code, expiresAt.getTime(), data);
int update = jdbcTemplate.update(insert, zonifyCode(code), expiresAt.getTime(), data);
if (update == 1) {
ExpiringCode expiringCode = new ExpiringCode(code, expiresAt, data);
return expiringCode;
Expand All @@ -111,17 +112,14 @@ public ExpiringCode retrieveCode(String code) {
}

try {
ExpiringCode expiringCode = jdbcTemplate.queryForObject(select, new JdbcExpiringCodeMapper(), code);
try {
if (expiringCode != null) {
jdbcTemplate.update(delete, code);
}
if (expiringCode.getExpiresAt().getTime() < System.currentTimeMillis()) {
expiringCode = null;
}
} finally {
return expiringCode;
ExpiringCode expiringCode = jdbcTemplate.queryForObject(select, rowMapper, zonifyCode(code));
if (expiringCode != null) {
jdbcTemplate.update(delete, zonifyCode(code));
}
if (expiringCode.getExpiresAt().getTime() < System.currentTimeMillis()) {
expiringCode = null;
}
return expiringCode;
} catch (EmptyResultDataAccessException x) {
return null;
}
Expand All @@ -145,14 +143,14 @@ public int cleanExpiredEntries() {
return 0;
}

protected static class JdbcExpiringCodeMapper implements RowMapper<ExpiringCode> {
protected class JdbcExpiringCodeMapper implements RowMapper<ExpiringCode> {

@Override
public ExpiringCode mapRow(ResultSet rs, int rowNum) throws SQLException {
int pos = 1;
String code = rs.getString(pos++);
String code = extractCode(rs.getString(pos++));
Timestamp expiresAt = new Timestamp(rs.getLong(pos++));
String data = rs.getString(pos++).toString();
String data = rs.getString(pos++);
return new ExpiringCode(code, expiresAt, data);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Cloud Foundry
* Cloud Foundry
* Copyright (c) [2009-2014] Pivotal Software, Inc. All Rights Reserved.
*
* This product is licensed to you under the Apache License, Version 2.0 (the "License").
Expand All @@ -12,11 +12,27 @@
*******************************************************************************/
package org.cloudfoundry.identity.uaa.codestore;

import org.cloudfoundry.identity.uaa.test.JdbcTestBase;
import org.cloudfoundry.identity.uaa.test.TestUtils;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.cloudfoundry.identity.uaa.zone.MultitenancyFixture;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.springframework.dao.DataAccessException;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;
import org.springframework.test.util.ReflectionTestUtils;

import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertEquals;
Expand All @@ -26,18 +42,6 @@
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import org.cloudfoundry.identity.uaa.test.JdbcTestBase;
import org.cloudfoundry.identity.uaa.test.TestUtils;
import org.cloudfoundry.identity.uaa.util.JsonUtils;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.springframework.dao.DataAccessException;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;

@RunWith(Parameterized.class)
public class ExpiringCodeStoreTests extends JdbcTestBase {
Expand Down Expand Up @@ -71,6 +75,16 @@ public void initExpiringCodeStoreTests() throws Exception {
}
}

public int countCodes() {
if (expiringCodeStore instanceof InMemoryExpiringCodeStore) {
Map map = (Map) ReflectionTestUtils.getField(expiringCodeStore, "store");
return map.size();
} else {
// confirm that everything is clean prior to test.
return jdbcTemplate.queryForObject("select count(*) from expiring_code_store", Integer.class);
}
}

@Test
public void testGenerateCode() throws Exception {
String data = "{}";
Expand Down Expand Up @@ -133,6 +147,22 @@ public void testRetrieveCode() throws Exception {
assertNull(expiringCodeStore.retrieveCode(generatedCode.getCode()));
}

@Test
public void testRetrieveCode_In_Another_Zone() throws Exception {
String data = "{}";
Timestamp expiresAt = new Timestamp(System.currentTimeMillis() + 60000);
ExpiringCode generatedCode = expiringCodeStore.generateCode(data, expiresAt);

IdentityZoneHolder.set(MultitenancyFixture.identityZone("other", "other"));
Assert.assertNull(expiringCodeStore.retrieveCode(generatedCode.getCode()));

IdentityZoneHolder.clear();
ExpiringCode retrievedCode = expiringCodeStore.retrieveCode(generatedCode.getCode());
Assert.assertEquals(generatedCode, retrievedCode);


}

@Test
public void testRetrieveCodeWithCodeNotFound() throws Exception {
ExpiringCode retrievedCode = expiringCodeStore.retrieveCode("unknown");
Expand All @@ -151,7 +181,7 @@ public void testStoreLargeData() throws Exception {
Arrays.fill(oneMb, 'a');
String aaaString = new String(oneMb);
ExpiringCode expiringCode = expiringCodeStore.generateCode(aaaString, new Timestamp(
System.currentTimeMillis() + 60000));
System.currentTimeMillis() + 60000));
String code = expiringCode.getCode();
ExpiringCode actualCode = expiringCodeStore.retrieveCode(code);
assertEquals(expiringCode, actualCode);
Expand Down Expand Up @@ -192,7 +222,7 @@ public void testExpirationCleaner() throws Exception {
jdbcTemplate.update(JdbcExpiringCodeStore.insert, "test", System.currentTimeMillis() - 1000, "{}");
((JdbcExpiringCodeStore) expiringCodeStore).cleanExpiredEntries();
jdbcTemplate.queryForObject(JdbcExpiringCodeStore.select,
new JdbcExpiringCodeStore.JdbcExpiringCodeMapper(), "test");
(RowMapper<ExpiringCode>) ReflectionTestUtils.getField(expiringCodeStore, "rowMapper"), "test");
} else {
throw new EmptyResultDataAccessException(1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import org.cloudfoundry.identity.uaa.authentication.Origin;
import org.cloudfoundry.identity.uaa.codestore.ExpiringCode;
import org.cloudfoundry.identity.uaa.codestore.ExpiringCodeStore;
import org.cloudfoundry.identity.uaa.codestore.InMemoryExpiringCodeStore;
import org.cloudfoundry.identity.uaa.mock.InjectedMockContextTest;
import org.cloudfoundry.identity.uaa.mock.util.MockMvcUtils;
import org.cloudfoundry.identity.uaa.scim.ScimUser;
Expand All @@ -12,6 +13,7 @@
import org.cloudfoundry.identity.uaa.zone.IdentityProviderProvisioning;
import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.cloudfoundry.identity.uaa.zone.JdbcIdentityProviderProvisioning;
import org.cloudfoundry.identity.uaa.zone.MultitenantJdbcClientDetailsService;
import org.cloudfoundry.identity.uaa.zone.UaaIdentityProviderDefinition;
import org.flywaydb.core.internal.util.StringUtils;
Expand Down Expand Up @@ -81,7 +83,8 @@ public void setUp() throws Exception {

@After
public void cleanUpDomainList() throws Exception {
IdentityProvider<UaaIdentityProviderDefinition> uaaProvider = getWebApplicationContext().getBean(IdentityProviderProvisioning.class).retrieveByOrigin(UAA, IdentityZone.getUaa().getId());
IdentityZoneHolder.clear();
IdentityProvider<UaaIdentityProviderDefinition> uaaProvider = getWebApplicationContext().getBean(JdbcIdentityProviderProvisioning.class).retrieveByOrigin(UAA, IdentityZone.getUaa().getId());
uaaProvider.getConfig().setEmailDomain(null);
getWebApplicationContext().getBean(IdentityProviderProvisioning.class).update(uaaProvider);
}
Expand Down Expand Up @@ -113,7 +116,7 @@ public void invite_User_With_User_Credentials() throws Exception {

@Test
public void invite_User_Within_Zone() throws Exception {
String subdomain = generator.generate();
String subdomain = generator.generate().toLowerCase();
MockMvcUtils.IdentityZoneCreationResult result = utils().createOtherIdentityZoneAndReturnResult(subdomain, getMockMvc(), getWebApplicationContext(), null);

String zonedClientId = "zonedClientId";
Expand All @@ -126,7 +129,7 @@ public void invite_User_Within_Zone() throws Exception {
String redirectUrl = "example.com";
InvitationsResponse response = sendRequestWithTokenAndReturnResponse(zonedScimInviteToken, subdomain, zonedClientDetails.getClientId(), redirectUrl, email);

assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, subdomain, response, zonedClientDetails);
assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone(), response, zonedClientDetails);
}

@Test
Expand Down Expand Up @@ -206,6 +209,7 @@ public void invitations_Accept_Get_Security() throws Exception {
sendRequestWithToken(userToken, null, clientId, "example.com", "user1@"+domain);

String code = getWebApplicationContext().getBean(JdbcTemplate.class).queryForObject("SELECT code FROM expiring_code_store", String.class);
code = new InMemoryExpiringCodeStore().extractCode(code);
assertNotNull("Invite Code Must be Present", code);

MockHttpServletRequestBuilder accept = get("/invitations/accept")
Expand All @@ -231,16 +235,17 @@ public static void sendRequestWithToken(String token, String subdomain, String c
assertThat(response.getFailedInvites().size(), is(0));
}

private void assertResponseAndCodeCorrect(String[] emails, String redirectUrl, String subdomain, InvitationsResponse response, ClientDetails clientDetails) {
private void assertResponseAndCodeCorrect(String[] emails, String redirectUrl, IdentityZone zone, InvitationsResponse response, ClientDetails clientDetails) {
for (int i = 0; i < emails.length; i++) {
assertThat(response.getNewInvites().size(), is(emails.length));
assertThat(response.getNewInvites().get(i).getEmail(), is(emails[i]));
assertThat(response.getNewInvites().get(i).getOrigin(), is(Origin.UAA));
assertThat(response.getNewInvites().get(i).getUserId(), is(notNullValue()));
assertThat(response.getNewInvites().get(i).getErrorCode(), is(nullValue()));
assertThat(response.getNewInvites().get(i).getErrorMessage(), is(nullValue()));
if (StringUtils.hasText(subdomain)) {
assertThat(response.getNewInvites().get(i).getInviteLink().toString(), startsWith("http://" + subdomain + ".localhost/invitations/accept"));
if (zone != null && StringUtils.hasText(zone.getSubdomain())) {
assertThat(response.getNewInvites().get(i).getInviteLink().toString(), startsWith("http://" + zone.getSubdomain() + ".localhost/invitations/accept"));
IdentityZoneHolder.set(zone);
} else {
assertThat(response.getNewInvites().get(i).getInviteLink().toString(), startsWith("http://localhost/invitations/accept"));
}
Expand All @@ -249,6 +254,7 @@ private void assertResponseAndCodeCorrect(String[] emails, String redirectUrl, S
assertThat(query, startsWith("code="));
String code = query.split("=")[1];
ExpiringCode expiringCode = codeStore.retrieveCode(code);
IdentityZoneHolder.clear();
assertThat(expiringCode.getExpiresAt().getTime(), is(greaterThan(System.currentTimeMillis())));
Map<String, String> data = JsonUtils.readValue(expiringCode.getData(), new TypeReference<Map<String, String>>() {});
assertThat(data.get(InvitationConstants.USER_ID), is(notNullValue()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import org.cloudfoundry.identity.uaa.AbstractIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.authentication.Origin;
import org.cloudfoundry.identity.uaa.codestore.InMemoryExpiringCodeStore;
import org.cloudfoundry.identity.uaa.ldap.LdapIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.login.saml.SamlIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.login.util.FakeJavaMailSender;
Expand Down Expand Up @@ -223,6 +224,7 @@ public void accept_invitation_sets_your_password() throws Exception {
.andReturn();

code = getWebApplicationContext().getBean(JdbcTemplate.class).queryForObject("select code from expiring_code_store", String.class);
code = new InMemoryExpiringCodeStore().extractCode(code);
MockHttpSession session = (MockHttpSession) result.getRequest().getSession(false);
result = getMockMvc().perform(
post("/invitations/accept.do")
Expand Down
Loading

0 comments on commit eb3f860

Please sign in to comment.