diff --git a/jbpm-workitems/jbpm-workitems-webservice/src/main/java/org/jbpm/process/workitem/webservice/WebServiceWorkItemHandler.java b/jbpm-workitems/jbpm-workitems-webservice/src/main/java/org/jbpm/process/workitem/webservice/WebServiceWorkItemHandler.java index c097ee1953..3e51dd09e6 100644 --- a/jbpm-workitems/jbpm-workitems-webservice/src/main/java/org/jbpm/process/workitem/webservice/WebServiceWorkItemHandler.java +++ b/jbpm-workitems/jbpm-workitems-webservice/src/main/java/org/jbpm/process/workitem/webservice/WebServiceWorkItemHandler.java @@ -17,8 +17,10 @@ package org.jbpm.process.workitem.webservice; import java.io.File; +import java.io.Writer; import java.lang.reflect.Array; import java.lang.reflect.Field; +import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.net.MalformedURLException; import java.net.URL; @@ -34,8 +36,10 @@ import javax.xml.namespace.QName; -import org.apache.cxf.common.jaxb.JAXBUtils; +import org.apache.cxf.common.util.ProxyHelper; +import org.apache.cxf.common.util.ReflectionUtil; import org.apache.cxf.configuration.security.AuthorizationPolicy; +import org.apache.cxf.databinding.DataBinding; import org.apache.cxf.endpoint.Client; import org.apache.cxf.endpoint.ClientCallback; import org.apache.cxf.endpoint.Endpoint; @@ -480,38 +484,29 @@ private void removeWrappingInterceptors(Client client){ }); } - @SuppressWarnings("unchecked") protected Client getWSClient(WorkItem workItem, String interfaceRef) { - if (clients.containsKey(interfaceRef)) { - return clients.get(interfaceRef); + return clients.computeIfAbsent(interfaceRef, k -> createClient(workItem, interfaceRef)); + } + + private Client createClient(WorkItem workItem, String interfaceRef) { + String importLocation = (String) workItem.getParameter("Url"); + String importNamespace = (String) workItem.getParameter("Namespace"); + if (importLocation != null && importLocation.trim().length() > 0 && importNamespace != null && importNamespace.trim().length() > 0) { + return createClient(workItem, importLocation, importNamespace, interfaceRef); } - synchronized (this) { - - if (clients.containsKey(interfaceRef)) { - return clients.get(interfaceRef); - } - - String importLocation = (String) workItem.getParameter("Url"); - String importNamespace = (String) workItem.getParameter("Namespace"); - if (importLocation != null && importLocation.trim().length() > 0 - && importNamespace != null && importNamespace.trim().length() > 0) { - return getClient (workItem, importLocation, importNamespace, interfaceRef); - } - - long processInstanceId = ((WorkItemImpl) workItem).getProcessInstanceId(); - WorkflowProcessImpl process = ((WorkflowProcessImpl) ksession.getProcessInstance(processInstanceId).getProcess()); - List typedImports = (List) process.getMetaData("Bpmn2Imports"); - - if (typedImports != null) { - Client client = null; - for (Bpmn2Import importObj : typedImports) { - if (WSDL_IMPORT_TYPE.equalsIgnoreCase(importObj.getType())) { - try { - return getClient (workItem, importObj.getLocation(), importObj.getNamespace(), interfaceRef); - } catch (Exception e) { - logger.error("Error when creating WS Client", e); - } + long processInstanceId = ((WorkItemImpl) workItem).getProcessInstanceId(); + WorkflowProcessImpl process = ((WorkflowProcessImpl) ksession.getProcessInstance(processInstanceId).getProcess()); + @SuppressWarnings("unchecked") + List typedImports = (List) process.getMetaData("Bpmn2Imports"); + + if (typedImports != null) { + for (Bpmn2Import importObj : typedImports) { + if (WSDL_IMPORT_TYPE.equalsIgnoreCase(importObj.getType())) { + try { + return createClient(workItem, importObj.getLocation(), importObj.getNamespace(), interfaceRef); + } catch (Exception e) { + logger.error("Error when creating WS Client", e); } } } @@ -519,45 +514,105 @@ protected Client getWSClient(WorkItem workItem, String interfaceRef) { return null; } - - private Client getClient(WorkItem workItem, String location, String namespace, String interfaceRef) { + private Client createClient(WorkItem workItem, String location, String namespace, String interfaceRef) { Client client = getDynamicClientFactory().createClient(location, new QName(namespace, interfaceRef), getInternalClassLoader(), null); setClientTimeout(workItem, client); + setEscapeHandler(workItem, client); + addHeaders(workItem, client); + return client; + } + + private void addHeaders(WorkItem workItem, Client client) { Collection headers = WorkItemHeaderUtils.getHeaderInfo(workItem); if (!headers.isEmpty()) { client.getRequestContext().put(Header.HEADER_LIST, - headers.stream().map(this::buildHeader).collect(Collectors.toList())); + headers.stream().map(h -> buildHeader(h, client)).collect(Collectors.toList())); } - clients.put(interfaceRef, client); - return client; } - - - private Header buildHeader(WorkItemHeaderInfo header) { + private Header buildHeader(WorkItemHeaderInfo header, Client client) { String namespace = (String) header.getParam("NS"); QName name = namespace == null ? new QName(header.getName()) : new QName(namespace, header.getName()); - Class contentClass = String.class; - String type = (String) header.getParam("TYPE"); - boolean escape = Boolean.parseBoolean((String) header.getParam("ESCAPE")); - if (type != null) { + JAXBDataBinding binding = (JAXBDataBinding) client.getConduitSelector().getEndpoint().getService().getDataBinding(); + String escapeHandler = (String) header.getParam("ESCAPE"); + if (escapeHandler != null) try { - contentClass = classLoader.loadClass(type); - } catch (ClassNotFoundException ex) { - logger.warn("Cannot find type {}", type, ex); + binding = new JAXBDataBinding(binding.getContext()); + setEscapeHandler(binding, escapeHandler); + } catch (Exception ex) { + logger.warn("Error creating binding for escapeHandler {}", escapeHandler, ex); + } + return new Header(name, header.getContent(), binding); + } + + private void setEscapeHandler(WorkItem workItem, Client client) { + String escapeHandler = (String) workItem.getParameter("ESCAPE_HANDLER"); + if (escapeHandler == null) { + escapeHandler = System.getProperty("org.jbpm.cxf.client.escapeHandler"); + } + if (escapeHandler != null) { + DataBinding binding = client.getConduitSelector().getEndpoint().getService().getDataBinding(); + if (binding instanceof JAXBDataBinding) { + setEscapeHandler((JAXBDataBinding) binding, escapeHandler); } } - JAXBDataBinding binding = null; + } + + private static void setEscapeHandler(JAXBDataBinding dataBinding, String escapeHandler) { + Object escapeHandlerObj = createEscapeHandler(dataBinding, escapeHandler); + logger.debug("Escape handler {} created. Object is {}", escapeHandler, escapeHandlerObj); + Map map = dataBinding.getMarshallerProperties(); + if (map == null) { + map = new HashMap<>(); + } + // if implementation is not reference one, this should be ignored. + map.put("com.sun.xml.bind.characterEscapeHandler", escapeHandlerObj); + map.put("com.sun.xml.bind.marshaller.CharacterEscapeHandler", escapeHandlerObj); + logger.debug("Marshalling properties {}", map); + dataBinding.setMarshallerProperties(map); + } + + /* this code is a modified version of cxf that uses the class loader of context, no current thread context */ + public static Object createEscapeHandler(JAXBDataBinding binding, String escapeHandler) { + Class cls = binding.getContext().getClass(); + ClassLoader classLoader = cls.getClassLoader(); + String className = cls.getName(); + String postFix = className.contains("com.sun.xml.internal") || className.contains("eclipse") ? ".internal" : ""; try { - binding = new JAXBDataBinding(contentClass); - if (!escape) { - binding.setEscapeHandler(JAXBUtils.createNoEscapeHandler(binding.getContext().getClass())); + Class handlerInterface = classLoader.loadClass("com.sun.xml" + postFix + ".bind.marshaller.CharacterEscapeHandler"); + Class handlerClass = classLoader.loadClass("com.sun.xml" + postFix + ".bind.marshaller." + escapeHandler); + Object targetHandler = ReflectionUtil.getDeclaredField(handlerClass, "theInstance").get(null); + return ProxyHelper.getProxy(classLoader, new Class[]{handlerInterface}, new LoggingEscapeHandlerInvocationHandler(targetHandler)); + } catch (Exception e) { + logger.warn("Error instantiating escape handler, characters will be escaped", e); + } + return null; + } + + private static final class LoggingEscapeHandlerInvocationHandler implements InvocationHandler { + + private Object target; + + public LoggingEscapeHandlerInvocationHandler(Object obj) { + target = obj; + } + + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + logger.debug("Escape handler invoked with args {}", args); + Object result = null; + if (method.getName().equals("escape") && args.length == 5) { + if ((Integer) args[1] == 0 && (Integer) args[2] == 0) { + Writer writer = (Writer) args[4]; + writer.write(""); + return null; + } + result = method.invoke(target, args); } - } catch (Exception ex) { - logger.warn("Error creating binding for type {}", type, ex); + logger.debug("Escape handler result {}", result); + return result; } - return new Header(name, header.getContent(), binding); } private void setClientTimeout(WorkItem workItem, Client client) { diff --git a/jbpm-workitems/jbpm-workitems-webservice/src/test/java/org/jbpm/process/workitem/webservice/WebServiceWorkItemHandlerTest.java b/jbpm-workitems/jbpm-workitems-webservice/src/test/java/org/jbpm/process/workitem/webservice/WebServiceWorkItemHandlerTest.java index ea83fafc93..3fbff077bc 100644 --- a/jbpm-workitems/jbpm-workitems-webservice/src/test/java/org/jbpm/process/workitem/webservice/WebServiceWorkItemHandlerTest.java +++ b/jbpm-workitems/jbpm-workitems-webservice/src/test/java/org/jbpm/process/workitem/webservice/WebServiceWorkItemHandlerTest.java @@ -16,6 +16,16 @@ package org.jbpm.process.workitem.webservice; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + import java.util.ArrayList; import java.util.concurrent.ConcurrentHashMap; @@ -30,6 +40,7 @@ import org.apache.cxf.transport.http.HTTPConduit; import org.drools.core.process.instance.impl.WorkItemImpl; import org.jbpm.process.workitem.core.TestWorkItemManager; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.kie.api.runtime.KieSession; @@ -38,16 +49,6 @@ import org.mockito.Mockito; import org.mockito.runners.MockitoJUnitRunner; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - @RunWith(MockitoJUnitRunner.class) public class WebServiceWorkItemHandlerTest { @@ -60,12 +61,15 @@ public class WebServiceWorkItemHandlerTest { @Mock ConcurrentHashMap clients; + @Before + public void setUp() { + when(clients.computeIfAbsent(any(), any())).thenReturn(client); + } + @Test public void testExecuteSyncOperation() throws Exception { - when(clients.containsKey(any())).thenReturn(true); - when(clients.get(any())).thenReturn(client); - + TestWorkItemManager manager = new TestWorkItemManager(); WorkItemImpl workItem = new WorkItemImpl(); workItem.setParameter("Interface", @@ -91,8 +95,6 @@ public void testExecuteSyncOperation() throws Exception { @Test public void testExecuteWrappedModeSync() throws Exception { - when(clients.containsKey(any())).thenReturn(true); - when(clients.get(any())).thenReturn(client); Endpoint endpoint = mock(Endpoint.class); when(client.getEndpoint()).thenReturn(endpoint); ArrayList> interceptors = new ArrayList<>(); @@ -129,8 +131,6 @@ public void testExecuteWrappedModeSync() throws Exception { @Test public void testExecuteWrappedModeOneWay() throws Exception { - when(clients.containsKey(any())).thenReturn(true); - when(clients.get(any())).thenReturn(client); Endpoint endpoint = mock(Endpoint.class); when(client.getEndpoint()).thenReturn(endpoint); ArrayList> interceptors = new ArrayList<>(); @@ -164,8 +164,6 @@ public void testExecuteWrappedModeOneWay() throws Exception { @Test public void testExecuteWrappedModeAsync() throws Exception { - when(clients.containsKey(any())).thenReturn(true); - when(clients.get(any())).thenReturn(client); Endpoint endpoint = mock(Endpoint.class); when(client.getEndpoint()).thenReturn(endpoint); ArrayList> interceptors = new ArrayList<>(); @@ -202,8 +200,6 @@ public void testExecuteSyncOperationWithBasicAuth() throws Exception { HTTPConduit http = Mockito.mock(HTTPConduit.class, Mockito.CALLS_REAL_METHODS); - when(clients.containsKey(any())).thenReturn(true); - when(clients.get(any())).thenReturn(client); when(client.getConduit()).thenReturn(http); TestWorkItemManager manager = new TestWorkItemManager(); @@ -241,10 +237,7 @@ public void testExecuteSyncOperationWithBasicAuth() throws Exception { @Test public void testExecuteSyncOperationHandlingException() throws Exception { - - when(clients.containsKey(any())).thenReturn(true); - when(clients.get(any())).thenReturn(null); - + when(clients.computeIfAbsent(any(), any())).thenReturn(null); TestWorkItemManager manager = new TestWorkItemManager(); WorkItemImpl workItem = new WorkItemImpl(); workItem.setParameter("Interface",