Skip to content

Commit

Permalink
[JBPM-9738] Adding support non escaping headers
Browse files Browse the repository at this point in the history
  • Loading branch information
fjtirado committed Aug 23, 2021
1 parent 73f2298 commit b2884a5
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 77 deletions.
Expand Up @@ -17,8 +17,10 @@
package org.jbpm.process.workitem.webservice;

import java.io.File;
import java.io.Writer;
import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
Expand All @@ -34,8 +36,10 @@

import javax.xml.namespace.QName;

import org.apache.cxf.common.jaxb.JAXBUtils;
import org.apache.cxf.common.util.ProxyHelper;
import org.apache.cxf.common.util.ReflectionUtil;
import org.apache.cxf.configuration.security.AuthorizationPolicy;
import org.apache.cxf.databinding.DataBinding;
import org.apache.cxf.endpoint.Client;
import org.apache.cxf.endpoint.ClientCallback;
import org.apache.cxf.endpoint.Endpoint;
Expand Down Expand Up @@ -480,84 +484,135 @@ private void removeWrappingInterceptors(Client client){
});
}

@SuppressWarnings("unchecked")
protected Client getWSClient(WorkItem workItem, String interfaceRef) {
if (clients.containsKey(interfaceRef)) {
return clients.get(interfaceRef);
return clients.computeIfAbsent(interfaceRef, k -> createClient(workItem, interfaceRef));
}

private Client createClient(WorkItem workItem, String interfaceRef) {
String importLocation = (String) workItem.getParameter("Url");
String importNamespace = (String) workItem.getParameter("Namespace");
if (importLocation != null && importLocation.trim().length() > 0 && importNamespace != null && importNamespace.trim().length() > 0) {
return createClient(workItem, importLocation, importNamespace, interfaceRef);
}

synchronized (this) {

if (clients.containsKey(interfaceRef)) {
return clients.get(interfaceRef);
}

String importLocation = (String) workItem.getParameter("Url");
String importNamespace = (String) workItem.getParameter("Namespace");
if (importLocation != null && importLocation.trim().length() > 0
&& importNamespace != null && importNamespace.trim().length() > 0) {
return getClient (workItem, importLocation, importNamespace, interfaceRef);
}

long processInstanceId = ((WorkItemImpl) workItem).getProcessInstanceId();
WorkflowProcessImpl process = ((WorkflowProcessImpl) ksession.getProcessInstance(processInstanceId).getProcess());
List<Bpmn2Import> typedImports = (List<Bpmn2Import>) process.getMetaData("Bpmn2Imports");

if (typedImports != null) {
Client client = null;
for (Bpmn2Import importObj : typedImports) {
if (WSDL_IMPORT_TYPE.equalsIgnoreCase(importObj.getType())) {
try {
return getClient (workItem, importObj.getLocation(), importObj.getNamespace(), interfaceRef);
} catch (Exception e) {
logger.error("Error when creating WS Client", e);
}
long processInstanceId = ((WorkItemImpl) workItem).getProcessInstanceId();
WorkflowProcessImpl process = ((WorkflowProcessImpl) ksession.getProcessInstance(processInstanceId).getProcess());
@SuppressWarnings("unchecked")
List<Bpmn2Import> typedImports = (List<Bpmn2Import>) process.getMetaData("Bpmn2Imports");

if (typedImports != null) {
for (Bpmn2Import importObj : typedImports) {
if (WSDL_IMPORT_TYPE.equalsIgnoreCase(importObj.getType())) {
try {
return createClient(workItem, importObj.getLocation(), importObj.getNamespace(), interfaceRef);
} catch (Exception e) {
logger.error("Error when creating WS Client", e);
}
}
}
}
return null;
}


private Client getClient(WorkItem workItem, String location, String namespace, String interfaceRef) {
private Client createClient(WorkItem workItem, String location, String namespace, String interfaceRef) {
Client client = getDynamicClientFactory().createClient(location, new QName(namespace, interfaceRef),
getInternalClassLoader(), null);
setClientTimeout(workItem, client);
setEscapeHandler(workItem, client);
addHeaders(workItem, client);
return client;
}

private void addHeaders(WorkItem workItem, Client client) {
Collection<WorkItemHeaderInfo> headers = WorkItemHeaderUtils.getHeaderInfo(workItem);
if (!headers.isEmpty()) {
client.getRequestContext().put(Header.HEADER_LIST,
headers.stream().map(this::buildHeader).collect(Collectors.toList()));
headers.stream().map(h -> buildHeader(h, client)).collect(Collectors.toList()));
}
clients.put(interfaceRef, client);
return client;
}



private Header buildHeader(WorkItemHeaderInfo header) {
private Header buildHeader(WorkItemHeaderInfo header, Client client) {
String namespace = (String) header.getParam("NS");
QName name = namespace == null ? new QName(header.getName()) : new QName(namespace, header.getName());
Class<?> contentClass = String.class;
String type = (String) header.getParam("TYPE");
boolean escape = Boolean.parseBoolean((String) header.getParam("ESCAPE"));
if (type != null) {
JAXBDataBinding binding = (JAXBDataBinding) client.getConduitSelector().getEndpoint().getService().getDataBinding();
String escapeHandler = (String) header.getParam("ESCAPE");
if (escapeHandler != null)
try {
contentClass = classLoader.loadClass(type);
} catch (ClassNotFoundException ex) {
logger.warn("Cannot find type {}", type, ex);
binding = new JAXBDataBinding(binding.getContext());
setEscapeHandler(binding, escapeHandler);
} catch (Exception ex) {
logger.warn("Error creating binding for escapeHandler {}", escapeHandler, ex);
}
return new Header(name, header.getContent(), binding);
}

private void setEscapeHandler(WorkItem workItem, Client client) {
String escapeHandler = (String) workItem.getParameter("ESCAPE_HANDLER");
if (escapeHandler == null) {
escapeHandler = System.getProperty("org.jbpm.cxf.client.escapeHandler");
}
if (escapeHandler != null) {
DataBinding binding = client.getConduitSelector().getEndpoint().getService().getDataBinding();
if (binding instanceof JAXBDataBinding) {
setEscapeHandler((JAXBDataBinding) binding, escapeHandler);
}
}
JAXBDataBinding binding = null;
}

private static void setEscapeHandler(JAXBDataBinding dataBinding, String escapeHandler) {
Object escapeHandlerObj = createEscapeHandler(dataBinding, escapeHandler);
logger.debug("Escape handler {} created. Object is {}", escapeHandler, escapeHandlerObj);
Map<String, Object> map = dataBinding.getMarshallerProperties();
if (map == null) {
map = new HashMap<>();
}
// if implementation is not reference one, this should be ignored.
map.put("com.sun.xml.bind.characterEscapeHandler", escapeHandlerObj);
map.put("com.sun.xml.bind.marshaller.CharacterEscapeHandler", escapeHandlerObj);
logger.debug("Marshalling properties {}", map);
dataBinding.setMarshallerProperties(map);
}

/* this code is a modified version of cxf that uses the class loader of context, no current thread context */
public static Object createEscapeHandler(JAXBDataBinding binding, String escapeHandler) {
Class<?> cls = binding.getContext().getClass();
ClassLoader classLoader = cls.getClassLoader();
String className = cls.getName();
String postFix = className.contains("com.sun.xml.internal") || className.contains("eclipse") ? ".internal" : "";
try {
binding = new JAXBDataBinding(contentClass);
if (!escape) {
binding.setEscapeHandler(JAXBUtils.createNoEscapeHandler(binding.getContext().getClass()));
Class<?> handlerInterface = classLoader.loadClass("com.sun.xml" + postFix + ".bind.marshaller.CharacterEscapeHandler");
Class<?> handlerClass = classLoader.loadClass("com.sun.xml" + postFix + ".bind.marshaller." + escapeHandler);
Object targetHandler = ReflectionUtil.getDeclaredField(handlerClass, "theInstance").get(null);
return ProxyHelper.getProxy(classLoader, new Class[]{handlerInterface}, new LoggingEscapeHandlerInvocationHandler(targetHandler));
} catch (Exception e) {
logger.warn("Error instantiating escape handler, characters will be escaped", e);
}
return null;
}

private static final class LoggingEscapeHandlerInvocationHandler implements InvocationHandler {

private Object target;

public LoggingEscapeHandlerInvocationHandler(Object obj) {
target = obj;
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
logger.debug("Escape handler invoked with args {}", args);
Object result = null;
if (method.getName().equals("escape") && args.length == 5) {
if ((Integer) args[1] == 0 && (Integer) args[2] == 0) {
Writer writer = (Writer) args[4];
writer.write("");
return null;
}
result = method.invoke(target, args);
}
} catch (Exception ex) {
logger.warn("Error creating binding for type {}", type, ex);
logger.debug("Escape handler result {}", result);
return result;
}
return new Header(name, header.getContent(), binding);
}

private void setClientTimeout(WorkItem workItem, Client client) {
Expand Down
Expand Up @@ -16,6 +16,16 @@

package org.jbpm.process.workitem.webservice;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.util.ArrayList;
import java.util.concurrent.ConcurrentHashMap;

Expand All @@ -30,6 +40,7 @@
import org.apache.cxf.transport.http.HTTPConduit;
import org.drools.core.process.instance.impl.WorkItemImpl;
import org.jbpm.process.workitem.core.TestWorkItemManager;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.kie.api.runtime.KieSession;
Expand All @@ -38,16 +49,6 @@
import org.mockito.Mockito;
import org.mockito.runners.MockitoJUnitRunner;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@RunWith(MockitoJUnitRunner.class)
public class WebServiceWorkItemHandlerTest {

Expand All @@ -60,12 +61,15 @@ public class WebServiceWorkItemHandlerTest {
@Mock
ConcurrentHashMap<String, Client> clients;

@Before
public void setUp() {
when(clients.computeIfAbsent(any(), any())).thenReturn(client);
}

@Test
public void testExecuteSyncOperation() throws Exception {

when(clients.containsKey(any())).thenReturn(true);
when(clients.get(any())).thenReturn(client);


TestWorkItemManager manager = new TestWorkItemManager();
WorkItemImpl workItem = new WorkItemImpl();
workItem.setParameter("Interface",
Expand All @@ -91,8 +95,6 @@ public void testExecuteSyncOperation() throws Exception {
@Test
public void testExecuteWrappedModeSync() throws Exception {

when(clients.containsKey(any())).thenReturn(true);
when(clients.get(any())).thenReturn(client);
Endpoint endpoint = mock(Endpoint.class);
when(client.getEndpoint()).thenReturn(endpoint);
ArrayList<Interceptor<? extends Message>> interceptors = new ArrayList<>();
Expand Down Expand Up @@ -129,8 +131,6 @@ public void testExecuteWrappedModeSync() throws Exception {
@Test
public void testExecuteWrappedModeOneWay() throws Exception {

when(clients.containsKey(any())).thenReturn(true);
when(clients.get(any())).thenReturn(client);
Endpoint endpoint = mock(Endpoint.class);
when(client.getEndpoint()).thenReturn(endpoint);
ArrayList<Interceptor<? extends Message>> interceptors = new ArrayList<>();
Expand Down Expand Up @@ -164,8 +164,6 @@ public void testExecuteWrappedModeOneWay() throws Exception {
@Test
public void testExecuteWrappedModeAsync() throws Exception {

when(clients.containsKey(any())).thenReturn(true);
when(clients.get(any())).thenReturn(client);
Endpoint endpoint = mock(Endpoint.class);
when(client.getEndpoint()).thenReturn(endpoint);
ArrayList<Interceptor<? extends Message>> interceptors = new ArrayList<>();
Expand Down Expand Up @@ -202,8 +200,6 @@ public void testExecuteSyncOperationWithBasicAuth() throws Exception {
HTTPConduit http = Mockito.mock(HTTPConduit.class,
Mockito.CALLS_REAL_METHODS);

when(clients.containsKey(any())).thenReturn(true);
when(clients.get(any())).thenReturn(client);
when(client.getConduit()).thenReturn(http);

TestWorkItemManager manager = new TestWorkItemManager();
Expand Down Expand Up @@ -241,10 +237,7 @@ public void testExecuteSyncOperationWithBasicAuth() throws Exception {

@Test
public void testExecuteSyncOperationHandlingException() throws Exception {

when(clients.containsKey(any())).thenReturn(true);
when(clients.get(any())).thenReturn(null);

when(clients.computeIfAbsent(any(), any())).thenReturn(null);
TestWorkItemManager manager = new TestWorkItemManager();
WorkItemImpl workItem = new WorkItemImpl();
workItem.setParameter("Interface",
Expand Down

0 comments on commit b2884a5

Please sign in to comment.