Skip to content

Commit

Permalink
Don't allocate memory when performing hpack huffman decoding
Browse files Browse the repository at this point in the history
Use a ring buffer instead.  Hopefully this should increase the fuzzing
throughput.
  • Loading branch information
lpereira committed May 16, 2024
1 parent 3508980 commit 6d12dd8
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 110 deletions.
11 changes: 3 additions & 8 deletions src/bin/fuzz/h2_huffman_fuzzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,11 @@
#include <stdlib.h>

extern "C" {
uint8_t *lwan_h2_huffman_decode_for_fuzzing(const uint8_t *input,
size_t input_len);
bool lwan_h2_huffman_decode_for_fuzzing(const uint8_t *input,
size_t input_len);

int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size)
{
uint8_t *decoded = lwan_h2_huffman_decode_for_fuzzing(data, size);
if (decoded) {
free(decoded);
return 0;
}
return 1;
return lwan_h2_huffman_decode_for_fuzzing(data, size) == true;
}
}
110 changes: 67 additions & 43 deletions src/lib/lwan-h2-huffman.c
Original file line number Diff line number Diff line change
Expand Up @@ -305,86 +305,97 @@ static inline bool consume(struct bit_reader *reader, int count)
return reader->total_bitcount > 0;
}

static inline size_t output_size(size_t input_size)
DEFINE_RING_BUFFER_TYPE(uint8_ring_buffer, uint8_t, 64)

struct lwan_h2_huffman_decoder {
struct bit_reader bit_reader;
struct uint8_ring_buffer buffer;
};

void lwan_h2_huffman_init(struct lwan_h2_huffman_decoder *huff,
const uint8_t *input,
size_t input_len)
{
/* Smallest input is 5 bits which produces 8 bits. Scaling that to 8 bits,
* we get 12.8 bits of output per 8 bits of input. */
return (input_size * 128) / 10;
huff->bit_reader = (struct bit_reader){
.bitptr = input,
.total_bitcount = (int64_t)input_len * 8,
};
uint8_ring_buffer_init(&huff->buffer);
}

