Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #10679 - Review HTTP/2 rate control (CVE-2023-44487) #10680

Merged
merged 1 commit into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,28 @@ public boolean parse(ByteBuffer buffer)
int remaining = buffer.remaining();
if (remaining < length)
{
headerBlockFragments.storeFragment(buffer, remaining, false);
ContinuationFrame frame = new ContinuationFrame(getStreamId(), false);
if (!rateControlOnEvent(frame))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_continuation_frame_rate");

if (!headerBlockFragments.storeFragment(buffer, remaining, false))
return connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_continuation_stream");

length -= remaining;
break;
}
else
{
boolean last = hasFlag(Flags.END_HEADERS);
headerBlockFragments.storeFragment(buffer, length, last);
boolean endHeaders = hasFlag(Flags.END_HEADERS);
ContinuationFrame frame = new ContinuationFrame(getStreamId(), endHeaders);
if (!rateControlOnEvent(frame))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_continuation_frame_rate");

if (!headerBlockFragments.storeFragment(buffer, length, endHeaders))
return connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_continuation_stream");

reset();
if (last)
if (endHeaders)
return onHeaders(buffer);
return true;
}
Expand All @@ -107,17 +119,21 @@ private boolean onHeaders(ByteBuffer buffer)
{
ByteBuffer headerBlock = headerBlockFragments.complete();
MetaData metaData = headerBlockParser.parse(headerBlock, headerBlock.remaining());
if (metaData == null)
return true;
HeadersFrame frame = new HeadersFrame(getStreamId(), metaData, headerBlockFragments.getPriorityFrame(), headerBlockFragments.isEndStream());
headerBlockFragments.reset();

if (metaData == HeaderBlockParser.SESSION_FAILURE)
return false;
HeadersFrame frame = new HeadersFrame(getStreamId(), metaData, headerBlockFragments.getPriorityFrame(), headerBlockFragments.isEndStream());
if (metaData == HeaderBlockParser.STREAM_FAILURE)

if (metaData != HeaderBlockParser.STREAM_FAILURE)
{
notifyHeaders(frame);
}
else
{
if (!rateControlOnEvent(frame))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_continuation_frame_rate");
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_headers_frame_rate");
}
notifyHeaders(frame);
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,40 @@

public class HeaderBlockFragments
{
private final int maxCapacity;
private PriorityFrame priorityFrame;
private boolean endStream;
private int streamId;
private boolean endStream;
private ByteBuffer storage;

public void storeFragment(ByteBuffer fragment, int length, boolean last)
public HeaderBlockFragments(int maxCapacity)
{
this.maxCapacity = maxCapacity;
}

void reset()
{
priorityFrame = null;
streamId = 0;
endStream = false;
storage = null;
}

public boolean storeFragment(ByteBuffer fragment, int length, boolean last)
{
if (storage == null)
{
int space = last ? length : length * 2;
storage = ByteBuffer.allocate(space);
if (length > maxCapacity)
return false;
int capacity = last ? length : length * 2;
storage = ByteBuffer.allocate(capacity);
}

// Grow the storage if necessary.
if (storage.remaining() < length)
{
if (storage.position() + length > maxCapacity)
return false;
int space = last ? length : length * 2;
int capacity = storage.position() + space;
ByteBuffer newStorage = ByteBuffer.allocate(capacity);
Expand All @@ -53,6 +71,7 @@ public void storeFragment(ByteBuffer fragment, int length, boolean last)
fragment.limit(fragment.position() + length);
storage.put(fragment);
fragment.limit(limit);
return true;
}

public PriorityFrame getPriorityFrame()
Expand All @@ -77,10 +96,8 @@ public void setEndStream(boolean endStream)

public ByteBuffer complete()
{
ByteBuffer result = storage;
storage = null;
result.flip();
return result;
storage.flip();
return storage;
}

public int getStreamId()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,15 @@ else if (hasFlag(Flags.END_HEADERS))
}
else
{
headerBlockFragments.setStreamId(getStreamId());
headerBlockFragments.setEndStream(isEndStream());
if (headerBlockFragments.getStreamId() != 0)
{
connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_headers_frame");
}
else
{
headerBlockFragments.setStreamId(getStreamId());
headerBlockFragments.setEndStream(isEndStream());
}
}
}

Expand Down Expand Up @@ -172,6 +179,18 @@ else if (hasFlag(Flags.PRIORITY))
break;
}
case HEADERS:
{
if (!hasFlag(Flags.END_HEADERS))
{
headerBlockFragments.setStreamId(getStreamId());
headerBlockFragments.setEndStream(isEndStream());
if (hasFlag(Flags.PRIORITY))
headerBlockFragments.setPriorityFrame(new PriorityFrame(getStreamId(), parentStreamId, weight, exclusive));
}
state = State.HEADER_BLOCK;
break;
}
case HEADER_BLOCK:
{
if (hasFlag(Flags.END_HEADERS))
{
Expand All @@ -196,7 +215,7 @@ else if (hasFlag(Flags.PRIORITY))
{
HeadersFrame frame = new HeadersFrame(getStreamId(), metaData, null, isEndStream());
if (!rateControlOnEvent(frame))
connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_headers_frame_rate");
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_headers_frame_rate");
}
}
}
Expand All @@ -205,16 +224,14 @@ else if (hasFlag(Flags.PRIORITY))
int remaining = buffer.remaining();
if (remaining < length)
{
headerBlockFragments.storeFragment(buffer, remaining, false);
if (!headerBlockFragments.storeFragment(buffer, remaining, false))
return connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_headers_frame");
length -= remaining;
}
else
{
headerBlockFragments.setStreamId(getStreamId());
headerBlockFragments.setEndStream(isEndStream());
if (hasFlag(Flags.PRIORITY))
headerBlockFragments.setPriorityFrame(new PriorityFrame(getStreamId(), parentStreamId, weight, exclusive));
headerBlockFragments.storeFragment(buffer, length, false);
if (!headerBlockFragments.storeFragment(buffer, length, false))
return connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_headers_frame");
state = State.PADDING;
loop = paddingLength == 0;
}
Expand Down Expand Up @@ -258,6 +275,6 @@ private void onHeaders(HeadersFrame frame)

