Skip to content

Commit

Permalink
[libc][NFC] Add supporting class for atof implementation
Browse files Browse the repository at this point in the history
This change adds the High Precision Decimal described here:
https://nigeltao.github.io/blog/2020/parse-number-f64-simple.html
It will be used for the atof implementation later, but is complete and
tested now.

The code is inspired by the golang implmentation of the HPD class, which
can be found here: https://github.com/golang/go/blob/release-branch.go1.16/src/strconv/decimal.go

Reviewed By: sivachandra

Differential Revision: https://reviews.llvm.org/D110454
  • Loading branch information
michaelrj-google committed Oct 4, 2021
1 parent bb69f1d commit 6f80339
Show file tree
Hide file tree
Showing 5 changed files with 842 additions and 0 deletions.
8 changes: 8 additions & 0 deletions libc/src/__support/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@ add_header_library(
ctype_utils.h
)

add_header_library(
high_precision_decimal
HDRS
high_precision_decimal.h

)

add_header_library(
str_conv_utils
HDRS
str_conv_utils.h
DEPENDS
.ctype_utils
.high_precision_decimal
libc.include.errno
libc.src.errno.__errno_location
libc.utils.CPP.standalone_cpp
Expand Down
378 changes: 378 additions & 0 deletions libc/src/__support/high_precision_decimal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,378 @@
//===-- High Precision Decimal ----------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See httpss//llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LIBC_SRC_SUPPORT_HIGH_PRECISION_DECIMAL_H
#define LIBC_SRC_SUPPORT_HIGH_PRECISION_DECIMAL_H

#include "src/__support/ctype_utils.h"
#include "src/__support/str_conv_utils.h"
#include <stdint.h>

namespace __llvm_libc {
namespace internal {

struct LShiftTableEntry {
uint32_t newDigits;
char const *powerOfFive;
};

// This is based on the HPD data structure described as part of the Simple
// Decimal Conversion algorithm by Nigel Tao, described at this link:
// https://nigeltao.github.io/blog/2020/parse-number-f64-simple.html
class HighPrecsisionDecimal {

// This precomputed table speeds up left shifts by having the number of new
// digits that will be added by multiplying 5^i by 2^i. If the number is less
// than 5^i then it will add one fewer digit. There are only 60 entries since
// that's the max shift amount.
// This table was generated by the script at
// libc/utils/mathtools/GenerateHPDConstants.py
static constexpr LShiftTableEntry LEFT_SHIFT_DIGIT_TABLE[] = {
{0, ""},
{1, "5"},
{1, "25"},
{1, "125"},
{2, "625"},
{2, "3125"},
{2, "15625"},
{3, "78125"},
{3, "390625"},
{3, "1953125"},
{4, "9765625"},
{4, "48828125"},
{4, "244140625"},
{4, "1220703125"},
{5, "6103515625"},
{5, "30517578125"},
{5, "152587890625"},
{6, "762939453125"},
{6, "3814697265625"},
{6, "19073486328125"},
{7, "95367431640625"},
{7, "476837158203125"},
{7, "2384185791015625"},
{7, "11920928955078125"},
{8, "59604644775390625"},
{8, "298023223876953125"},
{8, "1490116119384765625"},
{9, "7450580596923828125"},
{9, "37252902984619140625"},
{9, "186264514923095703125"},
{10, "931322574615478515625"},
{10, "4656612873077392578125"},
{10, "23283064365386962890625"},
{10, "116415321826934814453125"},
{11, "582076609134674072265625"},
{11, "2910383045673370361328125"},
{11, "14551915228366851806640625"},
{12, "72759576141834259033203125"},
{12, "363797880709171295166015625"},
{12, "1818989403545856475830078125"},
{13, "9094947017729282379150390625"},
{13, "45474735088646411895751953125"},
{13, "227373675443232059478759765625"},
{13, "1136868377216160297393798828125"},
{14, "5684341886080801486968994140625"},
{14, "28421709430404007434844970703125"},
{14, "142108547152020037174224853515625"},
{15, "710542735760100185871124267578125"},
{15, "3552713678800500929355621337890625"},
{15, "17763568394002504646778106689453125"},
{16, "88817841970012523233890533447265625"},
{16, "444089209850062616169452667236328125"},
{16, "2220446049250313080847263336181640625"},
{16, "11102230246251565404236316680908203125"},
{17, "55511151231257827021181583404541015625"},
{17, "277555756156289135105907917022705078125"},
{17, "1387778780781445675529539585113525390625"},
{18, "6938893903907228377647697925567626953125"},
{18, "34694469519536141888238489627838134765625"},
{18, "173472347597680709441192448139190673828125"},
{19, "867361737988403547205962240695953369140625"},
};

// The maximum amount we can shift is the number of bits used in the
// accumulator, minus the number of bits needed to represent the base (in this
// case 4).
static constexpr uint32_t MAX_SHIFT_AMOUNT = sizeof(uint64_t) - 4;

// 800 is an arbitrary number of digits, but should be
// large enough for any practical number.
static constexpr uint32_t MAX_NUM_DIGITS = 800;

uint32_t numDigits = 0;
int32_t decimalPoint = 0;
bool truncated = false;
uint8_t digits[MAX_NUM_DIGITS];

private:
bool shouldRoundUp(uint32_t roundToDigit) {
if (roundToDigit < 0 || roundToDigit >= this->numDigits) {
return false;
}

// If we're right in the middle and there are no extra digits
if (this->digits[roundToDigit] == 5 &&
roundToDigit + 1 == this->numDigits) {

// Round up if we've truncated (since that means the result is slightly
// higher than what's represented.)
if (this->truncated) {
return true;
}

// If this exactly halfway, round to even.
return this->digits[roundToDigit - 1] % 2 != 0;
}
// If there are digits after roundToDigit, they must be non-zero since we
// trim trailing zeroes after all operations that change digits.
return this->digits[roundToDigit] >= 5;
}

// Takes an amount to left shift and returns the number of new digits needed
// to store the result based on LEFT_SHIFT_DIGIT_TABLE.
uint32_t getNumNewDigits(uint32_t lShiftAmount) {
const char *powerOfFive = LEFT_SHIFT_DIGIT_TABLE[lShiftAmount].powerOfFive;
uint32_t newDigits = LEFT_SHIFT_DIGIT_TABLE[lShiftAmount].newDigits;
uint32_t digitIndex = 0;
while (powerOfFive[digitIndex] != 0) {
if (digitIndex >= this->numDigits) {
return newDigits - 1;
}
if (this->digits[digitIndex] != powerOfFive[digitIndex] - '0') {
return newDigits -
((this->digits[digitIndex] < powerOfFive[digitIndex] - '0') ? 1
: 0);
}
++digitIndex;
}
return newDigits;
}

// Trim all trailing 0s
void trimTrailingZeroes() {
while (this->numDigits > 0 && this->digits[this->numDigits - 1] == 0) {
--this->numDigits;
}
if (this->numDigits == 0) {
this->decimalPoint = 0;
}
}

// Perform a digitwise binary non-rounding right shift on this value by
// shiftAmount. The shiftAmount can't be more than MAX_SHIFT_AMOUNT to prevent
// overflow.
void rightShift(uint32_t shiftAmount) {
uint32_t readIndex = 0;
uint32_t writeIndex = 0;

uint64_t accumulator = 0;

const uint64_t shiftMask = (uint64_t(1) << shiftAmount) - 1;

// Warm Up phase: we don't have enough digits to start writing, so just
// read them into the accumulator.
while (accumulator >> shiftAmount == 0) {
uint64_t readDigit = 0;
// If there are still digits to read, read the next one, else the digit is
// assumed to be 0.
if (readIndex < this->numDigits) {
readDigit = this->digits[readIndex];
}
accumulator = accumulator * 10 + readDigit;
++readIndex;
}

// Shift the decimal point by the number of digits it took to fill the
// accumulator.
this->decimalPoint -= readIndex - 1;

// Middle phase: we have enough digits to write, as well as more digits to
// read. Keep reading until we run out of digits.
while (readIndex < this->numDigits) {
uint64_t readDigit = this->digits[readIndex];
uint64_t writeDigit = accumulator >> shiftAmount;
accumulator &= shiftMask;
this->digits[writeIndex] = static_cast<uint8_t>(writeDigit);
accumulator = accumulator * 10 + readDigit;
++readIndex;
++writeIndex;
}

// Cool Down phase: All of the readable digits have been read, so just write
// the remainder, while treating any more digits as 0.
while (accumulator > 0) {
uint64_t writeDigit = accumulator >> shiftAmount;
accumulator &= shiftMask;
if (writeIndex < MAX_NUM_DIGITS) {
this->digits[writeIndex] = static_cast<uint8_t>(writeDigit);
++writeIndex;
} else if (writeDigit > 0) {
this->truncated = true;
}
accumulator = accumulator * 10;
}
this->numDigits = writeIndex;
this->trimTrailingZeroes();
}

// Perform a digitwise binary non-rounding left shift on this value by
// shiftAmount. The shiftAmount can't be more than MAX_SHIFT_AMOUNT to prevent
// overflow.
void leftShift(uint32_t shiftAmount) {
uint32_t newDigits = this->getNumNewDigits(shiftAmount);

int32_t readIndex = this->numDigits - 1;
uint32_t writeIndex = this->numDigits + newDigits;

uint64_t accumulator = 0;

// No Warm Up phase. Since we're putting digits in at the top and taking
// digits from the bottom we don't have to wait for the accumulator to fill.

// Middle phase: while we have more digits to read, keep reading as well as
// writing.
while (readIndex >= 0) {
accumulator += static_cast<uint64_t>(this->digits[readIndex])
<< shiftAmount;
uint64_t nextAccumulator = accumulator / 10;
uint64_t writeDigit = accumulator - (10 * nextAccumulator);
--writeIndex;
if (writeIndex < MAX_NUM_DIGITS) {
this->digits[writeIndex] = static_cast<uint8_t>(writeDigit);
} else if (writeDigit != 0) {
this->truncated = true;
}
accumulator = nextAccumulator;
--readIndex;
}

// Cool Down phase: there are no more digits to read, so just write the
// remaining digits in the accumulator.
while (accumulator > 0) {
uint64_t nextAccumulator = accumulator / 10;
uint64_t writeDigit = accumulator - (10 * nextAccumulator);
--writeIndex;
if (writeIndex < MAX_NUM_DIGITS) {
this->digits[writeIndex] = static_cast<uint8_t>(writeDigit);
} else if (writeDigit != 0) {
this->truncated = true;
}
accumulator = nextAccumulator;
}

this->numDigits += newDigits;
if (this->numDigits > MAX_NUM_DIGITS) {
this->numDigits = MAX_NUM_DIGITS;
}
this->decimalPoint += newDigits;
this->trimTrailingZeroes();
}

public:
// numString is assumed to be a string of numeric characters. It doesn't
// handle leading spaces.
HighPrecsisionDecimal(const char *__restrict numString) {
bool sawDot = false;
bool sawDigit = false;
while (isdigit(*numString) || *numString == '.') {
if (*numString == '.') {
if (sawDot) {
break;
}
this->decimalPoint = this->numDigits;
sawDot = true;
} else {
sawDigit = true;
if (*numString == '0' && this->numDigits == 0) {
--this->decimalPoint;
continue;
}
if (this->numDigits < MAX_NUM_DIGITS) {
this->digits[this->numDigits] = *numString - '0';
++this->numDigits;
} else if (*numString != '0') {
this->truncated = true;
}
}
++numString;
}

if (!sawDot) {
this->decimalPoint = this->numDigits;
}

if ((*numString | 32) == 'e') {
++numString;
if (isdigit(*numString) || *numString == '+' || *numString == '-') {
int32_t addToExp = strtointeger<int32_t>(numString, nullptr, 10);
this->decimalPoint += addToExp;
}
}

this->trimTrailingZeroes();
}

// Binary shift left (shiftAmount > 0) or right (shiftAmount < 0)
void shift(int shiftAmount) {
if (shiftAmount == 0) {
return;
}
// Left
else if (shiftAmount > 0) {
while (static_cast<uint32_t>(shiftAmount) > MAX_SHIFT_AMOUNT) {
this->leftShift(MAX_SHIFT_AMOUNT);
shiftAmount -= MAX_SHIFT_AMOUNT;
}
this->leftShift(shiftAmount);
}
// Right
else {
while (static_cast<uint32_t>(shiftAmount) < -MAX_SHIFT_AMOUNT) {
this->rightShift(MAX_SHIFT_AMOUNT);
shiftAmount += MAX_SHIFT_AMOUNT;
}
this->rightShift(-shiftAmount);
}
}

// Round the number represented to the closest value of unsigned int type T.
// This is done ignoring overflow.
template <class T> T roundToIntegerType() {
T result = 0;
uint32_t curDigit = 0;

while (static_cast<int32_t>(curDigit) < this->decimalPoint &&
curDigit < this->numDigits) {
result = result * 10 + (this->digits[curDigit]);
++curDigit;
}

// If there are implicit 0s at the end of the number, include those.
while (static_cast<int32_t>(curDigit) < this->decimalPoint) {
result *= 10;
++curDigit;
}
if (this->shouldRoundUp(this->decimalPoint)) {
++result;
}
return result;
}

// Extra functions for testing.

uint8_t *getDigits() { return this->digits; }
uint32_t getNumDigits() { return this->numDigits; }
int32_t getDecimalPoint() { return this->decimalPoint; }
void setTruncated(bool trunc) { this->truncated = trunc; }
};

} // namespace internal
} // namespace __llvm_libc

#endif // LIBC_SRC_SUPPORT_HIGH_PRECISION_DECIMAL_H
Loading

0 comments on commit 6f80339

Please sign in to comment.