diff --git a/lib/internal/crypto/cipher.js b/lib/internal/crypto/cipher.js index add56eae680ced..80b0c0e9dab382 100644 --- a/lib/internal/crypto/cipher.js +++ b/lib/internal/crypto/cipher.js @@ -151,13 +151,13 @@ Cipher.prototype.update = function update(data, inputEncoding, outputEncoding) { inputEncoding = inputEncoding || encoding; outputEncoding = outputEncoding || encoding; - if (typeof data !== 'string' && !isArrayBufferView(data)) { + if (typeof data === 'string') { + validateEncoding(data, inputEncoding); + } else if (!isArrayBufferView(data)) { throw new ERR_INVALID_ARG_TYPE( 'data', ['string', 'Buffer', 'TypedArray', 'DataView'], data); } - validateEncoding(data, inputEncoding); - const ret = this[kHandle].update(data, inputEncoding); if (outputEncoding && outputEncoding !== 'buffer') { diff --git a/lib/internal/crypto/hash.js b/lib/internal/crypto/hash.js index dca0ba767f6e29..1cf0188da2f35e 100644 --- a/lib/internal/crypto/hash.js +++ b/lib/internal/crypto/hash.js @@ -78,17 +78,13 @@ Hash.prototype.update = function update(data, encoding) { if (state[kFinalized]) throw new ERR_CRYPTO_HASH_FINALIZED(); - if (typeof data !== 'string' && !isArrayBufferView(data)) { - throw new ERR_INVALID_ARG_TYPE('data', - ['string', - 'Buffer', - 'TypedArray', - 'DataView'], - data); + if (typeof data === 'string') { + validateEncoding(data, encoding); + } else if (!isArrayBufferView(data)) { + throw new ERR_INVALID_ARG_TYPE( + 'data', ['string', 'Buffer', 'TypedArray', 'DataView'], data); } - validateEncoding(data, encoding); - if (!this[kHandle].update(data, encoding)) throw new ERR_CRYPTO_HASH_UPDATE_FAILED(); return this; diff --git a/lib/internal/crypto/sig.js b/lib/internal/crypto/sig.js index 27930ce1acf793..7e3b7aa7ff394d 100644 --- a/lib/internal/crypto/sig.js +++ b/lib/internal/crypto/sig.js @@ -9,7 +9,7 @@ const { ERR_INVALID_ARG_TYPE, ERR_INVALID_OPT_VALUE } = require('internal/errors').codes; -const { validateString } = require('internal/validators'); +const { validateEncoding, validateString } = require('internal/validators'); const { Sign: _Sign, Verify: _Verify, @@ -50,8 +50,15 @@ Sign.prototype._write = function _write(chunk, encoding, callback) { Sign.prototype.update = function update(data, encoding) { encoding = encoding || getDefaultEncoding(); - data = getArrayBufferView(data, 'data', encoding); - this[kHandle].update(data); + + if (typeof data === 'string') { + validateEncoding(data, encoding); + } else if (!isArrayBufferView(data)) { + throw new ERR_INVALID_ARG_TYPE( + 'data', ['string', 'Buffer', 'TypedArray', 'DataView'], data); + } + + this[kHandle].update(data, encoding); return this; }; diff --git a/src/node_crypto.cc b/src/node_crypto.cc index 92760fb8c8577b..e26e0f9f17d8cb 100644 --- a/src/node_crypto.cc +++ b/src/node_crypto.cc @@ -168,6 +168,26 @@ template int SSLWrap::SelectALPNCallback( unsigned int inlen, void* arg); +template +void Decode(const FunctionCallbackInfo& args, + void (*callback)(T*, const FunctionCallbackInfo&, + const char*, size_t)) { + T* ctx; + ASSIGN_OR_RETURN_UNWRAP(&ctx, args.Holder()); + + if (args[0]->IsString()) { + StringBytes::InlineDecoder decoder; + Environment* env = Environment::GetCurrent(args); + enum encoding enc = ParseEncoding(env->isolate(), args[1], UTF8); + if (decoder.Decode(env, args[0].As(), enc).IsNothing()) + return; + callback(ctx, args, decoder.out(), decoder.size()); + } else { + ArrayBufferViewContents buf(args[0]); + callback(ctx, args, buf.data(), buf.length()); + } +} + static int PasswordCallback(char* buf, int size, int rwflag, void* u) { const char* passphrase = static_cast(u); if (passphrase != nullptr) { @@ -4455,38 +4475,24 @@ CipherBase::UpdateResult CipherBase::Update(const char* data, void CipherBase::Update(const FunctionCallbackInfo& args) { - Environment* env = Environment::GetCurrent(args); - - CipherBase* cipher; - ASSIGN_OR_RETURN_UNWRAP(&cipher, args.Holder()); - - AllocatedBuffer out; - UpdateResult r; - - // Only copy the data if we have to, because it's a string - if (args[0]->IsString()) { - StringBytes::InlineDecoder decoder; - enum encoding enc = ParseEncoding(env->isolate(), args[1], UTF8); - - if (decoder.Decode(env, args[0].As(), enc).IsNothing()) + Decode(args, [](CipherBase* cipher, + const FunctionCallbackInfo& args, + const char* data, size_t size) { + AllocatedBuffer out; + UpdateResult r = cipher->Update(data, size, &out); + + if (r != kSuccess) { + if (r == kErrorState) { + Environment* env = Environment::GetCurrent(args); + ThrowCryptoError(env, ERR_get_error(), + "Trying to add data in unsupported state"); + } return; - r = cipher->Update(decoder.out(), decoder.size(), &out); - } else { - ArrayBufferViewContents buf(args[0]); - r = cipher->Update(buf.data(), buf.length(), &out); - } - - if (r != kSuccess) { - if (r == kErrorState) { - ThrowCryptoError(env, ERR_get_error(), - "Trying to add data in unsupported state"); } - return; - } - CHECK(out.data() != nullptr || out.size() == 0); - - args.GetReturnValue().Set(out.ToBuffer().ToLocalChecked()); + CHECK(out.data() != nullptr || out.size() == 0); + args.GetReturnValue().Set(out.ToBuffer().ToLocalChecked()); + }); } @@ -4642,26 +4648,11 @@ bool Hmac::HmacUpdate(const char* data, int len) { void Hmac::HmacUpdate(const FunctionCallbackInfo& args) { - Environment* env = Environment::GetCurrent(args); - - Hmac* hmac; - ASSIGN_OR_RETURN_UNWRAP(&hmac, args.Holder()); - - // Only copy the data if we have to, because it's a string - bool r = false; - if (args[0]->IsString()) { - StringBytes::InlineDecoder decoder; - enum encoding enc = ParseEncoding(env->isolate(), args[1], UTF8); - - if (!decoder.Decode(env, args[0].As(), enc).IsNothing()) { - r = hmac->HmacUpdate(decoder.out(), decoder.size()); - } - } else { - ArrayBufferViewContents buf(args[0]); - r = hmac->HmacUpdate(buf.data(), buf.length()); - } - - args.GetReturnValue().Set(r); + Decode(args, [](Hmac* hmac, const FunctionCallbackInfo& args, + const char* data, size_t size) { + bool r = hmac->HmacUpdate(data, size); + args.GetReturnValue().Set(r); + }); } @@ -4778,28 +4769,11 @@ bool Hash::HashUpdate(const char* data, int len) { void Hash::HashUpdate(const FunctionCallbackInfo& args) { - Environment* env = Environment::GetCurrent(args); - - Hash* hash; - ASSIGN_OR_RETURN_UNWRAP(&hash, args.Holder()); - - // Only copy the data if we have to, because it's a string - bool r = true; - if (args[0]->IsString()) { - StringBytes::InlineDecoder decoder; - enum encoding enc = ParseEncoding(env->isolate(), args[1], UTF8); - - if (decoder.Decode(env, args[0].As(), enc).IsNothing()) { - args.GetReturnValue().Set(false); - return; - } - r = hash->HashUpdate(decoder.out(), decoder.size()); - } else if (args[0]->IsArrayBufferView()) { - ArrayBufferViewContents buf(args[0].As()); - r = hash->HashUpdate(buf.data(), buf.length()); - } - - args.GetReturnValue().Set(r); + Decode(args, [](Hash* hash, const FunctionCallbackInfo& args, + const char* data, size_t size) { + bool r = hash->HashUpdate(data, size); + args.GetReturnValue().Set(r); + }); } @@ -4992,14 +4966,11 @@ void Sign::SignInit(const FunctionCallbackInfo& args) { void Sign::SignUpdate(const FunctionCallbackInfo& args) { - Sign* sign; - ASSIGN_OR_RETURN_UNWRAP(&sign, args.Holder()); - - Error err; - ArrayBufferViewContents buf(args[0]); - err = sign->Update(buf.data(), buf.length()); - - sign->CheckThrow(err); + Decode(args, [](Sign* sign, const FunctionCallbackInfo& args, + const char* data, size_t size) { + Error err = sign->Update(data, size); + sign->CheckThrow(err); + }); } static int GetDefaultSignPadding(const ManagedEVPPKey& key) { @@ -5311,14 +5282,12 @@ void Verify::VerifyInit(const FunctionCallbackInfo& args) { void Verify::VerifyUpdate(const FunctionCallbackInfo& args) { - Verify* verify; - ASSIGN_OR_RETURN_UNWRAP(&verify, args.Holder()); - - Error err; - ArrayBufferViewContents buf(args[0]); - err = verify->Update(buf.data(), buf.length()); - - verify->CheckThrow(err); + Decode(args, [](Verify* verify, + const FunctionCallbackInfo& args, + const char* data, size_t size) { + Error err = verify->Update(data, size); + verify->CheckThrow(err); + }); } diff --git a/test/parallel/test-crypto-update-encoding.js b/test/parallel/test-crypto-update-encoding.js new file mode 100644 index 00000000000000..e1e6d029aa5e30 --- /dev/null +++ b/test/parallel/test-crypto-update-encoding.js @@ -0,0 +1,22 @@ +'use strict'; +const common = require('../common'); + +if (!common.hasCrypto) + common.skip('missing crypto'); + +const crypto = require('crypto'); + +const zeros = Buffer.alloc; +const key = zeros(16); +const iv = zeros(16); + +const cipher = () => crypto.createCipheriv('aes-128-cbc', key, iv); +const decipher = () => crypto.createDecipheriv('aes-128-cbc', key, iv); +const hash = () => crypto.createSign('sha256'); +const hmac = () => crypto.createHmac('sha256', key); +const sign = () => crypto.createSign('sha256'); +const verify = () => crypto.createVerify('sha256'); + +for (const f of [cipher, decipher, hash, hmac, sign, verify]) + for (const n of [15, 16]) + f().update(zeros(n), 'hex'); // Should ignore inputEncoding.