Skip to content

Commit

Permalink
ISPN-12690 Support per-cache marshaller in Spring caches
Browse files Browse the repository at this point in the history
* Always register a ProtoStreamMarshaller in RemoteCacheManager
* Register the serialization context initializers even if
  ProtoStreamMarshaller is not the default marshaller
* Repeat both steps in SpringRemoteCacheManager
* Test null values with per-cache marshaller
  • Loading branch information
danberindei authored and karesti committed Apr 2, 2021
1 parent 00ac506 commit bb103db
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -340,30 +340,35 @@ private void actualStart() {
log.debugf("Starting remote cache manager %x", System.identityHashCode(this));
channelFactory = createChannelFactory();

marshallerRegistry.registerMarshaller(BytesOnlyMarshaller.INSTANCE);
marshallerRegistry.registerMarshaller(new UTF8StringMarshaller());
marshallerRegistry.registerMarshaller(new JavaSerializationMarshaller(configuration.getClassAllowList()));
registerProtoStreamMarshaller();

boolean customMarshallerInstance = true;
marshaller = configuration.marshaller();
if (marshaller == null) {
marshaller = configuration.marshaller();
Class<? extends Marshaller> clazz = configuration.marshallerClass();
marshaller = marshallerRegistry.getMarshaller(clazz);
if (marshaller == null) {
Class<? extends Marshaller> clazz = configuration.marshallerClass();
if (marshaller == null) {
marshaller = Util.getInstance(clazz);
}
marshaller = Util.getInstance(clazz);
} else {
customMarshallerInstance = false;
}
}
if (!configuration.serialAllowList().isEmpty()) {
marshaller.initialize(configuration.getClassAllowList());
}
if (marshaller instanceof ProtoStreamMarshaller) {
SerializationContext ctx = ((ProtoStreamMarshaller) marshaller).getSerializationContext();
for (SerializationContextInitializer sci : configuration.getContextInitializers()) {
sci.registerSchema(ctx);
sci.registerMarshallers(ctx);

if (customMarshallerInstance) {
if (!configuration.serialAllowList().isEmpty()) {
marshaller.initialize(configuration.getClassAllowList());
}

if (marshaller instanceof ProtoStreamMarshaller) {
initializeProtoStreamMarshaller((ProtoStreamMarshaller) marshaller);
}

// Replace any default marshaller with the same media type
marshallerRegistry.registerMarshaller(marshaller);
}
marshallerRegistry.registerMarshaller(BytesOnlyMarshaller.INSTANCE);
marshallerRegistry.registerMarshaller(new UTF8StringMarshaller());
marshallerRegistry.registerMarshaller(new JavaSerializationMarshaller(configuration.getClassAllowList()));
// Register this one last, so it will replace any that may support the same media type
marshallerRegistry.registerMarshaller(marshaller);

codec = configuration.version().getCodec();

Expand All @@ -387,6 +392,25 @@ private void actualStart() {
started = true;
}

private void registerProtoStreamMarshaller() {
try {
ProtoStreamMarshaller protoMarshaller = new ProtoStreamMarshaller();
marshallerRegistry.registerMarshaller(protoMarshaller);

initializeProtoStreamMarshaller(protoMarshaller);
} catch (NoClassDefFoundError e) {
// Ignore the error, it the protostream dependency is missing
}
}

private void initializeProtoStreamMarshaller(ProtoStreamMarshaller protoMarshaller) {
SerializationContext ctx = protoMarshaller.getSerializationContext();
for (SerializationContextInitializer sci : configuration.getContextInitializers()) {
sci.registerSchema(ctx);
sci.registerMarshallers(ctx);
}
}

public ChannelFactory createChannelFactory() {
return new ChannelFactory();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import org.infinispan.client.hotrod.RemoteCache;
import org.infinispan.client.hotrod.RemoteCacheManager;
import org.infinispan.client.hotrod.impl.MarshallerRegistry;
import org.infinispan.client.hotrod.marshall.MarshallerUtil;
import org.infinispan.commons.configuration.ClassAllowList;
import org.infinispan.commons.dataconversion.MediaType;
import org.infinispan.commons.logging.Log;
import org.infinispan.commons.marshall.JavaSerializationMarshaller;
import org.infinispan.commons.marshall.ProtoStreamMarshaller;
import org.infinispan.protostream.BaseMarshaller;
import org.infinispan.protostream.SerializationContext;
import org.infinispan.protostream.SerializationContextInitializer;
import org.infinispan.spring.common.provider.NullValue;
import org.infinispan.spring.common.provider.SpringCache;
import org.infinispan.spring.common.session.MapSessionProtoAdapter;
Expand Down Expand Up @@ -148,8 +148,24 @@ private void configureMarshallers(RemoteCacheManager nativeCacheManager) {
// Protostream support
ProtoStreamMarshaller protoMarshaller =
(ProtoStreamMarshaller) marshallerRegistry.getMarshaller(MediaType.APPLICATION_PROTOSTREAM);
if (protoMarshaller == null) {
try {
protoMarshaller = new ProtoStreamMarshaller();
marshallerRegistry.registerMarshaller(protoMarshaller);

// Apply the serialization context initializers in the configuration first
SerializationContext ctx = protoMarshaller.getSerializationContext();
for (SerializationContextInitializer sci : nativeCacheManager.getConfiguration().getContextInitializers()) {
sci.registerSchema(ctx);
sci.registerMarshallers(ctx);
}
} catch (NoClassDefFoundError e) {
// Ignore the error, the protostream dependency is missing
}
}
if (protoMarshaller != null) {
SerializationContext ctx = MarshallerUtil.getSerializationContext(nativeCacheManager);
// Apply our own serialization context initializers
SerializationContext ctx = protoMarshaller.getSerializationContext();
addProviderContextInitializer(ctx);
addSessionContextInitializerAndMarshaller(ctx, serializationMarshaller);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

import org.infinispan.client.hotrod.RemoteCacheManager;
import org.infinispan.client.hotrod.configuration.ConfigurationBuilder;
import org.infinispan.commons.dataconversion.MediaType;
import org.infinispan.commons.marshall.JavaSerializationMarshaller;
import org.infinispan.commons.marshall.ProtoStreamMarshaller;
import org.infinispan.manager.EmbeddedCacheManager;
import org.infinispan.server.hotrod.HotRodServer;
import org.infinispan.server.hotrod.test.HotRodTestingUtil;
Expand All @@ -19,20 +22,25 @@
import org.springframework.cache.Cache;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

@Test(testName = "spring.provider.SpringRemoteCacheTest", groups = "functional")
public class SpringRemoteCacheTest extends SingleCacheManagerTest {

private static final String TEST_CACHE_NAME = "spring.remote.cache.Test";
private static final String TEST_CACHE_NAME = "SerializationCache";
private static final String TEST_CACHE_NAME_PROTO = "ProtoStreamCache";

private RemoteCacheManager remoteCacheManager;
private HotRodServer hotrodServer;

@Override
protected EmbeddedCacheManager createCacheManager() throws Exception {
cacheManager = TestCacheManagerFactory.createCacheManager(hotRodCacheConfiguration());
cacheManager.defineConfiguration(TEST_CACHE_NAME, cacheManager.getDefaultCacheConfiguration());
cacheManager.defineConfiguration(TEST_CACHE_NAME,
hotRodCacheConfiguration(MediaType.APPLICATION_SERIALIZED_OBJECT).build());
cacheManager.defineConfiguration(TEST_CACHE_NAME_PROTO,
hotRodCacheConfiguration(MediaType.APPLICATION_PROTOSTREAM).build());
cache = cacheManager.getCache(TEST_CACHE_NAME);

return cacheManager;
Expand All @@ -43,6 +51,8 @@ public void setupRemoteCacheFactory() {
hotrodServer = HotRodTestingUtil.startHotRodServer(cacheManager, 0);
ConfigurationBuilder builder = new ConfigurationBuilder();
builder.addServer().host("localhost").port(hotrodServer.getPort());
builder.remoteCache(TEST_CACHE_NAME).marshaller(JavaSerializationMarshaller.class);
builder.remoteCache(TEST_CACHE_NAME_PROTO).marshaller(ProtoStreamMarshaller.class);
remoteCacheManager = new RemoteCacheManager(builder.build());
}

Expand Down Expand Up @@ -93,10 +103,19 @@ public void testValueLoaderWithLocking() throws Exception {
assertEquals("thread1", valueObtainedByThread2);
}

public void testNullValues() {
@DataProvider(name = "caches")
public Object[][] caches() {
return new Object[][] {
{TEST_CACHE_NAME},
{TEST_CACHE_NAME_PROTO},
};
}

@Test(dataProvider = "caches")
public void testNullValues(String cacheName) {
//given
final SpringRemoteCacheManager springRemoteCacheManager = new SpringRemoteCacheManager(remoteCacheManager);
final SpringCache cache = springRemoteCacheManager.getCache(TEST_CACHE_NAME);
final SpringCache cache = springRemoteCacheManager.getCache(cacheName);

// when
cache.put("key", null);
Expand Down

0 comments on commit bb103db

Please sign in to comment.