diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java index df7875fc7ae..0330874cdf4 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java +++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java @@ -82,7 +82,7 @@ protected GrpcHttp2InboundHeaders newHeaders() { * A {@link Http2Headers} implementation optimized for inbound/received headers. * *

Header names and values are stored in simple arrays, which makes insert run in O(1) - * and retrievial a O(n). Header name equality is not determined by the equals implementation of + * and retrieval a O(n). Header name equality is not determined by the equals implementation of * {@link CharSequence} type, but by comparing two names byte to byte. * *

All {@link CharSequence} input parameters and return values are required to be of type @@ -104,6 +104,11 @@ abstract static class GrpcHttp2InboundHeaders extends AbstractHttp2Headers { values = new AsciiString[numHeadersGuess]; } + @Override + public final Http2Headers add(CharSequence csName, CharSequence csValue) { + return add(validateName(requireAsciiString(csName)), requireAsciiString(csValue)); + } + protected Http2Headers add(AsciiString name, AsciiString value) { byte[] nameBytes = bytes(name); byte[] valueBytes; @@ -136,6 +141,11 @@ private void addHeader(AsciiString value, byte[] nameBytes, byte[] valueBytes) { namesAndValuesIdx++; } + @Override + public final CharSequence get(CharSequence csName) { + return get(requireAsciiString(csName)); + } + protected CharSequence get(AsciiString name) { for (int i = 0; i < namesAndValuesIdx; i += 2) { if (equals(name, namesAndValues[i])) { @@ -146,25 +156,31 @@ protected CharSequence get(AsciiString name) { } @Override - public boolean contains(CharSequence name) { + public final boolean contains(CharSequence name) { return get(name) != null; } @Override - public CharSequence status() { + public final CharSequence status() { return get(Http2Headers.PseudoHeaderName.STATUS.value()); } @Override - public List getAll(CharSequence csName) { - AsciiString name = requireAsciiString(csName); - List returnValues = new ArrayList<>(4); + public final List getAll(CharSequence csName) { + return getAll(requireAsciiString(csName)); + } + + protected List getAll(AsciiString name) { + List returnValues = null; for (int i = 0; i < namesAndValuesIdx; i += 2) { if (equals(name, namesAndValues[i])) { + if (returnValues == null) { + returnValues = new ArrayList<>(4); + } returnValues.add(values[i / 2]); } } - return returnValues; + return returnValues != null ? returnValues : Collections.emptyList(); } @CanIgnoreReturnValue @@ -195,7 +211,7 @@ public boolean remove(CharSequence csName) { } @Override - public Http2Headers set(CharSequence name, CharSequence value) { + public final Http2Headers set(CharSequence name, CharSequence value) { remove(name); return add(name, value); } @@ -231,16 +247,16 @@ protected static boolean equals(AsciiString str0, byte[] str1) { } protected static boolean equals(AsciiString str0, AsciiString str1) { - return equals(str0.array(), str0.arrayOffset(), str0.length(), str1.array(), - str1.arrayOffset(), str1.length()); + int length0 = str0.length(); + return length0 == str1.length() && str0.hashCode() == str1.hashCode() + && PlatformDependent.equals(str0.array(), str0.arrayOffset(), + str1.array(), str1.arrayOffset(), length0); } protected static boolean equals(byte[] bytes0, int offset0, int length0, byte[] bytes1, int offset1, int length1) { - if (length0 != length1) { - return false; - } - return PlatformDependent.equals(bytes0, offset0, bytes1, offset1, length0); + return length0 == length1 + && PlatformDependent.equals(bytes0, offset0, bytes1, offset1, length0); } protected static byte[] bytes(AsciiString str) { @@ -296,9 +312,7 @@ protected static void appendNameAndValue(StringBuilder builder, CharSequence nam builder.append(name).append(": ").append(value); } - protected final String namesAndValuesToString() { - StringBuilder builder = new StringBuilder(); - boolean prependSeparator = false; + protected StringBuilder appendNamesAndValues(StringBuilder builder, boolean prependSeparator) { for (int i = 0; i < namesAndValuesIdx; i += 2) { String name = new String(namesAndValues[i], US_ASCII); // If binary headers, the value is base64 encoded. @@ -306,7 +320,13 @@ protected final String namesAndValuesToString() { appendNameAndValue(builder, name, value, prependSeparator); prependSeparator = true; } - return builder.toString(); + return builder; + } + + @Override + public final String toString() { + StringBuilder builder = new StringBuilder(getClass().getSimpleName()).append('['); + return appendNamesAndValues(builder, false).append(']').toString(); } } @@ -336,62 +356,75 @@ static final class GrpcHttp2RequestHeaders extends GrpcHttp2InboundHeaders { } @Override - public Http2Headers add(CharSequence csName, CharSequence csValue) { - AsciiString name = validateName(requireAsciiString(csName)); - AsciiString value = requireAsciiString(csValue); + protected Http2Headers add(AsciiString name, AsciiString value) { if (isPseudoHeader(name)) { - addPseudoHeader(name, value); + setPseudoHeader(name, value, true); return this; } if (equals(TE_HEADER, name)) { te = value; return this; } - return add(name, value); + return super.add(name, value); } @Override - public CharSequence get(CharSequence csName) { - AsciiString name = requireAsciiString(csName); - checkArgument(!isPseudoHeader(name), "Use direct accessor methods for pseudo headers."); + protected CharSequence get(AsciiString name) { + if (isPseudoHeader(name)) { + return getPseudoHeader(name); + } if (equals(TE_HEADER, name)) { return te; } - return get(name); + return super.get(name); } - private void addPseudoHeader(CharSequence csName, CharSequence csValue) { - AsciiString name = requireAsciiString(csName); - AsciiString value = requireAsciiString(csValue); + private CharSequence getPseudoHeader(AsciiString name) { + if (equals(PATH_HEADER, name)) { + return path; + } + if (equals(AUTHORITY_HEADER, name)) { + return authority; + } + if (equals(METHOD_HEADER, name)) { + return method; + } + if (equals(SCHEME_HEADER, name)) { + return scheme; + } + return null; + } + @CanIgnoreReturnValue + private boolean setPseudoHeader(AsciiString name, AsciiString value, boolean requireUnique) { + boolean wasPresent = false; if (equals(PATH_HEADER, name)) { - if (path != null) { - PlatformDependent.throwException( - connectionError(PROTOCOL_ERROR, "Duplicate :path header")); - } + wasPresent = checkPseudoHeader(PATH_HEADER, path, requireUnique); path = value; } else if (equals(AUTHORITY_HEADER, name)) { - if (authority != null) { - PlatformDependent.throwException( - connectionError(PROTOCOL_ERROR, "Duplicate :authority header")); - } + wasPresent = checkPseudoHeader(AUTHORITY_HEADER, authority, requireUnique); authority = value; } else if (equals(METHOD_HEADER, name)) { - if (method != null) { - PlatformDependent.throwException( - connectionError(PROTOCOL_ERROR, "Duplicate :method header")); - } + wasPresent = checkPseudoHeader(METHOD_HEADER, method, requireUnique); method = value; } else if (equals(SCHEME_HEADER, name)) { - if (scheme != null) { - PlatformDependent.throwException( - connectionError(PROTOCOL_ERROR, "Duplicate :scheme header")); - } + wasPresent = checkPseudoHeader(SCHEME_HEADER, scheme, requireUnique); scheme = value; } else { PlatformDependent.throwException( connectionError(PROTOCOL_ERROR, "Illegal pseudo-header '%s' in request.", name)); } + return wasPresent; + } + + private static boolean checkPseudoHeader(AsciiString name, AsciiString oldValue, + boolean requireNull) { + boolean present = oldValue != null; + if (requireNull && present) { + PlatformDependent.throwException( + connectionError(PROTOCOL_ERROR, "Duplicate %s header", name)); + } + return present; } @Override @@ -415,24 +448,20 @@ public CharSequence scheme() { } @Override - public List getAll(CharSequence csName) { - AsciiString name = requireAsciiString(csName); - if (isPseudoHeader(name)) { - // This code should never be reached. - throw new IllegalArgumentException("Use direct accessor methods for pseudo headers."); + protected List getAll(AsciiString name) { + boolean isPseudo = isPseudoHeader(name); + if (isPseudo || equals(TE_HEADER, name)) { + CharSequence value = isPseudo ? get(name) : te; + return value != null ? Collections.singletonList(value) : Collections.emptyList(); } - if (equals(TE_HEADER, name)) { - return Collections.singletonList((CharSequence) te); - } - return super.getAll(csName); + return super.getAll(name); } @Override public boolean remove(CharSequence csName) { AsciiString name = requireAsciiString(csName); if (isPseudoHeader(name)) { - // This code should never be reached. - throw new IllegalArgumentException("Use direct accessor methods for pseudo headers."); + return setPseudoHeader(name, null, false); } if (equals(TE_HEADER, name)) { boolean wasPresent = te != null; @@ -463,15 +492,11 @@ public int size() { if (te != null) { size++; } - size += super.size(); - return size; + return size + super.size(); } @Override - public String toString() { - StringBuilder builder = new StringBuilder(getClass().getSimpleName()).append('['); - boolean prependSeparator = false; - + protected StringBuilder appendNamesAndValues(StringBuilder builder, boolean prependSeparator) { if (path != null) { appendNameAndValue(builder, PATH_HEADER, path, prependSeparator); prependSeparator = true; @@ -490,18 +515,9 @@ public String toString() { } if (te != null) { appendNameAndValue(builder, TE_HEADER, te, prependSeparator); + prependSeparator = true; } - - String namesAndValues = namesAndValuesToString(); - - if (builder.length() > 0 && namesAndValues.length() > 0) { - builder.append(", "); - } - - builder.append(namesAndValues); - builder.append(']'); - - return builder.toString(); + return super.appendNamesAndValues(builder, prependSeparator); } } @@ -517,25 +533,5 @@ static final class GrpcHttp2ResponseHeaders extends GrpcHttp2InboundHeaders { GrpcHttp2ResponseHeaders(int numHeadersGuess) { super(numHeadersGuess); } - - @Override - public Http2Headers add(CharSequence csName, CharSequence csValue) { - AsciiString name = validateName(requireAsciiString(csName)); - AsciiString value = requireAsciiString(csValue); - return add(name, value); - } - - @Override - public CharSequence get(CharSequence csName) { - AsciiString name = requireAsciiString(csName); - return get(name); - } - - @Override - public String toString() { - StringBuilder builder = new StringBuilder(getClass().getSimpleName()).append('['); - builder.append(namesAndValuesToString()).append(']'); - return builder.toString(); - } } }