diff --git a/spring-session-core/src/main/java/org/springframework/session/MapSession.java b/spring-session-core/src/main/java/org/springframework/session/MapSession.java index bd2413575..cafdc1ff9 100644 --- a/spring-session-core/src/main/java/org/springframework/session/MapSession.java +++ b/spring-session-core/src/main/java/org/springframework/session/MapSession.java @@ -25,6 +25,8 @@ import java.util.Set; import java.util.UUID; +import org.springframework.util.Assert; + /** *

* A {@link Session} implementation that is backed by a {@link java.util.Map}. The @@ -74,6 +76,9 @@ public final class MapSession implements Session, Serializable { */ private Duration maxInactiveInterval = DEFAULT_MAX_INACTIVE_INTERVAL; + private transient SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy + .getInstance(); + /** * Creates a new instance with a secure randomly generated identifier. */ @@ -81,6 +86,17 @@ public MapSession() { this(generateId()); } + /** + * Creates a new instance using the specified {@link SessionIdGenerationStrategy} to + * generate the session id. + * @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use. + * @since 3.2 + */ + public MapSession(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + this(sessionIdGenerationStrategy.generate()); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + /** * Creates a new instance with the specified id. This is preferred to the default * constructor when the id is known to prevent unnecessary consumption on entropy @@ -141,7 +157,7 @@ public String getOriginalId() { @Override public String changeSessionId() { - String changedId = generateId(); + String changedId = this.sessionIdGenerationStrategy.generate(); setId(changedId); return changedId; } @@ -232,6 +248,17 @@ private static String generateId() { return UUID.randomUUID().toString(); } + /** + * Sets the {@link SessionIdGenerationStrategy} to use when generating a new session + * id. + * @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use. + * @since 3.2 + */ + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null"); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + private static final long serialVersionUID = 7160779239673823561L; } diff --git a/spring-session-core/src/main/java/org/springframework/session/MapSessionRepository.java b/spring-session-core/src/main/java/org/springframework/session/MapSessionRepository.java index bff80f36a..cea27dcf1 100644 --- a/spring-session-core/src/main/java/org/springframework/session/MapSessionRepository.java +++ b/spring-session-core/src/main/java/org/springframework/session/MapSessionRepository.java @@ -43,6 +43,8 @@ public class MapSessionRepository implements SessionRepository { private final Map sessions; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + /** * Creates a new instance backed by the provided {@link java.util.Map}. This allows * injecting a distributed {@link java.util.Map}. @@ -94,9 +96,14 @@ public void deleteById(String id) { @Override public MapSession createSession() { - MapSession result = new MapSession(); + MapSession result = new MapSession(this.sessionIdGenerationStrategy.generate()); result.setMaxInactiveInterval(this.defaultMaxInactiveInterval); return result; } + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null"); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + } diff --git a/spring-session-core/src/main/java/org/springframework/session/ReactiveMapSessionRepository.java b/spring-session-core/src/main/java/org/springframework/session/ReactiveMapSessionRepository.java index 3bd1cb030..bddeca558 100644 --- a/spring-session-core/src/main/java/org/springframework/session/ReactiveMapSessionRepository.java +++ b/spring-session-core/src/main/java/org/springframework/session/ReactiveMapSessionRepository.java @@ -45,6 +45,8 @@ public class ReactiveMapSessionRepository implements ReactiveSessionRepository sessions; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + /** * Creates a new instance backed by the provided {@link Map}. This allows injecting a * distributed {@link Map}. @@ -84,6 +86,7 @@ public Mono findById(String id) { return Mono.defer(() -> Mono.justOrEmpty(this.sessions.get(id)) .filter((session) -> !session.isExpired()) .map(MapSession::new) + .doOnNext((session) -> session.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy)) .switchIfEmpty(deleteById(id).then(Mono.empty()))); // @formatter:on } @@ -96,10 +99,21 @@ public Mono deleteById(String id) { @Override public Mono createSession() { return Mono.defer(() -> { - MapSession result = new MapSession(); + MapSession result = new MapSession(this.sessionIdGenerationStrategy); result.setMaxInactiveInterval(this.defaultMaxInactiveInterval); return Mono.just(result); }); } + /** + * Sets the {@link SessionIdGenerationStrategy} to use. + * @param sessionIdGenerationStrategy the non-null {@link SessionIdGenerationStrategy} + * to use + * @since 3.2 + */ + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null"); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + } diff --git a/spring-session-core/src/main/java/org/springframework/session/SessionIdGenerationStrategy.java b/spring-session-core/src/main/java/org/springframework/session/SessionIdGenerationStrategy.java new file mode 100644 index 000000000..cc20530f6 --- /dev/null +++ b/spring-session-core/src/main/java/org/springframework/session/SessionIdGenerationStrategy.java @@ -0,0 +1,32 @@ +/* + * Copyright 2014-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.session; + +import org.springframework.lang.NonNull; + +/** + * An interface for specifying a strategy for generating session identifiers. + * + * @author Marcus da Coregio + * @since 3.2 + */ +public interface SessionIdGenerationStrategy { + + @NonNull + String generate(); + +} diff --git a/spring-session-core/src/main/java/org/springframework/session/UuidSessionIdGenerationStrategy.java b/spring-session-core/src/main/java/org/springframework/session/UuidSessionIdGenerationStrategy.java new file mode 100644 index 000000000..42316f170 --- /dev/null +++ b/spring-session-core/src/main/java/org/springframework/session/UuidSessionIdGenerationStrategy.java @@ -0,0 +1,51 @@ +/* + * Copyright 2014-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.session; + +import java.util.UUID; + +import org.springframework.lang.NonNull; + +/** + * A {@link SessionIdGenerationStrategy} that generates a random UUID to be used as the + * session id. + * + * @author Marcus da Coregio + * @since 3.2 + */ +public final class UuidSessionIdGenerationStrategy implements SessionIdGenerationStrategy { + + private static final UuidSessionIdGenerationStrategy INSTANCE = new UuidSessionIdGenerationStrategy(); + + private UuidSessionIdGenerationStrategy() { + } + + @Override + @NonNull + public String generate() { + return UUID.randomUUID().toString(); + } + + /** + * Returns the singleton instance of {@link UuidSessionIdGenerationStrategy}. + * @return the singleton instance of {@link UuidSessionIdGenerationStrategy} + */ + public static UuidSessionIdGenerationStrategy getInstance() { + return INSTANCE; + } + +} diff --git a/spring-session-core/src/test/java/org/springframework/session/MapSessionTests.java b/spring-session-core/src/test/java/org/springframework/session/MapSessionTests.java index 36221f5e9..09ac62db3 100644 --- a/spring-session-core/src/test/java/org/springframework/session/MapSessionTests.java +++ b/spring-session-core/src/test/java/org/springframework/session/MapSessionTests.java @@ -19,6 +19,7 @@ import java.time.Duration; import java.time.Instant; import java.util.Set; +import java.util.UUID; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -42,6 +43,19 @@ void constructorNullSession() { .withMessage("session cannot be null"); } + @Test + void constructorWhenSessionIdGenerationStrategyThenUsesStrategy() { + MapSession session = new MapSession(new FixedSessionIdGenerationStrategy("my-id")); + assertThat(session.getId()).isEqualTo("my-id"); + } + + @Test + void constructorWhenDefaultThenUuid() { + String id = this.session.getId(); + UUID uuid = UUID.fromString(id); + assertThat(uuid).isNotNull(); + } + @Test void getAttributeWhenNullThenNull() { String result = this.session.getAttribute("attrName"); @@ -143,6 +157,41 @@ void getAttributeNamesAndRemove() { assertThat(this.session.getAttributeNames()).isEmpty(); } + @Test + void changeSessionIdWhenSessionIdStrategyThenUsesStrategy() { + MapSession session = new MapSession(new IncrementalSessionIdGenerationStrategy()); + String idBeforeChange = session.getId(); + String idAfterChange = session.changeSessionId(); + assertThat(idBeforeChange).isEqualTo("1"); + assertThat(idAfterChange).isEqualTo("2"); + } + + static class FixedSessionIdGenerationStrategy implements SessionIdGenerationStrategy { + + private final String id; + + FixedSessionIdGenerationStrategy(String id) { + this.id = id; + } + + @Override + public String generate() { + return this.id; + } + + } + + static class IncrementalSessionIdGenerationStrategy implements SessionIdGenerationStrategy { + + private int counter = 1; + + @Override + public String generate() { + return String.valueOf(this.counter++); + } + + } + static class CustomSession implements Session { @Override diff --git a/spring-session-core/src/test/java/org/springframework/session/ReactiveMapSessionRepositoryTests.java b/spring-session-core/src/test/java/org/springframework/session/ReactiveMapSessionRepositoryTests.java index ee1478f22..9210b1f2a 100644 --- a/spring-session-core/src/test/java/org/springframework/session/ReactiveMapSessionRepositoryTests.java +++ b/spring-session-core/src/test/java/org/springframework/session/ReactiveMapSessionRepositoryTests.java @@ -152,4 +152,31 @@ void getAttributeNamesAndRemove() { assertThat(session.getAttributeNames()).isEmpty(); } + @Test + void createSessionWhenSessionIdGenerationStrategyThenUses() { + this.repository.setSessionIdGenerationStrategy(() -> "test"); + MapSession session = this.repository.createSession().block(); + assertThat(session.getId()).isEqualTo("test"); + assertThat(session.changeSessionId()).isEqualTo("test"); + } + + @Test + void setSessionIdGenerationStrategyWhenNullThenThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setSessionIdGenerationStrategy(null)) + .withMessage("sessionIdGenerationStrategy cannot be null"); + } + + @Test + void findByIdWhenChangeSessionIdThenUsesSessionIdGenerationStrategy() { + this.repository.setSessionIdGenerationStrategy(() -> "test"); + + MapSession session = this.repository.createSession().block(); + this.repository.save(session).block(); + + MapSession savedSession = this.repository.findById("test").block(); + + assertThat(savedSession.getId()).isEqualTo("test"); + assertThat(savedSession.changeSessionId()).isEqualTo("test"); + } + } diff --git a/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/MongoIndexedSessionRepository.java b/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/MongoIndexedSessionRepository.java index 3c5f5aff4..18cf41d57 100644 --- a/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/MongoIndexedSessionRepository.java +++ b/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/MongoIndexedSessionRepository.java @@ -36,6 +36,8 @@ import org.springframework.lang.Nullable; import org.springframework.session.FindByIndexNameSessionRepository; import org.springframework.session.MapSession; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.events.SessionCreatedEvent; import org.springframework.session.events.SessionDeletedEvent; import org.springframework.session.events.SessionExpiredEvent; @@ -81,6 +83,8 @@ public class MongoIndexedSessionRepository private ApplicationEventPublisher eventPublisher; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + public MongoIndexedSessionRepository(MongoOperations mongoOperations) { this.mongoOperations = mongoOperations; } @@ -88,7 +92,7 @@ public MongoIndexedSessionRepository(MongoOperations mongoOperations) { @Override public MongoSession createSession() { - MongoSession session = new MongoSession(); + MongoSession session = new MongoSession(this.sessionIdGenerationStrategy); session.setMaxInactiveInterval(this.defaultMaxInactiveInterval); @@ -116,10 +120,13 @@ public MongoSession findById(String id) { MongoSession session = MongoSessionUtils.convertToSession(this.mongoSessionConverter, sessionWrapper); - if (session != null && session.isExpired()) { - publishEvent(new SessionExpiredEvent(this, session)); - deleteById(id); - return null; + if (session != null) { + if (session.isExpired()) { + publishEvent(new SessionExpiredEvent(this, session)); + deleteById(id); + return null; + } + session.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy); } return session; @@ -140,6 +147,7 @@ public Map findByIndexNameAndIndexValue(String indexName, .map((query) -> this.mongoOperations.find(query, Document.class, this.collectionName)) .orElse(Collections.emptyList()).stream() .map((dbSession) -> MongoSessionUtils.convertToSession(this.mongoSessionConverter, dbSession)) + .peek((session) -> session.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy)) .collect(Collectors.toMap(MongoSession::getId, (mapSession) -> mapSession)); } @@ -216,4 +224,14 @@ public void setMongoSessionConverter(final AbstractMongoSessionConverter mongoSe this.mongoSessionConverter = mongoSessionConverter; } + /** + * Set the {@link SessionIdGenerationStrategy} to use to generate session ids. + * @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use + * @since 3.2 + */ + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null"); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + } diff --git a/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/MongoSession.java b/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/MongoSession.java index 39a47ffc1..b3d8cc919 100644 --- a/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/MongoSession.java +++ b/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/MongoSession.java @@ -23,12 +23,14 @@ import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.UUID; import java.util.stream.Collectors; import org.springframework.lang.Nullable; import org.springframework.session.MapSession; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; +import org.springframework.util.Assert; /** * Session object providing additional information about the datetime of expiration. @@ -66,12 +68,24 @@ class MongoSession implements Session { private Map attrs = new HashMap<>(); + private transient SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy + .getInstance(); + + /** + * Constructs a new instance using the provided session id. + * @param sessionId the session id to use + * @since 3.2 + */ + MongoSession(String sessionId) { + this(sessionId, MapSession.DEFAULT_MAX_INACTIVE_INTERVAL_SECONDS); + } + MongoSession() { this(MapSession.DEFAULT_MAX_INACTIVE_INTERVAL_SECONDS); } MongoSession(long maxInactiveIntervalInSeconds) { - this(UUID.randomUUID().toString(), maxInactiveIntervalInSeconds); + this(UuidSessionIdGenerationStrategy.getInstance().generate(), maxInactiveIntervalInSeconds); } MongoSession(String id, long maxInactiveIntervalInSeconds) { @@ -82,6 +96,28 @@ class MongoSession implements Session { setLastAccessedTime(Instant.ofEpochMilli(this.createdMillis)); } + /** + * Constructs a new instance using the provided {@link SessionIdGenerationStrategy}. + * @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use + * @since 3.2 + */ + MongoSession(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + this(sessionIdGenerationStrategy.generate(), MapSession.DEFAULT_MAX_INACTIVE_INTERVAL_SECONDS); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + + /** + * Constructs a new instance using the provided {@link SessionIdGenerationStrategy} + * and max inactive interval. + * @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use + * @param maxInactiveIntervalInSeconds the max inactive interval in seconds + * @since 3.2 + */ + MongoSession(SessionIdGenerationStrategy sessionIdGenerationStrategy, long maxInactiveIntervalInSeconds) { + this(sessionIdGenerationStrategy.generate(), maxInactiveIntervalInSeconds); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + static String coverDot(String attributeName) { return attributeName.replace('.', DOT_COVER_CHAR); } @@ -93,7 +129,7 @@ static String uncoverDot(String attributeName) { @Override public String changeSessionId() { - String changedId = UUID.randomUUID().toString(); + String changedId = this.sessionIdGenerationStrategy.generate(); this.id = changedId; return changedId; } @@ -141,7 +177,6 @@ public Instant getLastAccessedTime() { @Override public void setLastAccessedTime(Instant lastAccessedTime) { - this.accessedMillis = lastAccessedTime.toEpochMilli(); this.expireAt = Date.from(lastAccessedTime.plus(Duration.ofSeconds(this.intervalSeconds))); } @@ -200,4 +235,23 @@ String getOriginalSessionId() { return this.originalSessionId; } + /** + * Sets the session id. + * @param id the id to set + * @since 3.2 + */ + void setId(String id) { + this.id = id; + } + + /** + * Sets the {@link SessionIdGenerationStrategy} to use. + * @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use + * @since 3.2 + */ + void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null"); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + } diff --git a/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/ReactiveMongoSessionRepository.java b/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/ReactiveMongoSessionRepository.java index 71c06c351..36c6f09e9 100644 --- a/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/ReactiveMongoSessionRepository.java +++ b/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/ReactiveMongoSessionRepository.java @@ -22,6 +22,7 @@ import org.apache.commons.logging.LogFactory; import org.bson.Document; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; import org.springframework.beans.factory.InitializingBean; import org.springframework.context.ApplicationEvent; @@ -34,6 +35,8 @@ import org.springframework.data.mongodb.core.query.Query; import org.springframework.session.MapSession; import org.springframework.session.ReactiveSessionRepository; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.events.SessionCreatedEvent; import org.springframework.session.events.SessionDeletedEvent; import org.springframework.util.Assert; @@ -76,6 +79,8 @@ public class ReactiveMongoSessionRepository private ApplicationEventPublisher eventPublisher; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + public ReactiveMongoSessionRepository(ReactiveMongoOperations mongoOperations) { this.mongoOperations = mongoOperations; } @@ -93,11 +98,17 @@ public ReactiveMongoSessionRepository(ReactiveMongoOperations mongoOperations) { */ @Override public Mono createSession() { - - return Mono.justOrEmpty(this.defaultMaxInactiveInterval.toSeconds()) // - .map(MongoSession::new) // - .doOnNext((mongoSession) -> publishEvent(new SessionCreatedEvent(this, mongoSession))) // - .switchIfEmpty(Mono.just(new MongoSession())); + // @formatter:off + return Mono.fromSupplier(() -> this.sessionIdGenerationStrategy.generate()) + .map(MongoSession::new) + .doOnNext((mongoSession) -> mongoSession.setMaxInactiveInterval(this.defaultMaxInactiveInterval)) + .doOnNext( + (mongoSession) -> mongoSession.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy)) + .doOnNext((mongoSession) -> publishEvent(new SessionCreatedEvent(this, mongoSession))) + .switchIfEmpty(Mono.just(new MongoSession(this.sessionIdGenerationStrategy))) + .subscribeOn(Schedulers.boundedElastic()) + .publishOn(Schedulers.parallel()); + // @formatter:on } @Override @@ -127,6 +138,8 @@ public Mono findById(String id) { return findSession(id) // .map((document) -> MongoSessionUtils.convertToSession(this.mongoSessionConverter, document)) // .filter((mongoSession) -> !mongoSession.isExpired()) // + .doOnNext( + (mongoSession) -> mongoSession.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy)) .switchIfEmpty(Mono.defer(() -> this.deleteById(id).then(Mono.empty()))); } @@ -216,4 +229,9 @@ public void setBlockingMongoOperations(final MongoOperations blockingMongoOperat this.blockingMongoOperations = blockingMongoOperations; } + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null"); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + } diff --git a/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/config/annotation/web/http/MongoHttpSessionConfiguration.java b/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/config/annotation/web/http/MongoHttpSessionConfiguration.java index 3a920d05a..8ca28358e 100644 --- a/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/config/annotation/web/http/MongoHttpSessionConfiguration.java +++ b/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/config/annotation/web/http/MongoHttpSessionConfiguration.java @@ -36,6 +36,8 @@ import org.springframework.session.IndexResolver; import org.springframework.session.MapSession; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.config.SessionRepositoryCustomizer; import org.springframework.session.config.annotation.web.http.SpringHttpSessionConfiguration; import org.springframework.session.data.mongo.AbstractMongoSessionConverter; @@ -70,6 +72,8 @@ public class MongoHttpSessionConfiguration implements BeanClassLoaderAware, Embe private IndexResolver indexResolver; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + @Bean public MongoIndexedSessionRepository mongoSessionRepository(MongoOperations mongoOperations) { @@ -98,6 +102,7 @@ public MongoIndexedSessionRepository mongoSessionRepository(MongoOperations mong if (StringUtils.hasText(this.collectionName)) { repository.setCollectionName(this.collectionName); } + repository.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy); this.sessionRepositoryCustomizers .forEach((sessionRepositoryCustomizer) -> sessionRepositoryCustomizer.customize(repository)); @@ -160,4 +165,9 @@ public void setIndexResolver(IndexResolver indexResolver) { this.indexResolver = indexResolver; } + @Autowired(required = false) + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + } diff --git a/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/config/annotation/web/reactive/ReactiveMongoWebSessionConfiguration.java b/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/config/annotation/web/reactive/ReactiveMongoWebSessionConfiguration.java index 197d599a8..88f806f3d 100644 --- a/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/config/annotation/web/reactive/ReactiveMongoWebSessionConfiguration.java +++ b/spring-session-data-mongodb/src/main/java/org/springframework/session/data/mongo/config/annotation/web/reactive/ReactiveMongoWebSessionConfiguration.java @@ -37,6 +37,8 @@ import org.springframework.session.IndexResolver; import org.springframework.session.MapSession; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.config.ReactiveSessionRepositoryCustomizer; import org.springframework.session.config.annotation.web.server.SpringWebSessionConfiguration; import org.springframework.session.data.mongo.AbstractMongoSessionConverter; @@ -74,6 +76,8 @@ public class ReactiveMongoWebSessionConfiguration private IndexResolver indexResolver; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + @Bean public ReactiveMongoSessionRepository reactiveMongoSessionRepository(ReactiveMongoOperations operations) { @@ -112,6 +116,8 @@ public ReactiveMongoSessionRepository reactiveMongoSessionRepository(ReactiveMon this.sessionRepositoryCustomizers .forEach((sessionRepositoryCustomizer) -> sessionRepositoryCustomizer.customize(repository)); + repository.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy); + return repository; } @@ -180,4 +186,9 @@ public void setIndexResolver(IndexResolver indexResolver) { this.indexResolver = indexResolver; } + @Autowired(required = false) + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + } diff --git a/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/MongoIndexedSessionRepositoryTest.java b/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/MongoIndexedSessionRepositoryTest.java index c33ce79d9..90c0986bb 100644 --- a/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/MongoIndexedSessionRepositoryTest.java +++ b/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/MongoIndexedSessionRepositoryTest.java @@ -34,8 +34,10 @@ import org.springframework.data.mongodb.core.query.Query; import org.springframework.session.FindByIndexNameSessionRepository; import org.springframework.session.MapSession; +import org.springframework.session.SessionIdGenerationStrategy; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; @@ -209,4 +211,52 @@ void shouldReturnEmptyMapForNotSupportedIndex() { assertThat(sessionsMap).isEmpty(); } + @Test + void createSessionWhenSessionIdGenerationStrategyThenUses() { + this.repository.setSessionIdGenerationStrategy(new FixedSessionIdGenerationStrategy("123")); + MongoSession session = this.repository.createSession(); + assertThat(session.getId()).isEqualTo("123"); + assertThat(session.changeSessionId()).isEqualTo("123"); + } + + @Test + void setSessionIdGenerationStrategyWhenNullThenThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setSessionIdGenerationStrategy(null)); + } + + @Test + void findByIdWhenChangeSessionIdThenUsesSessionIdGenerationStrategy() { + this.repository.setSessionIdGenerationStrategy(new FixedSessionIdGenerationStrategy("456")); + + Document sessionDocument = new Document(); + + given(this.mongoOperations.findById("123", Document.class, + MongoIndexedSessionRepository.DEFAULT_COLLECTION_NAME)).willReturn(sessionDocument); + + MongoSession session = new MongoSession("123"); + + given(this.converter.convert(sessionDocument, TypeDescriptor.valueOf(Document.class), + TypeDescriptor.valueOf(MongoSession.class))).willReturn(session); + + MongoSession retrievedSession = this.repository.findById("123"); + assertThat(retrievedSession.getId()).isEqualTo("123"); + String newSessionId = retrievedSession.changeSessionId(); + assertThat(newSessionId).isEqualTo("456"); + } + + static class FixedSessionIdGenerationStrategy implements SessionIdGenerationStrategy { + + private final String id; + + FixedSessionIdGenerationStrategy(String id) { + this.id = id; + } + + @Override + public String generate() { + return this.id; + } + + } + } diff --git a/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/ReactiveMongoSessionRepositoryTest.java b/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/ReactiveMongoSessionRepositoryTest.java index ed6966a79..afeffc055 100644 --- a/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/ReactiveMongoSessionRepositoryTest.java +++ b/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/ReactiveMongoSessionRepositoryTest.java @@ -40,6 +40,7 @@ import org.springframework.session.events.SessionDeletedEvent; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.BDDMockito.any; import static org.mockito.BDDMockito.eq; import static org.mockito.BDDMockito.given; @@ -210,4 +211,43 @@ void shouldInvokeMethodToCreateIndexesImperatively() { verify(this.converter, times(1)).ensureIndexes(indexOperations); } + @Test + void createSessionWhenSessionIdGenerationStrategyThenUses() { + this.repository.setSessionIdGenerationStrategy(() -> "test"); + + this.repository.createSession().as(StepVerifier::create).assertNext((mongoSession) -> { + assertThat(mongoSession.getId()).isEqualTo("test"); + assertThat(mongoSession.changeSessionId()).isEqualTo("test"); + }).verifyComplete(); + } + + @Test + void setSessionIdGenerationStrategyWhenNullThenThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setSessionIdGenerationStrategy(null)) + .withMessage("sessionIdGenerationStrategy cannot be null"); + } + + @Test + void findByIdWhenChangeSessionIdThenUsesSessionIdGenerationStrategy() { + this.repository.setSessionIdGenerationStrategy(() -> "test"); + + String sessionId = UUID.randomUUID().toString(); + Document sessionDocument = new Document(); + + given(this.mongoOperations.findById(sessionId, Document.class, + ReactiveMongoSessionRepository.DEFAULT_COLLECTION_NAME)).willReturn(Mono.just(sessionDocument)); + + MongoSession session = new MongoSession(sessionId); + + given(this.converter.convert(sessionDocument, TypeDescriptor.valueOf(Document.class), + TypeDescriptor.valueOf(MongoSession.class))).willReturn(session); + + this.repository.findById(sessionId).as(StepVerifier::create).assertNext((mongoSession) -> { + String oldId = mongoSession.getId(); + String newId = mongoSession.changeSessionId(); + assertThat(oldId).isEqualTo(sessionId); + assertThat(newId).isEqualTo("test"); + }).verifyComplete(); + } + } diff --git a/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/config/annotation/web/http/MongoHttpSessionConfigurationTest.java b/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/config/annotation/web/http/MongoHttpSessionConfigurationTest.java index 6612628da..238a335b2 100644 --- a/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/config/annotation/web/http/MongoHttpSessionConfigurationTest.java +++ b/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/config/annotation/web/http/MongoHttpSessionConfigurationTest.java @@ -34,6 +34,8 @@ import org.springframework.mock.env.MockEnvironment; import org.springframework.session.IndexResolver; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.config.SessionRepositoryCustomizer; import org.springframework.session.data.mongo.AbstractMongoSessionConverter; import org.springframework.session.data.mongo.JacksonMongoSessionConverter; @@ -200,6 +202,22 @@ void importConfigAndCustomize() { assertThat(sessionRepository).extracting("defaultMaxInactiveInterval").isEqualTo(Duration.ZERO); } + @Test + void registerWhenSessionIdGenerationStrategyBeanThenUses() { + registerAndRefresh(SessionIdGenerationStrategyConfiguration.class); + MongoIndexedSessionRepository sessionRepository = this.context.getBean(MongoIndexedSessionRepository.class); + assertThat(sessionRepository).extracting("sessionIdGenerationStrategy") + .isInstanceOf(TestSessionIdGenerationStrategy.class); + } + + @Test + void registerWhenNoSessionIdGenerationStrategyBeanThenDefault() { + registerAndRefresh(DefaultConfiguration.class); + MongoIndexedSessionRepository sessionRepository = this.context.getBean(MongoIndexedSessionRepository.class); + assertThat(sessionRepository).extracting("sessionIdGenerationStrategy") + .isInstanceOf(UuidSessionIdGenerationStrategy.class); + } + private void registerAndRefresh(Class... annotatedClasses) { this.context.register(annotatedClasses); @@ -350,4 +368,25 @@ SessionRepositoryCustomizer sessionRepositoryCust } + @Configuration(proxyBeanMethods = false) + @EnableMongoHttpSession + @Import(MongoConfiguration.class) + static class SessionIdGenerationStrategyConfiguration { + + @Bean + SessionIdGenerationStrategy sessionIdGenerationStrategy() { + return new TestSessionIdGenerationStrategy(); + } + + } + + static class TestSessionIdGenerationStrategy implements SessionIdGenerationStrategy { + + @Override + public String generate() { + return "test"; + } + + } + } diff --git a/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/config/annotation/web/reactive/ReactiveMongoWebSessionConfigurationTest.java b/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/config/annotation/web/reactive/ReactiveMongoWebSessionConfigurationTest.java index 46f2029d5..6b8f34cbc 100644 --- a/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/config/annotation/web/reactive/ReactiveMongoWebSessionConfigurationTest.java +++ b/spring-session-data-mongodb/src/test/java/org/springframework/session/data/mongo/config/annotation/web/reactive/ReactiveMongoWebSessionConfigurationTest.java @@ -35,6 +35,8 @@ import org.springframework.session.IndexResolver; import org.springframework.session.ReactiveSessionRepository; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.config.ReactiveSessionRepositoryCustomizer; import org.springframework.session.config.annotation.web.server.EnableSpringWebSession; import org.springframework.session.data.mongo.AbstractMongoSessionConverter; @@ -222,6 +224,28 @@ void importConfigAndCustomize() { assertThat(sessionRepository).extracting("defaultMaxInactiveInterval").isEqualTo(Duration.ZERO); } + @Test + void registerWhenSessionIdGenerationStrategyBeanThenUses() { + registerAndRefresh(GoodConfig.class, SessionIdGenerationStrategyConfiguration.class); + ReactiveMongoSessionRepository sessionRepository = this.context.getBean(ReactiveMongoSessionRepository.class); + assertThat(sessionRepository).extracting("sessionIdGenerationStrategy") + .isInstanceOf(TestSessionIdGenerationStrategy.class); + } + + @Test + void registerWhenNoSessionIdGenerationStrategyBeanThenDefault() { + registerAndRefresh(GoodConfig.class); + ReactiveMongoSessionRepository sessionRepository = this.context.getBean(ReactiveMongoSessionRepository.class); + assertThat(sessionRepository).extracting("sessionIdGenerationStrategy") + .isInstanceOf(UuidSessionIdGenerationStrategy.class); + } + + private void registerAndRefresh(Class... annotatedClasses) { + this.context = new AnnotationConfigApplicationContext(); + this.context.register(annotatedClasses); + this.context.refresh(); + } + /** * Reflectively extract the {@link AbstractMongoSessionConverter} from the * {@link ReactiveMongoSessionRepository}. This is to avoid expanding the surface area @@ -411,4 +435,23 @@ ReactiveSessionRepositoryCustomizer sessionRepos } + @Configuration(proxyBeanMethods = false) + static class SessionIdGenerationStrategyConfiguration { + + @Bean + SessionIdGenerationStrategy sessionIdGenerationStrategy() { + return new TestSessionIdGenerationStrategy(); + } + + } + + static class TestSessionIdGenerationStrategy implements SessionIdGenerationStrategy { + + @Override + public String generate() { + return "test"; + } + + } + } diff --git a/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/ReactiveRedisSessionRepository.java b/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/ReactiveRedisSessionRepository.java index df749a8cf..fbc568f9e 100644 --- a/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/ReactiveRedisSessionRepository.java +++ b/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/ReactiveRedisSessionRepository.java @@ -31,6 +31,8 @@ import org.springframework.session.ReactiveSessionRepository; import org.springframework.session.SaveMode; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -61,6 +63,8 @@ public class ReactiveRedisSessionRepository private SaveMode saveMode = SaveMode.ON_SET_ATTRIBUTE; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + /** * Create a new {@link ReactiveRedisSessionRepository} instance. * @param sessionRedisOperations the {@link ReactiveRedisOperations} to use for @@ -120,7 +124,7 @@ public ReactiveRedisOperations getSessionRedisOperations() { @Override public Mono createSession() { return Mono.defer(() -> { - MapSession cached = new MapSession(); + MapSession cached = new MapSession(this.sessionIdGenerationStrategy); cached.setMaxInactiveInterval(this.defaultMaxInactiveInterval); RedisSession session = new RedisSession(cached, true); return Mono.just(session); @@ -167,6 +171,16 @@ private String getSessionKey(String sessionId) { return this.namespace + "sessions:" + sessionId; } + /** + * Set the {@link SessionIdGenerationStrategy} to use to generate session ids. + * @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use + * @since 3.2 + */ + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null"); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + /** * A custom implementation of {@link Session} that uses a {@link MapSession} as the * basis for its mapping. It keeps track of any attributes that have changed. When @@ -206,7 +220,9 @@ public String getId() { @Override public String changeSessionId() { - return this.cached.changeSessionId(); + String newSessionId = ReactiveRedisSessionRepository.this.sessionIdGenerationStrategy.generate(); + this.cached.setId(newSessionId); + return newSessionId; } @Override diff --git a/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/RedisIndexedSessionRepository.java b/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/RedisIndexedSessionRepository.java index 48ef4f919..b0275757f 100644 --- a/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/RedisIndexedSessionRepository.java +++ b/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/RedisIndexedSessionRepository.java @@ -52,6 +52,8 @@ import org.springframework.session.PrincipalNameIndexResolver; import org.springframework.session.SaveMode; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.events.SessionCreatedEvent; import org.springframework.session.events.SessionDeletedEvent; import org.springframework.session.events.SessionDestroyedEvent; @@ -322,6 +324,8 @@ public class RedisIndexedSessionRepository private ThreadPoolTaskScheduler taskScheduler; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + /** * Creates a new instance. For an example, refer to the class level javadoc. * @param sessionRedisOperations the {@link RedisOperations} to use for managing the @@ -547,7 +551,7 @@ public void deleteById(String sessionId) { @Override public RedisSession createSession() { - MapSession cached = new MapSession(); + MapSession cached = new MapSession(this.sessionIdGenerationStrategy); cached.setMaxInactiveInterval(this.defaultMaxInactiveInterval); RedisSession session = new RedisSession(cached, true); session.flushImmediateIfNecessary(); @@ -716,6 +720,16 @@ static String getSessionAttrNameKey(String attributeName) { return RedisSessionMapper.ATTRIBUTE_PREFIX + attributeName; } + /** + * Set the {@link SessionIdGenerationStrategy} to use to generate session ids. + * @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use + * @since 3.2 + */ + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null"); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + /** * A custom implementation of {@link Session} that uses a {@link MapSession} as the * basis for its mapping. It keeps track of any attributes that have changed. When @@ -780,7 +794,9 @@ public String getId() { @Override public String changeSessionId() { - return this.cached.changeSessionId(); + String newSessionId = RedisIndexedSessionRepository.this.sessionIdGenerationStrategy.generate(); + this.cached.setId(newSessionId); + return newSessionId; } @Override diff --git a/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/RedisSessionRepository.java b/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/RedisSessionRepository.java index d2408a701..0320e5e2b 100644 --- a/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/RedisSessionRepository.java +++ b/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/RedisSessionRepository.java @@ -27,7 +27,9 @@ import org.springframework.session.MapSession; import org.springframework.session.SaveMode; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; import org.springframework.session.SessionRepository; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.util.Assert; /** @@ -56,6 +58,8 @@ public class RedisSessionRepository implements SessionRepository sessionRepositoryCustomizer.customize(sessionRepository)); return sessionRepository; @@ -87,4 +93,9 @@ public void setImportMetadata(AnnotationMetadata importMetadata) { setSaveMode(attributes.getEnum("saveMode")); } + @Autowired(required = false) + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + } diff --git a/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/config/annotation/web/http/RedisIndexedHttpSessionConfiguration.java b/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/config/annotation/web/http/RedisIndexedHttpSessionConfiguration.java index c6c4243d5..1873a17c9 100644 --- a/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/config/annotation/web/http/RedisIndexedHttpSessionConfiguration.java +++ b/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/config/annotation/web/http/RedisIndexedHttpSessionConfiguration.java @@ -44,6 +44,8 @@ import org.springframework.data.redis.listener.RedisMessageListenerContainer; import org.springframework.session.IndexResolver; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.data.redis.RedisIndexedSessionRepository; import org.springframework.session.data.redis.config.ConfigureNotifyKeyspaceEventsAction; import org.springframework.session.data.redis.config.ConfigureRedisAction; @@ -80,6 +82,8 @@ public class RedisIndexedHttpSessionConfiguration private StringValueResolver embeddedValueResolver; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + @Bean @Override public RedisIndexedSessionRepository sessionRepository() { @@ -101,6 +105,7 @@ public RedisIndexedSessionRepository sessionRepository() { sessionRepository.setCleanupCron(this.cleanupCron); int database = resolveDatabase(); sessionRepository.setDatabase(database); + sessionRepository.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy); getSessionRepositoryCustomizers() .forEach((sessionRepositoryCustomizer) -> sessionRepositoryCustomizer.customize(sessionRepository)); return sessionRepository; @@ -204,6 +209,11 @@ && getRedisConnectionFactory() instanceof JedisConnectionFactory) { return RedisIndexedSessionRepository.DEFAULT_DATABASE; } + @Autowired(required = false) + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + /** * Ensures that Redis is configured to send keyspace notifications. This is important * to ensure that expiration and deletion of sessions trigger SessionDestroyedEvents. diff --git a/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/config/annotation/web/server/RedisWebSessionConfiguration.java b/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/config/annotation/web/server/RedisWebSessionConfiguration.java index d7f140f79..5bbb89655 100644 --- a/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/config/annotation/web/server/RedisWebSessionConfiguration.java +++ b/spring-session-data-redis/src/main/java/org/springframework/session/data/redis/config/annotation/web/server/RedisWebSessionConfiguration.java @@ -39,6 +39,8 @@ import org.springframework.data.redis.serializer.RedisSerializer; import org.springframework.session.MapSession; import org.springframework.session.SaveMode; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.config.ReactiveSessionRepositoryCustomizer; import org.springframework.session.config.annotation.web.server.SpringWebSessionConfiguration; import org.springframework.session.data.redis.ReactiveRedisSessionRepository; @@ -77,6 +79,8 @@ public class RedisWebSessionConfiguration implements BeanClassLoaderAware, Embed private StringValueResolver embeddedValueResolver; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + @Bean public ReactiveRedisSessionRepository sessionRepository() { ReactiveRedisTemplate reactiveRedisTemplate = createReactiveRedisTemplate(); @@ -86,6 +90,7 @@ public ReactiveRedisSessionRepository sessionRepository() { sessionRepository.setRedisKeyNamespace(this.redisNamespace); } sessionRepository.setSaveMode(this.saveMode); + sessionRepository.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy); this.sessionRepositoryCustomizers .forEach((sessionRepositoryCustomizer) -> sessionRepositoryCustomizer.customize(sessionRepository)); return sessionRepository; @@ -168,4 +173,9 @@ private ReactiveRedisTemplate createReactiveRedisTemplate() { return new ReactiveRedisTemplate<>(this.redisConnectionFactory, serializationContext); } + @Autowired(required = false) + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + } diff --git a/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/ReactiveRedisSessionRepositoryTests.java b/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/ReactiveRedisSessionRepositoryTests.java index a85b9d367..c46f7752a 100644 --- a/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/ReactiveRedisSessionRepositoryTests.java +++ b/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/ReactiveRedisSessionRepositoryTests.java @@ -438,6 +438,46 @@ void saveWithSaveModeAlways() { verifyNoMoreInteractions(this.hashOperations); } + @Test + void createSessionWhenSessionIdGenerationStrategyThenUses() { + this.repository.setSessionIdGenerationStrategy(() -> "test"); + + this.repository.createSession().as(StepVerifier::create).assertNext((redisSession) -> { + assertThat(redisSession.getId()).isEqualTo("test"); + assertThat(redisSession.changeSessionId()).isEqualTo("test"); + }).verifyComplete(); + } + + @Test + void setSessionIdGenerationStrategyWhenNullThenThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setSessionIdGenerationStrategy(null)) + .withMessage("sessionIdGenerationStrategy cannot be null"); + } + + @Test + @SuppressWarnings("unchecked") + void findByIdWhenChangeSessionIdThenUsesSessionIdGenerationStrategy() { + this.repository.setSessionIdGenerationStrategy(() -> "changed-session-id"); + given(this.redisOperations.opsForHash()).willReturn(this.hashOperations); + String attribute1 = "attribute1"; + String attribute2 = "attribute2"; + MapSession expected = new MapSession("test"); + expected.setLastAccessedTime(Instant.now().minusSeconds(60)); + expected.setAttribute(attribute1, "test"); + expected.setAttribute(attribute2, null); + Map map = map(RedisSessionMapper.ATTRIBUTE_PREFIX + attribute1, expected.getAttribute(attribute1), + RedisSessionMapper.ATTRIBUTE_PREFIX + attribute2, expected.getAttribute(attribute2), + RedisSessionMapper.CREATION_TIME_KEY, expected.getCreationTime().toEpochMilli(), + RedisSessionMapper.MAX_INACTIVE_INTERVAL_KEY, (int) expected.getMaxInactiveInterval().getSeconds(), + RedisSessionMapper.LAST_ACCESSED_TIME_KEY, expected.getLastAccessedTime().toEpochMilli()); + given(this.hashOperations.entries(anyString())).willReturn(Flux.fromIterable(map.entrySet())); + + StepVerifier.create(this.repository.findById("test")).consumeNextWith((session) -> { + assertThat(session.getId()).isEqualTo(expected.getId()); + assertThat(session.changeSessionId()).isEqualTo("changed-session-id"); + }).verifyComplete(); + } + private Map map(Object... objects) { Map result = new HashMap<>(); if (objects == null) { diff --git a/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/RedisIndexedSessionRepositoryTests.java b/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/RedisIndexedSessionRepositoryTests.java index 697749e60..223f56dde 100644 --- a/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/RedisIndexedSessionRepositoryTests.java +++ b/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/RedisIndexedSessionRepositoryTests.java @@ -910,6 +910,46 @@ void saveWithSaveModeAlways() { assertThat(getDelta()).hasSize(3); } + @Test + void createSessionWhenSessionIdGenerationStrategyThenUses() { + this.redisRepository.setSessionIdGenerationStrategy(() -> "test"); + RedisSession session = this.redisRepository.createSession(); + assertThat(session.getId()).isEqualTo("test"); + assertThat(session.changeSessionId()).isEqualTo("test"); + } + + @Test + void setSessionIdGenerationStrategyWhenNullThenThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.redisRepository.setSessionIdGenerationStrategy(null)) + .withMessage("sessionIdGenerationStrategy cannot be null"); + } + + @Test + void findByIdWhenChangeSessionIdThenUsesSessionIdGenerationStrategy() { + this.redisRepository.setSessionIdGenerationStrategy(() -> "test"); + String attribute1 = "attribute1"; + String attribute2 = "attribute2"; + MapSession expected = new MapSession("original"); + expected.setLastAccessedTime(Instant.now().minusSeconds(60)); + expected.setAttribute(attribute1, "test"); + expected.setAttribute(attribute2, null); + given(this.redisOperations.boundHashOps(getKey(expected.getId()))) + .willReturn(this.boundHashOperations); + Map map = map(RedisIndexedSessionRepository.getSessionAttrNameKey(attribute1), + expected.getAttribute(attribute1), RedisIndexedSessionRepository.getSessionAttrNameKey(attribute2), + expected.getAttribute(attribute2), RedisSessionMapper.CREATION_TIME_KEY, + expected.getCreationTime().toEpochMilli(), RedisSessionMapper.MAX_INACTIVE_INTERVAL_KEY, + (int) expected.getMaxInactiveInterval().getSeconds(), RedisSessionMapper.LAST_ACCESSED_TIME_KEY, + expected.getLastAccessedTime().toEpochMilli()); + given(this.boundHashOperations.entries()).willReturn(map); + + RedisSession session = this.redisRepository.findById(expected.getId()); + String oldSessionId = session.getId(); + String newSessionId = session.changeSessionId(); + assertThat(oldSessionId).isEqualTo("original"); + assertThat(newSessionId).isEqualTo("test"); + } + private String getKey(String id) { return "spring:session:sessions:" + id; } diff --git a/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/RedisSessionRepositoryTests.java b/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/RedisSessionRepositoryTests.java index 0092e217e..3e22e30c4 100644 --- a/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/RedisSessionRepositoryTests.java +++ b/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/RedisSessionRepositoryTests.java @@ -373,6 +373,35 @@ void getSessionRedisOperations__ShouldReturnRedisOperations() { verifyNoMoreInteractions(this.sessionHashOperations); } + @Test + void createSessionWhenSessionIdGenerationStrategyThenUses() { + this.sessionRepository.setSessionIdGenerationStrategy(() -> "test"); + RedisSessionRepository.RedisSession session = this.sessionRepository.createSession(); + assertThat(session.getId()).isEqualTo("test"); + assertThat(session.changeSessionId()).isEqualTo("test"); + } + + @Test + void setSessionIdGenerationStrategyWhenNullThenThrowsException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.sessionRepository.setSessionIdGenerationStrategy(null)) + .withMessage("sessionIdGenerationStrategy cannot be null"); + } + + @Test + void findByIdWhenChangeSessionIdThenUsesSessionIdGenerationStrategy() { + this.sessionRepository.setSessionIdGenerationStrategy(() -> "test"); + Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); + given(this.sessionHashOperations.entries(eq(TEST_SESSION_KEY))) + .willReturn(mapOf(RedisSessionMapper.CREATION_TIME_KEY, Instant.EPOCH.toEpochMilli(), + RedisSessionMapper.LAST_ACCESSED_TIME_KEY, now.toEpochMilli(), + RedisSessionMapper.MAX_INACTIVE_INTERVAL_KEY, MapSession.DEFAULT_MAX_INACTIVE_INTERVAL_SECONDS, + RedisSessionMapper.ATTRIBUTE_PREFIX + "attribute1", "value1")); + RedisSession session = this.sessionRepository.findById(TEST_SESSION_ID); + assertThat(session.getId()).isEqualTo(TEST_SESSION_ID); + assertThat(session.changeSessionId()).isEqualTo("test"); + } + private static String getSessionKey(String sessionId) { return "spring:session:sessions:" + sessionId; } diff --git a/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/config/annotation/web/http/RedisHttpsSessionConfigurationTests.java b/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/config/annotation/web/http/RedisHttpsSessionConfigurationTests.java index 736201c7d..3758bcf65 100644 --- a/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/config/annotation/web/http/RedisHttpsSessionConfigurationTests.java +++ b/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/config/annotation/web/http/RedisHttpsSessionConfigurationTests.java @@ -39,6 +39,8 @@ import org.springframework.mock.env.MockEnvironment; import org.springframework.session.FlushMode; import org.springframework.session.SaveMode; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.config.SessionRepositoryCustomizer; import org.springframework.session.data.redis.RedisSessionRepository; import org.springframework.session.data.redis.config.annotation.SpringSessionRedisConnectionFactory; @@ -206,6 +208,22 @@ void importConfigAndCustomize() { assertThat(sessionRepository).extracting("defaultMaxInactiveInterval").isEqualTo(Duration.ZERO); } + @Test + void registerWhenSessionIdGenerationStrategyBeanThenUses() { + registerAndRefresh(RedisConfig.class, SessionIdGenerationStrategyConfiguration.class); + RedisSessionRepository sessionRepository = this.context.getBean(RedisSessionRepository.class); + assertThat(sessionRepository).extracting("sessionIdGenerationStrategy") + .isInstanceOf(TestSessionIdGenerationStrategy.class); + } + + @Test + void registerWhenNoSessionIdGenerationStrategyBeanThenDefault() { + registerAndRefresh(RedisConfig.class, DefaultConfiguration.class); + RedisSessionRepository sessionRepository = this.context.getBean(RedisSessionRepository.class); + assertThat(sessionRepository).extracting("sessionIdGenerationStrategy") + .isInstanceOf(UuidSessionIdGenerationStrategy.class); + } + private void registerAndRefresh(Class... annotatedClasses) { this.context.register(annotatedClasses); this.context.refresh(); @@ -381,4 +399,30 @@ SessionRepositoryCustomizer sessionRepositoryCustomizer( } + @Configuration(proxyBeanMethods = false) + @EnableRedisHttpSession + static class SessionIdGenerationStrategyConfiguration { + + @Bean + SessionIdGenerationStrategy sessionIdGenerationStrategy() { + return new TestSessionIdGenerationStrategy(); + } + + } + + @Configuration(proxyBeanMethods = false) + @EnableRedisHttpSession + static class DefaultConfiguration { + + } + + static class TestSessionIdGenerationStrategy implements SessionIdGenerationStrategy { + + @Override + public String generate() { + return "test"; + } + + } + } diff --git a/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/config/annotation/web/http/RedisIndexedHttpSessionConfigurationTests.java b/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/config/annotation/web/http/RedisIndexedHttpSessionConfigurationTests.java index f89dc8b4f..f135e4db4 100644 --- a/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/config/annotation/web/http/RedisIndexedHttpSessionConfigurationTests.java +++ b/spring-session-data-redis/src/test/java/org/springframework/session/data/redis/config/annotation/web/http/RedisIndexedHttpSessionConfigurationTests.java @@ -43,6 +43,8 @@ import org.springframework.session.IndexResolver; import org.springframework.session.SaveMode; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.config.SessionRepositoryCustomizer; import org.springframework.session.data.redis.RedisIndexedSessionRepository; import org.springframework.session.data.redis.config.annotation.SpringSessionRedisConnectionFactory; @@ -240,6 +242,22 @@ void importConfigAndCustomize() { assertThat(sessionRepository).extracting("defaultMaxInactiveInterval").isEqualTo(Duration.ZERO); } + @Test + void registerWhenSessionIdGenerationStrategyBeanThenUses() { + registerAndRefresh(RedisConfig.class, SessionIdGenerationStrategyConfiguration.class); + RedisIndexedSessionRepository sessionRepository = this.context.getBean(RedisIndexedSessionRepository.class); + assertThat(sessionRepository).extracting("sessionIdGenerationStrategy") + .isInstanceOf(TestSessionIdGenerationStrategy.class); + } + + @Test + void registerWhenNoSessionIdGenerationStrategyBeanThenDefault() { + registerAndRefresh(RedisConfig.class, DefaultConfiguration.class); + RedisIndexedSessionRepository sessionRepository = this.context.getBean(RedisIndexedSessionRepository.class); + assertThat(sessionRepository).extracting("sessionIdGenerationStrategy") + .isInstanceOf(UuidSessionIdGenerationStrategy.class); + } + private void registerAndRefresh(Class... annotatedClasses) { this.context.register(annotatedClasses); this.context.refresh(); @@ -444,4 +462,30 @@ SessionRepositoryCustomizer sessionRepositoryCust } + @Configuration(proxyBeanMethods = false) + @EnableRedisIndexedHttpSession + static class SessionIdGenerationStrategyConfiguration { + + @Bean + SessionIdGenerationStrategy sessionIdGenerationStrategy() { + return new TestSessionIdGenerationStrategy(); + } + + } + + @Configuration(proxyBeanMethods = false) + @EnableRedisIndexedHttpSession + static class DefaultConfiguration { + + } + + static class TestSessionIdGenerationStrategy implements SessionIdGenerationStrategy { + + @Override + public String generate() { + return "test"; + } + + } + } diff --git a/spring-session-hazelcast/src/main/java/org/springframework/session/hazelcast/HazelcastIndexedSessionRepository.java b/spring-session-hazelcast/src/main/java/org/springframework/session/hazelcast/HazelcastIndexedSessionRepository.java index d9be431ba..0b5109dff 100644 --- a/spring-session-hazelcast/src/main/java/org/springframework/session/hazelcast/HazelcastIndexedSessionRepository.java +++ b/spring-session-hazelcast/src/main/java/org/springframework/session/hazelcast/HazelcastIndexedSessionRepository.java @@ -48,6 +48,8 @@ import org.springframework.session.PrincipalNameIndexResolver; import org.springframework.session.SaveMode; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.events.AbstractSessionEvent; import org.springframework.session.events.SessionCreatedEvent; import org.springframework.session.events.SessionDeletedEvent; @@ -151,6 +153,8 @@ public class HazelcastIndexedSessionRepository private UUID sessionListenerId; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + /** * Create a new {@link HazelcastIndexedSessionRepository} instance. * @param hazelcastInstance the {@link HazelcastInstance} to use for managing sessions @@ -245,7 +249,7 @@ public void setSaveMode(SaveMode saveMode) { @Override public HazelcastSession createSession() { - MapSession cached = new MapSession(); + MapSession cached = new MapSession(this.sessionIdGenerationStrategy); cached.setMaxInactiveInterval(this.defaultMaxInactiveInterval); HazelcastSession session = new HazelcastSession(cached, true); session.flushImmediateIfNecessary(); @@ -349,6 +353,16 @@ public void entryExpired(EntryEvent event) { this.eventPublisher.publishEvent(new SessionExpiredEvent(this, event.getOldValue())); } + /** + * Set the {@link SessionIdGenerationStrategy} to use to generate session ids. + * @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use + * @since 3.2 + */ + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null"); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + /** * A custom implementation of {@link Session} that uses a {@link MapSession} as the * basis for its mapping. It keeps track if changes have been made since last save. @@ -405,7 +419,8 @@ public String getId() { @Override public String changeSessionId() { - String newSessionId = this.delegate.changeSessionId(); + String newSessionId = HazelcastIndexedSessionRepository.this.sessionIdGenerationStrategy.generate(); + this.delegate.setId(newSessionId); this.sessionIdChanged = true; return newSessionId; } diff --git a/spring-session-hazelcast/src/main/java/org/springframework/session/hazelcast/config/annotation/web/http/HazelcastHttpSessionConfiguration.java b/spring-session-hazelcast/src/main/java/org/springframework/session/hazelcast/config/annotation/web/http/HazelcastHttpSessionConfiguration.java index 55c542abb..16354a58c 100644 --- a/spring-session-hazelcast/src/main/java/org/springframework/session/hazelcast/config/annotation/web/http/HazelcastHttpSessionConfiguration.java +++ b/spring-session-hazelcast/src/main/java/org/springframework/session/hazelcast/config/annotation/web/http/HazelcastHttpSessionConfiguration.java @@ -38,6 +38,8 @@ import org.springframework.session.MapSession; import org.springframework.session.SaveMode; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.config.SessionRepositoryCustomizer; import org.springframework.session.config.annotation.web.http.SpringHttpSessionConfiguration; import org.springframework.session.hazelcast.HazelcastIndexedSessionRepository; @@ -75,6 +77,8 @@ public class HazelcastHttpSessionConfiguration implements ImportAware { private List> sessionRepositoryCustomizers; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + @Bean public FindByIndexNameSessionRepository sessionRepository() { return createHazelcastIndexedSessionRepository(); @@ -158,9 +162,15 @@ private HazelcastIndexedSessionRepository createHazelcastIndexedSessionRepositor sessionRepository.setDefaultMaxInactiveInterval(this.maxInactiveInterval); sessionRepository.setFlushMode(this.flushMode); sessionRepository.setSaveMode(this.saveMode); + sessionRepository.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy); this.sessionRepositoryCustomizers .forEach((sessionRepositoryCustomizer) -> sessionRepositoryCustomizer.customize(sessionRepository)); return sessionRepository; } + @Autowired(required = false) + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + } diff --git a/spring-session-hazelcast/src/test/java/org/springframework/session/hazelcast/HazelcastIndexedSessionRepositoryTests.java b/spring-session-hazelcast/src/test/java/org/springframework/session/hazelcast/HazelcastIndexedSessionRepositoryTests.java index abe728e38..73215f68e 100644 --- a/spring-session-hazelcast/src/test/java/org/springframework/session/hazelcast/HazelcastIndexedSessionRepositoryTests.java +++ b/spring-session-hazelcast/src/test/java/org/springframework/session/hazelcast/HazelcastIndexedSessionRepositoryTests.java @@ -465,4 +465,31 @@ void saveWithSaveModeAlways() { verifyNoMoreInteractions(this.sessions); } + @Test + void createSessionWhenSessionIdGenerationStrategyThenUses() { + this.repository.setSessionIdGenerationStrategy(() -> "test"); + HazelcastSession session = this.repository.createSession(); + assertThat(session.getId()).isEqualTo("test"); + assertThat(session.changeSessionId()).isEqualTo("test"); + } + + @Test + void setSessionIdGenerationStrategyWhenNullThenThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setSessionIdGenerationStrategy(null)) + .withMessage("sessionIdGenerationStrategy cannot be null"); + } + + @Test + void findByIdWhenChangeSessionIdThenUsesSessionIdGenerationStrategy() { + this.repository.setSessionIdGenerationStrategy(() -> "test"); + MapSession saved = new MapSession("original"); + saved.setAttribute("savedName", "savedValue"); + given(this.sessions.get(eq(saved.getId()))).willReturn(saved); + + HazelcastSession session = this.repository.findById(saved.getId()); + + assertThat(session.getId()).isEqualTo(saved.getId()); + assertThat(session.changeSessionId()).isEqualTo("test"); + } + } diff --git a/spring-session-hazelcast/src/test/java/org/springframework/session/hazelcast/config/annotation/web/http/HazelcastHttpSessionConfigurationTests.java b/spring-session-hazelcast/src/test/java/org/springframework/session/hazelcast/config/annotation/web/http/HazelcastHttpSessionConfigurationTests.java index bb201fe79..062930ca4 100644 --- a/spring-session-hazelcast/src/test/java/org/springframework/session/hazelcast/config/annotation/web/http/HazelcastHttpSessionConfigurationTests.java +++ b/spring-session-hazelcast/src/test/java/org/springframework/session/hazelcast/config/annotation/web/http/HazelcastHttpSessionConfigurationTests.java @@ -36,6 +36,8 @@ import org.springframework.session.IndexResolver; import org.springframework.session.SaveMode; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.config.SessionRepositoryCustomizer; import org.springframework.session.hazelcast.HazelcastIndexedSessionRepository; import org.springframework.session.hazelcast.config.annotation.SpringSessionHazelcastInstance; @@ -238,6 +240,24 @@ void importConfigAndCustomize() { assertThat(sessionRepository).extracting("defaultMaxInactiveInterval").isEqualTo(Duration.ZERO); } + @Test + void registerWhenSessionIdGenerationStrategyBeanThenUses() { + registerAndRefresh(DefaultConfiguration.class, SessionIdGenerationStrategyConfiguration.class); + HazelcastIndexedSessionRepository sessionRepository = this.context + .getBean(HazelcastIndexedSessionRepository.class); + assertThat(sessionRepository).extracting("sessionIdGenerationStrategy") + .isInstanceOf(TestSessionIdGenerationStrategy.class); + } + + @Test + void registerWhenNoSessionIdGenerationStrategyBeanThenDefault() { + registerAndRefresh(DefaultConfiguration.class); + HazelcastIndexedSessionRepository sessionRepository = this.context + .getBean(HazelcastIndexedSessionRepository.class); + assertThat(sessionRepository).extracting("sessionIdGenerationStrategy") + .isInstanceOf(UuidSessionIdGenerationStrategy.class); + } + private void registerAndRefresh(Class... annotatedClasses) { this.context.register(annotatedClasses); this.context.refresh(); @@ -465,4 +485,23 @@ SessionRepositoryCustomizer sessionRepository } + @Configuration(proxyBeanMethods = false) + static class SessionIdGenerationStrategyConfiguration { + + @Bean + SessionIdGenerationStrategy sessionIdGenerationStrategy() { + return new TestSessionIdGenerationStrategy(); + } + + } + + static class TestSessionIdGenerationStrategy implements SessionIdGenerationStrategy { + + @Override + public String generate() { + return "test"; + } + + } + } diff --git a/spring-session-jdbc/src/main/java/org/springframework/session/jdbc/JdbcIndexedSessionRepository.java b/spring-session-jdbc/src/main/java/org/springframework/session/jdbc/JdbcIndexedSessionRepository.java index 1d12c1607..5c40be5ca 100644 --- a/spring-session-jdbc/src/main/java/org/springframework/session/jdbc/JdbcIndexedSessionRepository.java +++ b/spring-session-jdbc/src/main/java/org/springframework/session/jdbc/JdbcIndexedSessionRepository.java @@ -62,6 +62,8 @@ import org.springframework.session.PrincipalNameIndexResolver; import org.springframework.session.SaveMode; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.transaction.support.TransactionOperations; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -252,6 +254,8 @@ public class JdbcIndexedSessionRepository implements private ThreadPoolTaskScheduler taskScheduler; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + /** * Create a new {@link JdbcIndexedSessionRepository} instance which uses the provided * {@link JdbcOperations} and {@link TransactionOperations} to manage sessions. @@ -461,7 +465,7 @@ public void setCleanupCron(String cleanupCron) { @Override public JdbcSession createSession() { - MapSession delegate = new MapSession(); + MapSession delegate = new MapSession(this.sessionIdGenerationStrategy); delegate.setMaxInactiveInterval(this.defaultMaxInactiveInterval); JdbcSession session = new JdbcSession(delegate, UUID.randomUUID().toString(), true); session.flushIfRequired(); @@ -686,6 +690,16 @@ private Object deserialize(byte[] bytes) { TypeDescriptor.valueOf(Object.class)); } + /** + * Set the {@link SessionIdGenerationStrategy} to use to generate session ids. + * @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use + * @since 3.2 + */ + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null"); + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + private enum DeltaValue { ADDED, UPDATED, REMOVED @@ -721,7 +735,7 @@ public T get() { */ final class JdbcSession implements Session { - private final Session delegate; + private final MapSession delegate; private final String primaryKey; @@ -773,7 +787,9 @@ public String getId() { @Override public String changeSessionId() { this.changed = true; - return this.delegate.changeSessionId(); + String newSessionId = JdbcIndexedSessionRepository.this.sessionIdGenerationStrategy.generate(); + this.delegate.setId(newSessionId); + return newSessionId; } @Override diff --git a/spring-session-jdbc/src/main/java/org/springframework/session/jdbc/config/annotation/web/http/JdbcHttpSessionConfiguration.java b/spring-session-jdbc/src/main/java/org/springframework/session/jdbc/config/annotation/web/http/JdbcHttpSessionConfiguration.java index 5c0cb2648..e937c1151 100644 --- a/spring-session-jdbc/src/main/java/org/springframework/session/jdbc/config/annotation/web/http/JdbcHttpSessionConfiguration.java +++ b/spring-session-jdbc/src/main/java/org/springframework/session/jdbc/config/annotation/web/http/JdbcHttpSessionConfiguration.java @@ -50,6 +50,8 @@ import org.springframework.session.MapSession; import org.springframework.session.SaveMode; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.config.SessionRepositoryCustomizer; import org.springframework.session.config.annotation.web.http.SpringHttpSessionConfiguration; import org.springframework.session.jdbc.JdbcIndexedSessionRepository; @@ -109,6 +111,8 @@ public class JdbcHttpSessionConfiguration implements BeanClassLoaderAware, Embed private StringValueResolver embeddedValueResolver; + private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance(); + @Bean public JdbcIndexedSessionRepository sessionRepository() { JdbcTemplate jdbcTemplate = createJdbcTemplate(this.dataSource); @@ -144,6 +148,7 @@ else if (this.conversionService != null) { else { sessionRepository.setConversionService(createConversionServiceWithBeanClassLoader(this.classLoader)); } + sessionRepository.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy); this.sessionRepositoryCustomizers .forEach((sessionRepositoryCustomizer) -> sessionRepositoryCustomizer.customize(sessionRepository)); return sessionRepository; @@ -235,6 +240,11 @@ public void setSessionRepositoryCustomizer( this.sessionRepositoryCustomizers = sessionRepositoryCustomizers.orderedStream().collect(Collectors.toList()); } + @Autowired(required = false) + public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) { + this.sessionIdGenerationStrategy = sessionIdGenerationStrategy; + } + @Override public void setBeanClassLoader(ClassLoader classLoader) { this.classLoader = classLoader; diff --git a/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/FixedSessionIdGenerationStrategy.java b/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/FixedSessionIdGenerationStrategy.java new file mode 100644 index 000000000..e4872a2df --- /dev/null +++ b/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/FixedSessionIdGenerationStrategy.java @@ -0,0 +1,34 @@ +/* + * Copyright 2014-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.session.jdbc; + +import org.springframework.session.SessionIdGenerationStrategy; + +public class FixedSessionIdGenerationStrategy implements SessionIdGenerationStrategy { + + private final String id; + + public FixedSessionIdGenerationStrategy(String id) { + this.id = id; + } + + @Override + public String generate() { + return this.id; + } + +} diff --git a/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/JdbcIndexedSessionRepositoryTests.java b/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/JdbcIndexedSessionRepositoryTests.java index 1d3283eaa..e2d9d51ed 100644 --- a/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/JdbcIndexedSessionRepositoryTests.java +++ b/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/JdbcIndexedSessionRepositoryTests.java @@ -263,6 +263,12 @@ void setCleanupCronDisabled() { assertThat(this.repository).extracting("taskScheduler").isNull(); } + @Test + void setSessionIdGenerationStrategyWhenNullThenException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setSessionIdGenerationStrategy(null)) + .withMessage("sessionIdGenerationStrategy cannot be null"); + } + @Test void createSessionDefaultMaxInactiveInterval() { JdbcSession session = this.repository.createSession(); @@ -769,4 +775,32 @@ void saveAndFreeTemporaryLob() { verify(lobCreator, atLeastOnce()).close(); } + @Test + void createSessionWhenSessionIdGenerationStrategyThenUses() { + this.repository.setSessionIdGenerationStrategy(() -> "test"); + JdbcSession session = this.repository.createSession(); + assertThat(session.getId()).isEqualTo("test"); + assertThat(session.changeSessionId()).isEqualTo("test"); + } + + @Test + void setSessionIdGenerationStrategyWhenNullThenThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setSessionIdGenerationStrategy(null)) + .withMessage("sessionIdGenerationStrategy cannot be null"); + } + + @Test + void findByIdWhenChangeSessionIdThenUsesSessionIdGenerationStrategy() { + this.repository.setSessionIdGenerationStrategy(() -> "test"); + Session saved = this.repository.new JdbcSession(new MapSession(), "primaryKey", false); + saved.setAttribute("savedName", "savedValue"); + given(this.jdbcOperations.query(isA(String.class), isA(PreparedStatementSetter.class), + isA(ResultSetExtractor.class))).willReturn(Collections.singletonList(saved)); + + JdbcSession session = this.repository.findById(saved.getId()); + + assertThat(session.getId()).isEqualTo(saved.getId()); + assertThat(session.changeSessionId()).isEqualTo("test"); + } + } diff --git a/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/config/annotation/web/http/JdbcHttpSessionConfigurationTests.java b/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/config/annotation/web/http/JdbcHttpSessionConfigurationTests.java index dcf5c2c7e..8b9892c83 100644 --- a/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/config/annotation/web/http/JdbcHttpSessionConfigurationTests.java +++ b/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/config/annotation/web/http/JdbcHttpSessionConfigurationTests.java @@ -43,7 +43,10 @@ import org.springframework.session.IndexResolver; import org.springframework.session.SaveMode; import org.springframework.session.Session; +import org.springframework.session.SessionIdGenerationStrategy; +import org.springframework.session.UuidSessionIdGenerationStrategy; import org.springframework.session.config.SessionRepositoryCustomizer; +import org.springframework.session.jdbc.FixedSessionIdGenerationStrategy; import org.springframework.session.jdbc.JdbcIndexedSessionRepository; import org.springframework.session.jdbc.config.annotation.SpringSessionDataSource; import org.springframework.test.util.ReflectionTestUtils; @@ -318,11 +321,40 @@ void importConfigAndCustomize() { assertThat(sessionRepository).extracting("defaultMaxInactiveInterval").isEqualTo(Duration.ZERO); } + @Test + void sessionIdGenerationStrategyWhenCustomBeanThenUses() { + registerAndRefresh(DataSourceConfiguration.class, CustomSessionIdGenerationStrategyConfiguration.class); + JdbcIndexedSessionRepository sessionRepository = this.context.getBean(JdbcIndexedSessionRepository.class); + SessionIdGenerationStrategy sessionIdGenerationStrategy = (SessionIdGenerationStrategy) ReflectionTestUtils + .getField(sessionRepository, "sessionIdGenerationStrategy"); + assertThat(sessionIdGenerationStrategy).isInstanceOf(FixedSessionIdGenerationStrategy.class); + } + + @Test + void sessionIdGenerationStrategyWhenNoBeanThenDefault() { + registerAndRefresh(DataSourceConfiguration.class, DefaultConfiguration.class); + JdbcIndexedSessionRepository sessionRepository = this.context.getBean(JdbcIndexedSessionRepository.class); + SessionIdGenerationStrategy sessionIdGenerationStrategy = (SessionIdGenerationStrategy) ReflectionTestUtils + .getField(sessionRepository, "sessionIdGenerationStrategy"); + assertThat(sessionIdGenerationStrategy).isInstanceOf(UuidSessionIdGenerationStrategy.class); + } + private void registerAndRefresh(Class... annotatedClasses) { this.context.register(annotatedClasses); this.context.refresh(); } + @Configuration(proxyBeanMethods = false) + @EnableJdbcHttpSession + static class CustomSessionIdGenerationStrategyConfiguration { + + @Bean + SessionIdGenerationStrategy sessionIdGenerationStrategy() { + return new FixedSessionIdGenerationStrategy("my-id"); + } + + } + @Configuration(proxyBeanMethods = false) @EnableJdbcHttpSession static class NoDataSourceConfiguration {