uint8_t *lwan_h2_huffman_decode_for_fuzzing(const uint8_t *input,
size_t input_len)
ssize_t lwan_h2_huffman_next(struct lwan_h2_huffman_decoder *huff)
{
uint8_t *output = malloc(output_size(input_len));
uint8_t *ret = output;
struct bit_reader bit_reader = {.bitptr = input,
.total_bitcount = (int64_t)input_len * 8};
struct bit_reader *reader = &huff->bit_reader;
struct uint8_ring_buffer *buffer = &huff->buffer;

while (bit_reader.total_bitcount > 7) {
uint8_t peeked_byte = peek_byte(&bit_reader);
while (reader->total_bitcount > 7) {
if (uint8_ring_buffer_full(buffer))
goto done;

uint8_t peeked_byte = peek_byte(reader);
if (LIKELY(level0[peeked_byte].num_bits)) {
*output++ = level0[peeked_byte].symbol;
consume(&bit_reader, level0[peeked_byte].num_bits);
uint8_ring_buffer_put_copy(buffer, level0[peeked_byte].symbol);
consume(reader, level0[peeked_byte].num_bits);
assert(bit_reader.total_bitcount >= 0);
continue;
}

if (!consume(&bit_reader, 8))
goto fail;
if (!consume(reader, 8))
return -1;

const struct h2_huffman_code *level1 = next_level0(peeked_byte);
peeked_byte = peek_byte(&bit_reader);
peeked_byte = peek_byte(reader);
if (level1[peeked_byte].num_bits) {
*output++ = level1[peeked_byte].symbol;
if (!consume(&bit_reader, level1[peeked_byte].num_bits))
goto fail;
uint8_ring_buffer_put_copy(buffer, level1[peeked_byte].symbol);
if (!consume(reader, level1[peeked_byte].num_bits))
return -1;
continue;
}

if (!consume(&bit_reader, 8))
goto fail;
if (!consume(reader, 8))
return -1;

const struct h2_huffman_code *level2 = next_level1(peeked_byte);
peeked_byte = peek_byte(&bit_reader);
peeked_byte = peek_byte(reader);
if (level2[peeked_byte].num_bits) {
*output++ = level2[peeked_byte].symbol;
if (!consume(&bit_reader, level2[peeked_byte].num_bits))
goto fail;
uint8_ring_buffer_put_copy(buffer, level2[peeked_byte].symbol);
if (!consume(reader, level2[peeked_byte].num_bits))
return -1;
continue;
}

if (!consume(&bit_reader, 8))
if (!consume(reader, 8))
goto fail;

const struct h2_huffman_code *level3 = next_level2(peeked_byte);
if (LIKELY(level3)) {
peeked_byte = peek_byte(&bit_reader);
peeked_byte = peek_byte(reader);
if (level3[peeked_byte].num_bits < 0) {
/* EOS found */
return ret;
goto done;
}
if (LIKELY(level3[peeked_byte].num_bits)) {
*output++ = level3[peeked_byte].symbol;
if (!consume(&bit_reader, level3[peeked_byte].num_bits))
goto fail;
uint8_ring_buffer_put_copy(buffer, level3[peeked_byte].symbol);
if (!consume(reader, level3[peeked_byte].num_bits))
return -1;
continue;
}
}

goto fail;
return -1;
}

/* FIXME: ensure we're not promoting types unnecessarily here */
if (bit_reader.total_bitcount) {
const uint8_t peeked_byte = peek_byte(&bit_reader);
if (reader->total_bitcount) {
const uint8_t peeked_byte = peek_byte(reader);
const uint8_t eos_prefix = ((1 << bit_reader.total_bitcount) - 1)
<< (8 - bit_reader.total_bitcount);

if ((peeked_byte & eos_prefix) == eos_prefix)
goto done;

if (level0[peeked_byte].num_bits == (int8_t)bit_reader.total_bitcount) {
*output = level0[peeked_byte].symbol;
if (level0[peeked_byte].num_bits == (int8_t)reader->total_bitcount) {
uint8_ring_buffer_put_copy(buffer, level0[peeked_byte].symbol);
goto done;
}

Expand All @@ -393,15 +404,28 @@ uint8_t *lwan_h2_huffman_decode_for_fuzzing(const uint8_t *input,
* - Incomplete sequence
* - Has overlong padding
*/
goto fail;
return -1;
}

done:
return ret;

fail:
free(ret);
return NULL;
return (ssize_t)uint8_ring_buffer_size(buffer);
}

bool lwan_h2_huffman_decode_for_fuzzing(const uint8_t *input, size_t input_len)
{
struct lwan_h2_huffman_decoder decoder;

lwan_h2_huffman_init(&decoder, input, input_len);

while (true) {
ssize_t n_decoded = lwan_h2_huffman_next(&decoder);

if (UNLIKELY(n_decoded < 0))
return false;
if (n_decoded < 64)
return true;

uint8_ring_buffer_init(&decoder->buffer);
}
}
#endif
22 changes: 22 additions & 0 deletions src/lib/ringbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@
memcpy(&rb->array[type_name_##_mask(rb->write++)], e, sizeof(*e)); \
} \
\
__attribute__((unused)) static inline void type_name_##_put_copy( \
struct type_name_ *rb, const element_type_ e) \
{ \
assert(!type_name_##_full(rb)); \
rb->array[type_name_##_mask(rb->write++)] = e; \
} \
\
__attribute__((unused)) static inline bool type_name_##_try_put( \
struct type_name_ *rb, const element_type_ *e) \
{ \
Expand All @@ -93,6 +100,21 @@
return rb->array[type_name_##_mask(rb->read++)]; \
} \
\
__attribute__((unused)) static inline void type_name_##_consume( \
struct type_name_ *rb, element_type_ *buffer, uint32_t entries) \
{ \
assert(type_name_##_size(rb) >= entries); \
const uint32_t mask = type_name_##_mask(rb->read); \
if (mask && entries > mask) { \
memcpy(buffer, &rb->array[rb->read], \
sizeof(*buffer) * (entries - mask)); \
memcpy(buffer + mask, &rb->array[0], sizeof(*buffer) * mask); \
} else { \
memcpy(buffer, &rb->array[rb->read], sizeof(*buffer) * entries); \
} \
rb->read += entries; \
} \
\
__attribute__((unused)) static inline element_type_ *type_name_##_get_ptr( \
struct type_name_ *rb) \
{ \
Expand Down
Loading

0 comments on commit 6d12dd8

Please sign in to comment.