From 6d12dd827709e8ee7485371d6cdf9d56961c1ee6 Mon Sep 17 00:00:00 2001 From: "L. Pereira" Date: Wed, 15 May 2024 22:28:39 -0700 Subject: [PATCH] Don't allocate memory when performing hpack huffman decoding Use a ring buffer instead. Hopefully this should increase the fuzzing throughput. --- src/bin/fuzz/h2_huffman_fuzzer.cc | 11 +- src/lib/lwan-h2-huffman.c | 110 +++++++++++-------- src/lib/ringbuffer.h | 22 ++++ src/scripts/gentables.py | 172 ++++++++++++++++++++---------- 4 files changed, 205 insertions(+), 110 deletions(-) diff --git a/src/bin/fuzz/h2_huffman_fuzzer.cc b/src/bin/fuzz/h2_huffman_fuzzer.cc index 270ae7a5e..ef0de8106 100644 --- a/src/bin/fuzz/h2_huffman_fuzzer.cc +++ b/src/bin/fuzz/h2_huffman_fuzzer.cc @@ -2,16 +2,11 @@ #include extern "C" { -uint8_t *lwan_h2_huffman_decode_for_fuzzing(const uint8_t *input, - size_t input_len); +bool lwan_h2_huffman_decode_for_fuzzing(const uint8_t *input, + size_t input_len); int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { - uint8_t *decoded = lwan_h2_huffman_decode_for_fuzzing(data, size); - if (decoded) { - free(decoded); - return 0; - } - return 1; + return lwan_h2_huffman_decode_for_fuzzing(data, size) == true; } } diff --git a/src/lib/lwan-h2-huffman.c b/src/lib/lwan-h2-huffman.c index 81816edf7..adeae37ea 100644 --- a/src/lib/lwan-h2-huffman.c +++ b/src/lib/lwan-h2-huffman.c @@ -305,86 +305,97 @@ static inline bool consume(struct bit_reader *reader, int count) return reader->total_bitcount > 0; } -static inline size_t output_size(size_t input_size) +DEFINE_RING_BUFFER_TYPE(uint8_ring_buffer, uint8_t, 64) + +struct lwan_h2_huffman_decoder { + struct bit_reader bit_reader; + struct uint8_ring_buffer buffer; +}; + +void lwan_h2_huffman_init(struct lwan_h2_huffman_decoder *huff, + const uint8_t *input, + size_t input_len) { - /* Smallest input is 5 bits which produces 8 bits. Scaling that to 8 bits, - * we get 12.8 bits of output per 8 bits of input. */ - return (input_size * 128) / 10; + huff->bit_reader = (struct bit_reader){ + .bitptr = input, + .total_bitcount = (int64_t)input_len * 8, + }; + uint8_ring_buffer_init(&huff->buffer); } -uint8_t *lwan_h2_huffman_decode_for_fuzzing(const uint8_t *input, - size_t input_len) +ssize_t lwan_h2_huffman_next(struct lwan_h2_huffman_decoder *huff) { - uint8_t *output = malloc(output_size(input_len)); - uint8_t *ret = output; - struct bit_reader bit_reader = {.bitptr = input, - .total_bitcount = (int64_t)input_len * 8}; + struct bit_reader *reader = &huff->bit_reader; + struct uint8_ring_buffer *buffer = &huff->buffer; - while (bit_reader.total_bitcount > 7) { - uint8_t peeked_byte = peek_byte(&bit_reader); + while (reader->total_bitcount > 7) { + if (uint8_ring_buffer_full(buffer)) + goto done; + + uint8_t peeked_byte = peek_byte(reader); if (LIKELY(level0[peeked_byte].num_bits)) { - *output++ = level0[peeked_byte].symbol; - consume(&bit_reader, level0[peeked_byte].num_bits); + uint8_ring_buffer_put_copy(buffer, level0[peeked_byte].symbol); + consume(reader, level0[peeked_byte].num_bits); assert(bit_reader.total_bitcount >= 0); continue; } - if (!consume(&bit_reader, 8)) - goto fail; + if (!consume(reader, 8)) + return -1; const struct h2_huffman_code *level1 = next_level0(peeked_byte); - peeked_byte = peek_byte(&bit_reader); + peeked_byte = peek_byte(reader); if (level1[peeked_byte].num_bits) { - *output++ = level1[peeked_byte].symbol; - if (!consume(&bit_reader, level1[peeked_byte].num_bits)) - goto fail; + uint8_ring_buffer_put_copy(buffer, level1[peeked_byte].symbol); + if (!consume(reader, level1[peeked_byte].num_bits)) + return -1; continue; } - if (!consume(&bit_reader, 8)) - goto fail; + if (!consume(reader, 8)) + return -1; const struct h2_huffman_code *level2 = next_level1(peeked_byte); - peeked_byte = peek_byte(&bit_reader); + peeked_byte = peek_byte(reader); if (level2[peeked_byte].num_bits) { - *output++ = level2[peeked_byte].symbol; - if (!consume(&bit_reader, level2[peeked_byte].num_bits)) - goto fail; + uint8_ring_buffer_put_copy(buffer, level2[peeked_byte].symbol); + if (!consume(reader, level2[peeked_byte].num_bits)) + return -1; continue; } - if (!consume(&bit_reader, 8)) + if (!consume(reader, 8)) goto fail; const struct h2_huffman_code *level3 = next_level2(peeked_byte); if (LIKELY(level3)) { - peeked_byte = peek_byte(&bit_reader); + peeked_byte = peek_byte(reader); if (level3[peeked_byte].num_bits < 0) { /* EOS found */ - return ret; + goto done; } if (LIKELY(level3[peeked_byte].num_bits)) { - *output++ = level3[peeked_byte].symbol; - if (!consume(&bit_reader, level3[peeked_byte].num_bits)) - goto fail; + uint8_ring_buffer_put_copy(buffer, level3[peeked_byte].symbol); + if (!consume(reader, level3[peeked_byte].num_bits)) + return -1; continue; } } - goto fail; + return -1; } /* FIXME: ensure we're not promoting types unnecessarily here */ - if (bit_reader.total_bitcount) { - const uint8_t peeked_byte = peek_byte(&bit_reader); + if (reader->total_bitcount) { + const uint8_t peeked_byte = peek_byte(reader); const uint8_t eos_prefix = ((1 << bit_reader.total_bitcount) - 1) << (8 - bit_reader.total_bitcount); if ((peeked_byte & eos_prefix) == eos_prefix) goto done; - if (level0[peeked_byte].num_bits == (int8_t)bit_reader.total_bitcount) { - *output = level0[peeked_byte].symbol; + if (level0[peeked_byte].num_bits == (int8_t)reader->total_bitcount) { + uint8_ring_buffer_put_copy(buffer, level0[peeked_byte].symbol); goto done; } @@ -393,15 +404,28 @@ uint8_t *lwan_h2_huffman_decode_for_fuzzing(const uint8_t *input, * - Incomplete sequence * - Has overlong padding */ - goto fail; + return -1; } done: - return ret; - -fail: - free(ret); - return NULL; + return (ssize_t)uint8_ring_buffer_size(buffer); } +bool lwan_h2_huffman_decode_for_fuzzing(const uint8_t *input, size_t input_len) +{ + struct lwan_h2_huffman_decoder decoder; + + lwan_h2_huffman_init(&decoder, input, input_len); + + while (true) { + ssize_t n_decoded = lwan_h2_huffman_next(&decoder); + + if (UNLIKELY(n_decoded < 0)) + return false; + if (n_decoded < 64) + return true; + + uint8_ring_buffer_init(&decoder->buffer); + } +} #endif diff --git a/src/lib/ringbuffer.h b/src/lib/ringbuffer.h index 1bece7c5e..7a5ecaf8e 100644 --- a/src/lib/ringbuffer.h +++ b/src/lib/ringbuffer.h @@ -76,6 +76,13 @@ memcpy(&rb->array[type_name_##_mask(rb->write++)], e, sizeof(*e)); \ } \ \ + __attribute__((unused)) static inline void type_name_##_put_copy( \ + struct type_name_ *rb, const element_type_ e) \ + { \ + assert(!type_name_##_full(rb)); \ + rb->array[type_name_##_mask(rb->write++)] = e; \ + } \ + \ __attribute__((unused)) static inline bool type_name_##_try_put( \ struct type_name_ *rb, const element_type_ *e) \ { \ @@ -93,6 +100,21 @@ return rb->array[type_name_##_mask(rb->read++)]; \ } \ \ + __attribute__((unused)) static inline void type_name_##_consume( \ + struct type_name_ *rb, element_type_ *buffer, uint32_t entries) \ + { \ + assert(type_name_##_size(rb) >= entries); \ + const uint32_t mask = type_name_##_mask(rb->read); \ + if (mask && entries > mask) { \ + memcpy(buffer, &rb->array[rb->read], \ + sizeof(*buffer) * (entries - mask)); \ + memcpy(buffer + mask, &rb->array[0], sizeof(*buffer) * mask); \ + } else { \ + memcpy(buffer, &rb->array[rb->read], sizeof(*buffer) * entries); \ + } \ + rb->read += entries; \ + } \ + \ __attribute__((unused)) static inline element_type_ *type_name_##_get_ptr( \ struct type_name_ *rb) \ { \ diff --git a/src/scripts/gentables.py b/src/scripts/gentables.py index 2642bfb1c..195a496a0 100755 --- a/src/scripts/gentables.py +++ b/src/scripts/gentables.py @@ -209,83 +209,97 @@ def generate_level(level, next_table): } """) - print("""static inline size_t output_size(size_t input_size) { - /* Smallest input is 5 bits which produces 8 bits. Scaling that to 8 bits, we - * get 12.8 bits of output per 8 bits of input. */ - return (input_size * 128) / 10; -}""") + print("""DEFINE_RING_BUFFER_TYPE(uint8_ring_buffer, uint8_t, 64) - print("""uint8_t *h2_huffman_decode(const uint8_t *input, size_t input_len) +struct lwan_h2_huffman_decoder { + struct bit_reader bit_reader; + struct uint8_ring_buffer buffer; +}; + +void lwan_h2_huffman_init(struct lwan_h2_huffman_decoder *huff, + const uint8_t *input, + size_t input_len) { - uint8_t *output = malloc(output_size(input_len)); - uint8_t *ret = output; - struct bit_reader bit_reader = {.bitptr = input, - .total_bitcount = (int64_t)input_len * 8}; + huff->bit_reader = (struct bit_reader){ + .bitptr = input, + .total_bitcount = (int64_t)input_len * 8, + }; + uint8_ring_buffer_init(&huff->buffer); +} - while (bit_reader.total_bitcount > 7) { - uint8_t peeked_byte = peek_byte(&bit_reader); +ssize_t lwan_h2_huffman_next(struct lwan_h2_huffman_decoder *huff) +{ + struct bit_reader *reader = &huff->bit_reader; + struct uint8_ring_buffer *buffer = &huff->buffer; + + while (reader->total_bitcount > 7) { + if (uint8_ring_buffer_full(buffer)) + goto done; + + uint8_t peeked_byte = peek_byte(reader); if (LIKELY(level0[peeked_byte].num_bits)) { - *output++ = level0[peeked_byte].symbol; - consume(&bit_reader, level0[peeked_byte].num_bits); + uint8_ring_buffer_put_copy(buffer, level0[peeked_byte].symbol); + consume(reader, level0[peeked_byte].num_bits); + assert(bit_reader.total_bitcount >= 0); continue; } - if (!consume(&bit_reader, 8)) - goto fail; + if (!consume(reader, 8)) + return -1; const struct h2_huffman_code *level1 = next_level0(peeked_byte); - peeked_byte = peek_byte(&bit_reader); + peeked_byte = peek_byte(reader); if (level1[peeked_byte].num_bits) { - *output++ = level1[peeked_byte].symbol; - if (!consume(&bit_reader, level1[peeked_byte].num_bits)) - goto fail; + uint8_ring_buffer_put_copy(buffer, level1[peeked_byte].symbol); + if (!consume(reader, level1[peeked_byte].num_bits)) + return -1; continue; } - if (!consume(&bit_reader, 8)) - goto fail; + if (!consume(reader, 8)) + return -1; const struct h2_huffman_code *level2 = next_level1(peeked_byte); - peeked_byte = peek_byte(&bit_reader); + peeked_byte = peek_byte(reader); if (level2[peeked_byte].num_bits) { - *output++ = level2[peeked_byte].symbol; - if (!consume(&bit_reader, level2[peeked_byte].num_bits)) - goto fail; + uint8_ring_buffer_put_copy(buffer, level2[peeked_byte].symbol); + if (!consume(reader, level2[peeked_byte].num_bits)) + return -1; continue; } - if (!consume(&bit_reader, 8)) + if (!consume(reader, 8)) goto fail; const struct h2_huffman_code *level3 = next_level2(peeked_byte); if (LIKELY(level3)) { - peeked_byte = peek_byte(&bit_reader); - if (UNLIKELY(level3[peeked_byte].num_bits < 0)) { + peeked_byte = peek_byte(reader); + if (level3[peeked_byte].num_bits < 0) { /* EOS found */ - return ret; + goto done; } if (LIKELY(level3[peeked_byte].num_bits)) { - *output++ = level3[peeked_byte].symbol; - if (!consume(&bit_reader, level3[peeked_byte].num_bits)) - goto fail; + uint8_ring_buffer_put_copy(buffer, level3[peeked_byte].symbol); + if (!consume(reader, level3[peeked_byte].num_bits)) + return -1; continue; } } - goto fail; + return -1; } /* FIXME: ensure we're not promoting types unnecessarily here */ - if (bit_reader.total_bitcount) { - const uint8_t peeked_byte = peek_byte(&bit_reader); + if (reader->total_bitcount) { + const uint8_t peeked_byte = peek_byte(reader); const uint8_t eos_prefix = ((1 << bit_reader.total_bitcount) - 1) << (8 - bit_reader.total_bitcount); if ((peeked_byte & eos_prefix) == eos_prefix) goto done; - if (level0[peeked_byte].num_bits == (int8_t)bit_reader.total_bitcount) { - *output = level0[peeked_byte].symbol; + if (level0[peeked_byte].num_bits == (int8_t)reader->total_bitcount) { + uint8_ring_buffer_put_copy(buffer, level0[peeked_byte].symbol); goto done; } @@ -294,30 +308,68 @@ def generate_level(level, next_table): * - Incomplete sequence * - Has overlong padding */ - goto fail; + return -1; } done: - return ret; + return (ssize_t)uint8_ring_buffer_size(buffer); +} + +bool lwan_h2_huffman_decode_for_fuzzing(const uint8_t *input, size_t input_len) +{ + struct lwan_h2_huffman_decoder decoder; -fail: - free(ret); - return NULL; + lwan_h2_huffman_init(&decoder, input, input_len); + + while (true) { + ssize_t n_decoded = lwan_h2_huffman_next(&decoder); + + if (UNLIKELY(n_decoded < 0)) + return false; + if (n_decoded < 64) + return true; + + uint8_ring_buffer_init(&decoder->buffer); + } +} + +bool test_decoder(unsigned char input[], size_t input_size, const char expected[], size_t expected_size) +{ + struct lwan_h2_huffman_decoder decoder; + const char *expected_ptr = expected; + + lwan_h2_huffman_init(&decoder, input, input_size); + + while (true) { + ssize_t n_decoded = lwan_h2_huffman_next(&decoder); + if (n_decoded < 0) + return false; + if (n_decoded == 0 && expected_size == 0) + return true; + if (n_decoded == 0 && expected_size != 0) + return false; + while (n_decoded && expected_size) { + uint8_t expected = *expected_ptr; + uint8_t got = uint8_ring_buffer_get(&decoder->buffer); + expected_ptr++; + expected_size--; + n_decoded--; + if (expected != got) { + fprintf(stderr, "expected %d, got %d\n", expected, got); + return false; + } + } + } } -int main(int argc, char *argv[]) { +int main(int argc, char *argv[]) +{ /* "litespeed" */ - unsigned char litespeed_huff[128] = {0xce, 0x64, 0x97, 0x75, 0x65, 0x2c, 0x9f}; - unsigned char *decoded; - - decoded = h2_huffman_decode(litespeed_huff, 7); - if (!decoded) { - puts("could not decode"); + unsigned char litespeed_huff[] = {0xce, 0x64, 0x97, 0x75, 0x65, 0x2c, 0x9f}; + if (!test_decoder(litespeed_huff, sizeof(litespeed_huff), "LiteSpeed", + sizeof("LiteSpeed") - 1)) { return 1; } - printf("%s\\n", !strcmp(decoded, "LiteSpeed") ? "pass!" : "fail!"); - printf("decoded: '%s'\\n", decoded); - free(decoded); unsigned char x_fb_debug[128] = { 0xa7, 0x06, 0xa7, 0x63, 0x97, 0xc6, 0x1d, 0xc9, 0xbb, 0xa3, 0xc6, 0x5e, @@ -327,13 +379,15 @@ def generate_level(level, next_table): 0x8b, 0xac, 0x7f, 0xef, 0x65, 0x5d, 0x9f, 0x8c, 0x9d, 0x3c, 0x72, 0x8f, 0xc5, 0xfd, 0x9e, 0xd0, 0x51, 0xb1, 0xdf, 0x46, 0xc8, 0x20, }; - decoded = h2_huffman_decode(x_fb_debug, 6*12-2); - if (!decoded) { - puts("could not decode"); + unsigned char x_fb_debug_decoded[] = + "mEO7bfwFStBMwJWfW4pmg2XL25AswjrVlfcfYbxkcS2ssduZmiKoipMH9XwoTGkb+" + "Qnq9bcjwWbwDQzsea/vMQ=="; + if (!test_decoder(x_fb_debug, sizeof(x_fb_debug), x_fb_debug_decoded, + sizeof(x_fb_debug_decoded) - 1)) { return 1; } - printf("%s\\n", !strcmp(decoded, "mEO7bfwFStBMwJWfW4pmg2XL25AswjrVlfcfYbxkcS2ssduZmiKoipMH9XwoTGkb+Qnq9bcjwWbwDQzsea/vMQ==") ? "pass!" : "fail!"); - printf("decoded: '%s'\\n", decoded); - free(decoded); + + puts("passed!"); + return 0; } """)