Skip to content

Commit

Permalink
Add support for decoding base64.
Browse files Browse the repository at this point in the history
An upcoming patch to LLDB will require the ability to decode base64. This patch adds support for decoding base64 and adds tests.

Differential Revision: https://reviews.llvm.org/D126254
  • Loading branch information
clayborg committed Jun 23, 2022
1 parent 878309c commit 8b987ca
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 7 deletions.
81 changes: 81 additions & 0 deletions llvm/include/llvm/Support/Base64.h
Expand Up @@ -13,6 +13,7 @@
#ifndef LLVM_SUPPORT_BASE64_H
#define LLVM_SUPPORT_BASE64_H

#include "llvm/Support/Error.h"
#include <cstdint>
#include <string>

Expand Down Expand Up @@ -52,6 +53,86 @@ template <class InputBytes> std::string encodeBase64(InputBytes const &Bytes) {
return Buffer;
}

template <class OutputBytes>
llvm::Error decodeBase64(llvm::StringRef Input, OutputBytes &Output) {
// Invalid table value with short name to fit in the table init below. The
// invalid value is 64 since valid base64 values are 0 - 63.
constexpr char Inv = 64;
static char DecodeTable[] = {
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
Inv, Inv, Inv, 62, Inv, Inv, Inv, 63, // ...+.../
52, 53, 54, 55, 56, 57, 58, 59, // 01234567
60, 61, Inv, Inv, Inv, 0, Inv, Inv, // 89...=..
Inv, 0, 1, 2, 3, 4, 5, 6, // .ABCDEFG
7, 8, 9, 10, 11, 12, 13, 14, // HIJKLMNO
15, 16, 17, 18, 19, 20, 21, 22, // PQRSTUVW
23, 24, 25, Inv, Inv, Inv, Inv, Inv, // XYZ.....
Inv, 26, 27, 28, 29, 30, 31, 32, // .abcdefg
33, 34, 35, 36, 37, 38, 39, 40, // hijklmno
41, 42, 43, 44, 45, 46, 47, 48, // pqrstuvw
49, 50, 51 // xyz.....
};
auto decodeBase64Byte = [](uint8_t Ch) -> char {
if (Ch >= sizeof(DecodeTable))
return Inv;
return DecodeTable[Ch];
};
Output.clear();
const uint64_t InputLength = Input.size();
if (InputLength == 0)
return Error::success();
// Make sure we have a valid input string length which must be a multiple
// of 4.
if ((InputLength % 4) != 0)
return createStringError(std::errc::illegal_byte_sequence,
"Base64 encoded strings must be a multiple of 4 "
"bytes in length");
const uint64_t FirstValidEqualIdx = InputLength - 2;
char Hex64Bytes[4];
for (uint64_t Idx = 0; Idx < InputLength; Idx += 4) {
for (uint64_t ByteOffset = 0; ByteOffset < 4; ++ByteOffset) {
const uint64_t ByteIdx = Idx + ByteOffset;
const char Byte = Input[ByteIdx];
const char DecodedByte = decodeBase64Byte(Byte);
bool Illegal = DecodedByte == Inv;
if (!Illegal && Byte == '=') {
if (ByteIdx < FirstValidEqualIdx) {
// We have an '=' in the middle of the string which is invalid, only
// the last two characters can be '=' characters.
Illegal = true;
} else if (ByteIdx == FirstValidEqualIdx && Input[ByteIdx + 1] != '=') {
// We have an equal second to last from the end and the last character
// is not also an equal, so the '=' character is invalid
Illegal = true;
}
}
if (Illegal)
return createStringError(
std::errc::illegal_byte_sequence,
"Invalid Base64 character %#2.2x at index %" PRIu64, Byte, ByteIdx);
Hex64Bytes[ByteOffset] = DecodedByte;
}
// Now we have 6 bits of 3 bytes in value in each of the Hex64Bytes bytes.
// Extract the right bytes into the Output buffer.
Output.push_back((Hex64Bytes[0] << 2) + ((Hex64Bytes[1] >> 4) & 0x03));
Output.push_back((Hex64Bytes[1] << 4) + ((Hex64Bytes[2] >> 2) & 0x0f));
Output.push_back((Hex64Bytes[2] << 6) + (Hex64Bytes[3] & 0x3f));
}
// If we had valid trailing '=' characters strip the right number of bytes
// from the end of the output buffer. We already know that the Input length
// it a multiple of 4 and is not zero, so direct character access is safe.
if (Input.back() == '=') {
Output.pop_back();
if (Input[InputLength - 2] == '=')
Output.pop_back();
}
return Error::success();
}

} // end namespace llvm

