Skip to content

Commit

Permalink
Made MutableHttpServletRequest case-insensitive
Browse files Browse the repository at this point in the history
  • Loading branch information
MINHN98 committed May 4, 2023
1 parent 01b5d12 commit dec6d30
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 45 deletions.
Expand Up @@ -18,31 +18,34 @@

import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.List;
import java.util.Set;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

import org.glassfish.jersey.internal.util.collection.StringKeyIgnoreCaseMultivaluedMap;

final class MutableHttpServletRequest extends HttpServletRequestWrapper {
// Allows for adding and replacing custom headers in the HttpServletRequest object.
private final Map<String, String> customHeaders;
// Only one value per custom header is allowed.
private final StringKeyIgnoreCaseMultivaluedMap<String> customHeaders;

public MutableHttpServletRequest(HttpServletRequest request) {
super(request);
this.customHeaders = new HashMap<String, String>();
this.customHeaders = new StringKeyIgnoreCaseMultivaluedMap<String>();
}

public void putHeader(String name, String value) {
// Putting a new header will take precedence over existing values in HttpServletRequest
this.customHeaders.put(name, value);
// Putting a new header will take precedence over existing values in HttpServletRequest.
// Value will also overwrite any existing custom header value.
this.customHeaders.putSingle(name, value);
}

public String getHeader(String name) {
// check the custom headers first
String headerValue = customHeaders.get(name);
// check the custom headers first and return the value if it exists.
String headerValue = customHeaders.getFirst(name);

if (headerValue != null) {
return headerValue;
Expand All @@ -52,26 +55,43 @@ public String getHeader(String name) {
}

public Enumeration<String> getHeaders(String name) {
// check the custom headers first
String headerValue = customHeaders.get(name);
// check the custom headers first and return the value if it exists.
List<String> headerValues = customHeaders.get(name);

if (headerValue != null) {
return Collections.enumeration(Collections.singletonList(headerValue));
if (headerValues != null) {
return Collections.enumeration(headerValues);
}
// else return from into the original wrapped object
return ((HttpServletRequest) getRequest()).getHeaders(name);
}

public Enumeration<String> getHeaderNames() {
// create a set of the custom header names
Set<String> set = new HashSet<String>(customHeaders.keySet());
// Return the unique (case-insensitive) header names in customHeaders and HttpServletRequest
HttpServletRequest request = (HttpServletRequest)getRequest();
Set<String> set = new HashSet<String>();
Set<String> usedHeaders = new HashSet<String>();

// add the custom headers
for (String key : customHeaders.keySet()) {
// Only add custom header it hasn't already been added (case-insensitive)
String keyLower = key.toLowerCase();
if (!usedHeaders.contains(keyLower)) {
set.add(key);
usedHeaders.add(keyLower);
}
}

// now add the headers from the wrapped request object
Enumeration<String> e = ((HttpServletRequest) getRequest()).getHeaderNames();
// add the HttpServletRequest headers
Enumeration<String> e = request.getHeaderNames();
while (e.hasMoreElements()) {
// add the names of the request headers into the list
String n = e.nextElement();
set.add(n);
String key = e.nextElement();
String keyLower = key.toLowerCase();
// Only add custom header it hasn't already been added (case insensitive)
if (!usedHeaders.contains(keyLower)) {
set.add(key);
usedHeaders.add(keyLower);
}
}

// create an enumeration from the set and return
Expand Down
Expand Up @@ -4,13 +4,11 @@

package io.confluent.kafka.schemaregistry.rest;

import org.eclipse.jetty.server.Request;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -61,8 +59,6 @@ public void testGetWithNoInitialHeaders() {

@Test
public void testGetAndPutWithNoInitialHeaders() {
String headerValue;
Enumeration<String> headerValues;
Enumeration<String> headerNames;
List<String> headerValuesList;
List<String> headerNamesList;
Expand All @@ -74,18 +70,41 @@ public void testGetAndPutWithNoInitialHeaders() {

// Add a "header-key-0" header
mutableRequest.putHeader("header-key-0", "new-header-value-98");
headerValue = mutableRequest.getHeader("header-key-0");
Assert.assertEquals("new-header-value-98", headerValue);

// Validate gets
headerValue = mutableRequest.getHeader("header-key-0");
Assert.assertEquals("new-header-value-98", headerValue);
// Validate gets (should be case-insensitive)
Assert.assertEquals("new-header-value-98", mutableRequest.getHeader("header-Key-0"));
Assert.assertEquals("new-header-value-98", mutableRequest.getHeader("header-key-0"));
Assert.assertEquals("new-header-value-98", mutableRequest.getHeader("HEADER-KEY-0"));

headerValues = mutableRequest.getHeaders("header-key-0");
headerValuesList = Collections.list(headerValues);
headerValuesList = Collections.list(mutableRequest.getHeaders("header-key-0"));
Assert.assertEquals(1, headerValuesList.size());
Assert.assertEquals("new-header-value-98", headerValuesList.get(0));

headerValuesList = Collections.list(mutableRequest.getHeaders("header-KEY-0"));
Assert.assertEquals(1, headerValuesList.size());
Assert.assertEquals("new-header-value-98", headerValuesList.get(0));

headerNames = mutableRequest.getHeaderNames();
headerNamesList = Collections.list(headerNames);
Assert.assertEquals(1, headerNamesList.size());
Assert.assertTrue(headerNamesList.contains("header-key-0"));

// Replace with a "header-Key-0" header key to test cast-insensitivity
mutableRequest.putHeader("header-Key-0", "new-header-value-75");

// Validate gets (should be case-insensitive)
Assert.assertEquals("new-header-value-75", mutableRequest.getHeader("header-Key-0"));
Assert.assertEquals("new-header-value-75", mutableRequest.getHeader("header-key-0"));
Assert.assertEquals("new-header-value-75", mutableRequest.getHeader("HEADER-KEY-0"));

headerValuesList = Collections.list(mutableRequest.getHeaders("header-key-0"));
Assert.assertEquals(1, headerValuesList.size());
Assert.assertEquals("new-header-value-75", headerValuesList.get(0));

headerValuesList = Collections.list(mutableRequest.getHeaders("header-KEY-0"));
Assert.assertEquals(1, headerValuesList.size());
Assert.assertEquals("new-header-value-75", headerValuesList.get(0));

headerNames = mutableRequest.getHeaderNames();
headerNamesList = Collections.list(headerNames);
Assert.assertEquals(1, headerNamesList.size());
Expand All @@ -95,7 +114,6 @@ public void testGetAndPutWithNoInitialHeaders() {
@Test
public void testGetAndPutWithInitialHeaders(){
// A more comprehensive test
String headerValue;
Enumeration<String> headerValues;
Enumeration<String> headerNames;
List<String> headerValuesList;
Expand All @@ -105,28 +123,26 @@ public void testGetAndPutWithInitialHeaders(){
when(httpServletRequest.getHeader("header-key-0")).thenReturn("header-value-78.1");
when(httpServletRequest.getHeaders("header-key-0")).thenReturn(
Collections.enumeration(Arrays.asList("header-value-78.1", "header-value-78.2")));
when(httpServletRequest.getHeaderNames()).thenReturn(
Collections.enumeration(Collections.singletonList("header-key-0")));

when(httpServletRequest.getHeader("header-key-2")).thenReturn("header-value-62.1");
when(httpServletRequest.getHeaders("header-key-2")).thenReturn(
Collections.enumeration(Arrays.asList("header-value-62.1", "header-value-62.2")));

// Include "header-KEY-0" here to make sure that getHeaderNames() filters it out.
when(httpServletRequest.getHeaderNames()).thenReturn(
Collections.enumeration(Collections.singletonList("header-key-2")));
Collections.enumeration(Arrays.asList("header-key-0", "header-KEY-0", "header-key-2")));


// Test getting header values from HttpServletRequest (should match the mock values above)
headerValue = mutableRequest.getHeader("header-key-0");
Assert.assertEquals("header-value-78.1", headerValue);
Assert.assertEquals("header-value-78.1", mutableRequest.getHeader("header-key-0"));

headerValues = mutableRequest.getHeaders("header-key-0");
headerValuesList = Collections.list(headerValues);
Assert.assertEquals(2, headerValuesList.size());
Assert.assertTrue(headerValuesList.contains("header-value-78.1"));
Assert.assertTrue(headerValuesList.contains("header-value-78.2"));

headerValue = mutableRequest.getHeader("header-key-2");
Assert.assertEquals("header-value-62.1", headerValue);
Assert.assertEquals("header-value-62.1", mutableRequest.getHeader("header-key-2"));

headerValues = mutableRequest.getHeaders("header-key-2");
headerValuesList = Collections.list(headerValues);
Expand All @@ -137,28 +153,39 @@ public void testGetAndPutWithInitialHeaders(){

// Test putHeader (overwrite and new values) and getHeader
mutableRequest.putHeader("header-key-0", "new-header-value-98"); // overwrite mock
headerValue = mutableRequest.getHeader("header-key-0");
Assert.assertEquals("new-header-value-98", headerValue);
Assert.assertEquals("new-header-value-98", mutableRequest.getHeader("header-key-0"));
Assert.assertEquals("new-header-value-98", mutableRequest.getHeader("header-KEY-0"));

mutableRequest.putHeader("header-key-0", "new-header-value-100"); // overwrite above
headerValue = mutableRequest.getHeader("header-key-0");
Assert.assertEquals("new-header-value-100", headerValue);
Assert.assertEquals("new-header-value-100", mutableRequest.getHeader("header-key-0"));
Assert.assertEquals("new-header-value-100", mutableRequest.getHeader("Header-Key-0"));

mutableRequest.putHeader("header-key-1", "new-header-value-54"); // new
headerValue = mutableRequest.getHeader("header-key-1");
Assert.assertEquals("new-header-value-54", headerValue);
mutableRequest.putHeader("HEADER-KEY-0", "new-header-value-999"); // overwrite above
Assert.assertEquals("new-header-value-999", mutableRequest.getHeader("header-key-0"));
Assert.assertEquals("new-header-value-999", mutableRequest.getHeader("Header-Key-0"));

mutableRequest.putHeader("header-key-1", "new-header-value-54"); // new
Assert.assertEquals("new-header-value-54", mutableRequest.getHeader("header-key-1"));
Assert.assertEquals("new-header-value-54", mutableRequest.getHeader("HEADER-key-1"));

// Test getHeaders
headerValues = mutableRequest.getHeaders("header-key-0");
headerValuesList = Collections.list(headerValues);
Assert.assertEquals(1, headerValuesList.size());
Assert.assertEquals("new-header-value-100", headerValuesList.get(0));
Assert.assertEquals("new-header-value-999", headerValuesList.get(0));
headerValues = mutableRequest.getHeaders("Header-Key-0"); // Test case-insensitivity
headerValuesList = Collections.list(headerValues);
Assert.assertEquals(1, headerValuesList.size());
Assert.assertEquals("new-header-value-999", headerValuesList.get(0));

headerValues = mutableRequest.getHeaders("header-key-1");
headerValuesList = Collections.list(headerValues);
Assert.assertEquals(1, headerValuesList.size());
Assert.assertEquals("new-header-value-54", headerValuesList.get(0));
headerValues = mutableRequest.getHeaders("HEADER-key-1"); // Test case-insensitivity
headerValuesList = Collections.list(headerValues);
Assert.assertEquals(1, headerValuesList.size());
Assert.assertEquals("new-header-value-54", headerValuesList.get(0));


// Test getHeaderNames
Expand Down

0 comments on commit dec6d30

Please sign in to comment.