private enum State
{
PREPARE, PADDING_LENGTH, EXCLUSIVE, PARENT_STREAM_ID, PARENT_STREAM_ID_BYTES, WEIGHT, HEADERS, PADDING
PREPARE, PADDING_LENGTH, EXCLUSIVE, PARENT_STREAM_ID, PARENT_STREAM_ID_BYTES, WEIGHT, HEADERS, HEADER_BLOCK, PADDING
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public void init(Listener listener)
this.listener = listener;
unknownBodyParser = new UnknownBodyParser(headerParser, listener);
HeaderBlockParser headerBlockParser = new HeaderBlockParser(headerParser, byteBufferPool, hpackDecoder, unknownBodyParser);
HeaderBlockFragments headerBlockFragments = new HeaderBlockFragments();
HeaderBlockFragments headerBlockFragments = new HeaderBlockFragments(hpackDecoder.getMaxHeaderListSize());
bodyParsers[FrameType.DATA.getType()] = new DataBodyParser(headerParser, listener);
bodyParsers[FrameType.HEADERS.getType()] = new HeadersBodyParser(headerParser, listener, headerBlockParser, headerBlockFragments);
bodyParsers[FrameType.PRIORITY.getType()] = new PriorityBodyParser(headerParser, listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.eclipse.jetty.http.MetaData;
import org.eclipse.jetty.http2.ErrorCode;
import org.eclipse.jetty.http2.Flags;
import org.eclipse.jetty.http2.frames.HeadersFrame;
import org.eclipse.jetty.http2.frames.PushPromiseFrame;

public class PushPromiseBodyParser extends BodyParser
Expand Down Expand Up @@ -70,13 +71,9 @@ public boolean parse(ByteBuffer buffer)
length = getBodyLength();

if (isPadding())
{
state = State.PADDING_LENGTH;
}
else
{
state = State.STREAM_ID;
}
break;
}
case PADDING_LENGTH:
Expand Down Expand Up @@ -136,7 +133,15 @@ public boolean parse(ByteBuffer buffer)
state = State.PADDING;
loop = paddingLength == 0;
if (metaData != HeaderBlockParser.STREAM_FAILURE)
{
onPushPromise(streamId, metaData);
}
else
{
HeadersFrame frame = new HeadersFrame(getStreamId(), metaData, null, isEndStream());
if (!rateControlOnEvent(frame))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_headers_frame_rate");
}
}
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public boolean parse(ByteBuffer buffer)
{
if (buffer.remaining() >= 4)
{
return onReset(buffer.getInt());
return onReset(buffer, buffer.getInt());
}
else
{
Expand All @@ -78,7 +78,7 @@ public boolean parse(ByteBuffer buffer)
--cursor;
error += currByte << (8 * cursor);
if (cursor == 0)
return onReset(error);
return onReset(buffer, error);
break;
}
default:
Expand All @@ -90,9 +90,11 @@ public boolean parse(ByteBuffer buffer)
return false;
}