#endif
66 changes: 59 additions & 7 deletions llvm/unittests/Support/Base64Test.cpp
Expand Up @@ -13,6 +13,7 @@

#include "llvm/Support/Base64.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Testing/Support/Error.h"
#include "gtest/gtest.h"

using namespace llvm;
Expand All @@ -24,6 +25,28 @@ void TestBase64(StringRef Input, StringRef Final) {
EXPECT_EQ(Res, Final);
}

void TestBase64Decode(StringRef Input, StringRef Expected,
StringRef ExpectedErrorMessage = {}) {
std::vector<char> DecodedBytes;
if (ExpectedErrorMessage.empty()) {
ASSERT_THAT_ERROR(decodeBase64(Input, DecodedBytes), Succeeded());
EXPECT_EQ(llvm::ArrayRef<char>(DecodedBytes),
llvm::ArrayRef<char>(Expected.data(), Expected.size()));
} else {
ASSERT_THAT_ERROR(decodeBase64(Input, DecodedBytes),
FailedWithMessage(ExpectedErrorMessage));
}
}

char NonPrintableVector[] = {0x00, 0x00, 0x00, 0x46,
0x00, 0x08, (char)0xff, (char)0xee};

char LargeVector[] = {0x54, 0x68, 0x65, 0x20, 0x71, 0x75, 0x69, 0x63, 0x6b,
0x20, 0x62, 0x72, 0x6f, 0x77, 0x6e, 0x20, 0x66, 0x6f,
0x78, 0x20, 0x6a, 0x75, 0x6d, 0x70, 0x73, 0x20, 0x6f,
0x76, 0x65, 0x72, 0x20, 0x31, 0x33, 0x20, 0x6c, 0x61,
0x7a, 0x79, 0x20, 0x64, 0x6f, 0x67, 0x73, 0x2e};

} // namespace

TEST(Base64Test, Base64) {
Expand All @@ -37,16 +60,45 @@ TEST(Base64Test, Base64) {
TestBase64("foobar", "Zm9vYmFy");

// With non-printable values.
char NonPrintableVector[] = {0x00, 0x00, 0x00, 0x46,
0x00, 0x08, (char)0xff, (char)0xee};
TestBase64({NonPrintableVector, sizeof(NonPrintableVector)}, "AAAARgAI/+4=");

// Large test case
char LargeVector[] = {0x54, 0x68, 0x65, 0x20, 0x71, 0x75, 0x69, 0x63, 0x6b,
0x20, 0x62, 0x72, 0x6f, 0x77, 0x6e, 0x20, 0x66, 0x6f,
0x78, 0x20, 0x6a, 0x75, 0x6d, 0x70, 0x73, 0x20, 0x6f,
0x76, 0x65, 0x72, 0x20, 0x31, 0x33, 0x20, 0x6c, 0x61,
0x7a, 0x79, 0x20, 0x64, 0x6f, 0x67, 0x73, 0x2e};
TestBase64({LargeVector, sizeof(LargeVector)},
"VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIDEzIGxhenkgZG9ncy4=");
}

TEST(Base64Test, DecodeBase64) {
std::vector<llvm::StringRef> Outputs = {"", "f", "fo", "foo",
"foob", "fooba", "foobar"};
Outputs.push_back(
llvm::StringRef(NonPrintableVector, sizeof(NonPrintableVector)));

Outputs.push_back(llvm::StringRef(LargeVector, sizeof(LargeVector)));
// Make sure we can encode and decode any byte.
std::vector<char> AllChars;
for (int Ch = INT8_MIN; Ch <= INT8_MAX; ++Ch)
AllChars.push_back(Ch);
Outputs.push_back(llvm::StringRef(AllChars.data(), AllChars.size()));

for (const auto &Output : Outputs) {
// We trust that encoding is working after running the Base64Test::Base64()
// test function above, so we can use it to encode the string and verify we
// can decode it correctly.
auto Input = encodeBase64(Output);
TestBase64Decode(Input, Output);
}
struct ErrorInfo {
llvm::StringRef Input;
llvm::StringRef ErrorMessage;
};
std::vector<ErrorInfo> ErrorInfos = {
{"f", "Base64 encoded strings must be a multiple of 4 bytes in length"},
{"=abc", "Invalid Base64 character 0x3d at index 0"},
{"a=bc", "Invalid Base64 character 0x3d at index 1"},
{"ab=c", "Invalid Base64 character 0x3d at index 2"},
{"fun!", "Invalid Base64 character 0x21 at index 3"},
};

for (const auto &EI : ErrorInfos)
TestBase64Decode(EI.Input, "", EI.ErrorMessage);
}

0 comments on commit 8b987ca

Please sign in to comment.