Skip to content

Commit

Permalink
Add zone ID to expiring codes
Browse files Browse the repository at this point in the history
  • Loading branch information
fhanik committed May 12, 2017
1 parent 1c42cf7 commit bbf6751
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 54 deletions.
@@ -1,5 +1,5 @@
/*******************************************************************************
* Cloud Foundry
* Cloud Foundry
* Copyright (c) [2009-2016] Pivotal Software, Inc. All Rights Reserved.
*
* This product is licensed to you under the Apache License, Version 2.0 (the "License").
Expand All @@ -12,6 +12,7 @@
*******************************************************************************/
package org.cloudfoundry.identity.uaa.codestore;

import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;

import java.sql.Timestamp;
Expand All @@ -20,7 +21,7 @@ public interface ExpiringCodeStore {

/**
* Generate and persist a one-time code with an expiry date.
*
*
* @param data JSON object to be associated with the code
* @param intent An optional key (not necessarily unique) for looking up codes
* @return code the generated one-time code
Expand All @@ -31,7 +32,7 @@ public interface ExpiringCodeStore {

/**
* Retrieve a code and delete it if it exists.
*
*
* @param code the one-time code to look for
* @return code or null if the code is not found
* @throws java.lang.NullPointerException if the code is null
Expand All @@ -40,7 +41,7 @@ public interface ExpiringCodeStore {

/**
* Set the code generator for this store.
*
*
* @param generator Code generator
*/
void setGenerator(RandomValueStringGenerator generator);
Expand All @@ -51,4 +52,16 @@ public interface ExpiringCodeStore {
* @param intent Intent of codes to remove
*/
void expireByIntent(String intent);

default String zonifyCode(String code) {
return code + "[zone[" + IdentityZoneHolder.get().getId()+"]]";
}

default String extractCode(String zoneCode) {
int endIndex = zoneCode.indexOf("[zone[" + IdentityZoneHolder.get().getId()+"]]");
if (endIndex<0) {
return zoneCode;
}
return zoneCode.substring(0, endIndex);
}
}
@@ -1,5 +1,5 @@
/*******************************************************************************
* Cloud Foundry
* Cloud Foundry
* Copyright (c) [2009-2016] Pivotal Software, Inc. All Rights Reserved.
*
* This product is licensed to you under the Apache License, Version 2.0 (the "License").
Expand All @@ -12,6 +12,7 @@
*******************************************************************************/
package org.cloudfoundry.identity.uaa.codestore;

import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -40,7 +41,7 @@ public ExpiringCode generateCode(String data, Timestamp expiresAt, String intent

ExpiringCode expiringCode = new ExpiringCode(code, expiresAt, data, intent);

ExpiringCode duplicate = store.putIfAbsent(code, expiringCode);
ExpiringCode duplicate = store.putIfAbsent(zonifyCode(code), expiringCode);
if (duplicate != null) {
throw new DataIntegrityViolationException("Duplicate code: " + code);
}
Expand All @@ -54,7 +55,7 @@ public ExpiringCode retrieveCode(String code) {
throw new NullPointerException();
}

ExpiringCode expiringCode = store.remove(code);
ExpiringCode expiringCode = store.remove(zonifyCode(code));

if (expiringCode == null || expiringCode.getExpiresAt().getTime() < System.currentTimeMillis()) {
expiringCode = null;
Expand All @@ -71,7 +72,7 @@ public void setGenerator(RandomValueStringGenerator generator) {
@Override
public void expireByIntent(String intent) {
Assert.hasText(intent);

store.values().stream().filter(c -> intent.equals(c.getIntent())).forEach(c -> store.remove(c.getCode()));
String id = IdentityZoneHolder.get().getId();
store.entrySet().stream().filter(c -> c.getKey().contains(id) && intent.equals(c.getValue().getIntent())).forEach(c -> store.remove(c.getKey()));
}
}
Expand Up @@ -12,13 +12,6 @@
*******************************************************************************/
package org.cloudfoundry.identity.uaa.codestore;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.concurrent.atomic.AtomicLong;

import javax.sql.DataSource;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.dao.DataIntegrityViolationException;
Expand All @@ -28,16 +21,26 @@
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;
import org.springframework.util.Assert;

import javax.sql.DataSource;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.concurrent.atomic.AtomicLong;

public class JdbcExpiringCodeStore implements ExpiringCodeStore {

public static final String tableName = "expiring_code_store";
public static final String fields = "code, expiresat, data, intent";

public static final String insert = "insert into " + tableName + " (" + fields + ") values (?,?,?,?)";
public static final String delete = "delete from " + tableName + " where code = ?";
public static final String deleteIntent = "delete from " + tableName + " where intent = ?";
public static final String deleteExpired = "delete from " + tableName + " where expiresat < ?";
public static final String select = "select " + fields + " from " + tableName + " where code = ?";
protected static final String insert = "insert into " + tableName + " (" + fields + ") values (?,?,?,?)";
protected static final String delete = "delete from " + tableName + " where code = ?";
protected static final String deleteIntent = "delete from " + tableName + " where intent = ? and code LIKE ?";
protected static final String deleteExpired = "delete from " + tableName + " where expiresat < ?";

private final JdbcExpiringCodeMapper rowMapper = new JdbcExpiringCodeMapper();

protected static final String selectAllFields = "select " + fields + " from " + tableName + " where code = ?";

public static final String SELECT_BY_EMAIL_AND_CLIENT_ID = "select " + fields + " from " + tableName +
" where data like '%%\"email\":\"%s\"%%' and data like '%%\"client_id\":\"%s\"%%' ORDER BY expiresat DESC LIMIT 1";

Expand Down Expand Up @@ -87,7 +90,7 @@ public ExpiringCode generateCode(String data, Timestamp expiresAt, String intent
count++;
String code = generator.generate();
try {
int update = jdbcTemplate.update(insert, code, expiresAt.getTime(), data, intent);
int update = jdbcTemplate.update(insert, zonifyCode(code), expiresAt.getTime(), data, intent);
if (update == 1) {
ExpiringCode expiringCode = new ExpiringCode(code, expiresAt, data, intent);
return expiringCode;
Expand All @@ -113,17 +116,14 @@ public ExpiringCode retrieveCode(String code) {
}

try {
ExpiringCode expiringCode = jdbcTemplate.queryForObject(select, new JdbcExpiringCodeMapper(), code);
try {
if (expiringCode != null) {
jdbcTemplate.update(delete, code);
}
if (expiringCode.getExpiresAt().getTime() < System.currentTimeMillis()) {
expiringCode = null;
}
} finally {
return expiringCode;
ExpiringCode expiringCode = jdbcTemplate.queryForObject(selectAllFields, rowMapper, zonifyCode(code));
if (expiringCode != null) {
jdbcTemplate.update(delete, zonifyCode(code));
}
if (expiringCode.getExpiresAt().getTime() < System.currentTimeMillis()) {
expiringCode = null;
}
return expiringCode;
} catch (EmptyResultDataAccessException x) {
return null;
}
Expand All @@ -138,7 +138,7 @@ public void setGenerator(RandomValueStringGenerator generator) {
public void expireByIntent(String intent) {
Assert.hasText(intent);

jdbcTemplate.update(deleteIntent, intent);
jdbcTemplate.update(deleteIntent, intent, zonifyCode("%")+"%");
}

public int cleanExpiredEntries() {
Expand All @@ -154,15 +154,14 @@ public int cleanExpiredEntries() {
return 0;
}

protected static class JdbcExpiringCodeMapper implements RowMapper<ExpiringCode> {
protected class JdbcExpiringCodeMapper implements RowMapper<ExpiringCode> {

@Override
public ExpiringCode mapRow(ResultSet rs, int rowNum) throws SQLException {
int pos = 1;
String code = rs.getString(pos++);
Timestamp expiresAt = new Timestamp(rs.getLong(pos++));
String data = rs.getString(pos++);
String intent = rs.getString(pos++);
String code = extractCode(rs.getString("code"));
Timestamp expiresAt = new Timestamp(rs.getLong("expiresat"));
String intent = rs.getString("intent");
String data = rs.getString("data");
return new ExpiringCode(code, expiresAt, data, intent);
}

Expand Down
@@ -1,5 +1,5 @@
/*******************************************************************************
* Cloud Foundry
* Cloud Foundry
* Copyright (c) [2009-2016] Pivotal Software, Inc. All Rights Reserved.
*
* This product is licensed to you under the Apache License, Version 2.0 (the "License").
Expand All @@ -14,6 +14,8 @@

import org.cloudfoundry.identity.uaa.test.JdbcTestBase;
import org.cloudfoundry.identity.uaa.test.TestUtils;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.cloudfoundry.identity.uaa.zone.MultitenancyFixture;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -24,12 +26,15 @@
import org.springframework.dao.DataAccessException;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;
import org.springframework.test.util.ReflectionTestUtils;

import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;

@RunWith(Parameterized.class)
public class ExpiringCodeStoreTests extends JdbcTestBase {
Expand Down Expand Up @@ -63,6 +68,16 @@ public void initExpiringCodeStoreTests() throws Exception {
}
}

public int countCodes() {
if (expiringCodeStore instanceof InMemoryExpiringCodeStore) {
Map map = (Map) ReflectionTestUtils.getField(expiringCodeStore, "store");
return map.size();
} else {
// confirm that everything is clean prior to test.
return jdbcTemplate.queryForObject("select count(*) from expiring_code_store", Integer.class);
}
}

@Test
public void testGenerateCode() throws Exception {
String data = "{}";
Expand Down Expand Up @@ -125,6 +140,22 @@ public void testRetrieveCode() throws Exception {
Assert.assertNull(expiringCodeStore.retrieveCode(generatedCode.getCode()));
}

@Test
public void testRetrieveCode_In_Another_Zone() throws Exception {
String data = "{}";
Timestamp expiresAt = new Timestamp(System.currentTimeMillis() + 60000);
ExpiringCode generatedCode = expiringCodeStore.generateCode(data, expiresAt, null);

IdentityZoneHolder.set(MultitenancyFixture.identityZone("other", "other"));
Assert.assertNull(expiringCodeStore.retrieveCode(generatedCode.getCode()));

IdentityZoneHolder.clear();
ExpiringCode retrievedCode = expiringCodeStore.retrieveCode(generatedCode.getCode());
Assert.assertEquals(generatedCode, retrievedCode);


}

@Test
public void testRetrieveCodeWithCodeNotFound() throws Exception {
ExpiringCode retrievedCode = expiringCodeStore.retrieveCode("unknown");
Expand All @@ -143,7 +174,7 @@ public void testStoreLargeData() throws Exception {
Arrays.fill(oneMb, 'a');
String aaaString = new String(oneMb);
ExpiringCode expiringCode = expiringCodeStore.generateCode(aaaString, new Timestamp(
System.currentTimeMillis() + 60000), null);
System.currentTimeMillis() + 60000), null);
String code = expiringCode.getCode();
ExpiringCode actualCode = expiringCodeStore.retrieveCode(code);
Assert.assertEquals(expiringCode, actualCode);
Expand All @@ -164,10 +195,16 @@ public void testExpiredCodeReturnsNull() throws Exception {
public void testExpireCodeByIntent() throws Exception {
ExpiringCode code = expiringCodeStore.generateCode("{}", new Timestamp(System.currentTimeMillis() + 60000), "Test Intent");

Assert.assertEquals(1, countCodes());

IdentityZoneHolder.set(MultitenancyFixture.identityZone("id","id"));
expiringCodeStore.expireByIntent("Test Intent");
Assert.assertEquals(1, countCodes());

IdentityZoneHolder.clear();
expiringCodeStore.expireByIntent("Test Intent");
ExpiringCode retrievedCode = expiringCodeStore.retrieveCode(code.getCode());

Assert.assertEquals(0, countCodes());
Assert.assertNull(retrievedCode);
}

Expand All @@ -194,8 +231,8 @@ public void testExpirationCleaner() throws Exception {
if (JdbcExpiringCodeStore.class == expiringCodeStoreClass) {
jdbcTemplate.update(JdbcExpiringCodeStore.insert, "test", System.currentTimeMillis() - 1000, "{}", null);
((JdbcExpiringCodeStore) expiringCodeStore).cleanExpiredEntries();
jdbcTemplate.queryForObject(JdbcExpiringCodeStore.select,
new JdbcExpiringCodeStore.JdbcExpiringCodeMapper(), "test");
jdbcTemplate.queryForObject(JdbcExpiringCodeStore.selectAllFields,
(RowMapper<ExpiringCode>) ReflectionTestUtils.getField(expiringCodeStore, "rowMapper"), "test");
} else {
throw new EmptyResultDataAccessException(1);
}
Expand Down

0 comments on commit bbf6751

Please sign in to comment.