private boolean onReset(int error)
private boolean onReset(ByteBuffer buffer, int error)
{
ResetFrame frame = new ResetFrame(getStreamId(), error);
if (!rateControlOnEvent(frame))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_rst_stream_frame_rate");
reset();
notifyReset(frame);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,22 @@ public int getMaxKeys()
@Override
protected void emptyBody(ByteBuffer buffer)
{
if (!validateFrame(buffer, getStreamId(), 0))
return;
boolean isReply = hasFlag(Flags.ACK);
SettingsFrame frame = new SettingsFrame(Collections.emptyMap(), isReply);
if (!isReply && !rateControlOnEvent(frame))
connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_settings_frame_rate");
else
onSettings(frame);
onSettings(buffer, frame);
}

private boolean validateFrame(ByteBuffer buffer, int streamId, int bodyLength)
{
// SPEC: wrong streamId is treated as connection error.
if (streamId != 0)
return connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_settings_frame");
// SPEC: reply with body is treated as connection error.
if (hasFlag(Flags.ACK) && bodyLength > 0)
return connectionFailure(buffer, ErrorCode.FRAME_SIZE_ERROR.code, "invalid_settings_frame");
return true;
}

@Override
Expand All @@ -95,9 +105,8 @@ private boolean parse(ByteBuffer buffer, int streamId, int bodyLength)
{
case PREPARE:
{
// SPEC: wrong streamId is treated as connection error.
if (streamId != 0)
return connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_settings_frame");
if (!validateFrame(buffer, streamId, bodyLength))
return false;
length = bodyLength;
settings = new HashMap<>();
state = State.SETTING_ID;
Expand Down Expand Up @@ -211,11 +220,13 @@ protected boolean onSettings(ByteBuffer buffer, Map<Integer, Integer> settings)
return connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_settings_max_frame_size");

SettingsFrame frame = new SettingsFrame(settings, hasFlag(Flags.ACK));
return onSettings(frame);
return onSettings(buffer, frame);
}

private boolean onSettings(SettingsFrame frame)
private boolean onSettings(ByteBuffer buffer, SettingsFrame frame)
{
if (!rateControlOnEvent(frame))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_settings_frame_rate");
reset();
notifySettings(frame);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ public boolean parse(ByteBuffer buffer)
boolean parsed = cursor == 0;
if (parsed && !rateControlOnEvent(new UnknownFrame(getFrameType())))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_unknown_frame_rate");

return parsed;
}

Expand Down