diff --git a/openssl/hmac.go b/openssl/hmac.go index 80f320b..81918cd 100644 --- a/openssl/hmac.go +++ b/openssl/hmac.go @@ -14,6 +14,11 @@ import ( "unsafe" ) +var ( + paramAlgHMAC = C.CString("HMAC") + paramDigest = C.CString("digest") +) + // NewHMAC returns a new HMAC using OpenSSL. // The function h must return a hash implemented by // OpenSSL (for example, h could be openssl.NewSHA256). @@ -38,19 +43,19 @@ func NewHMAC(h func() hash.Hash, key []byte) hash.Hash { // we pass an "empty" key. hkey = make([]byte, C.GO_EVP_MAX_MD_SIZE) } - hmac := &opensslHMAC{ - md: md, - size: ch.Size(), - blockSize: ch.BlockSize(), - key: hkey, - ctx: hmacCtxNew(), + switch vMajor { + case 1: + return newHMAC1(hkey, ch, md) + case 3: + return newHMAC3(hkey, ch, md) + default: + panic(errUnsuportedVersion()) } - runtime.SetFinalizer(hmac, (*opensslHMAC).finalize) - hmac.Reset() - return hmac } -type opensslHMAC struct { +// hmac1 implements hash.Hash +// using functions available in OpenSSL 1. +type hmac1 struct { md C.GO_EVP_MD_PTR ctx C.GO_HMAC_CTX_PTR size int @@ -59,8 +64,21 @@ type opensslHMAC struct { sum []byte } -func (h *opensslHMAC) Reset() { - hmacCtxReset(h.ctx) +func newHMAC1(key []byte, h hash.Hash, md C.GO_EVP_MD_PTR) *hmac1 { + hmac := &hmac1{ + md: md, + size: h.Size(), + blockSize: h.BlockSize(), + key: key, + ctx: hmac1CtxNew(), + } + runtime.SetFinalizer(hmac, (*hmac1).finalize) + hmac.Reset() + return hmac +} + +func (h *hmac1) Reset() { + hmac1CtxReset(h.ctx) if C.go_openssl_HMAC_Init_ex(h.ctx, unsafe.Pointer(&h.key[0]), C.int(len(h.key)), h.md, nil) == 0 { panic("openssl: HMAC_Init failed") @@ -73,11 +91,11 @@ func (h *opensslHMAC) Reset() { h.sum = nil } -func (h *opensslHMAC) finalize() { - hmacCtxFree(h.ctx) +func (h *hmac1) finalize() { + hmac1CtxFree(h.ctx) } -func (h *opensslHMAC) Write(p []byte) (int, error) { +func (h *hmac1) Write(p []byte) (int, error) { if len(p) > 0 { C.go_openssl_HMAC_Update(h.ctx, base(p), C.size_t(len(p))) } @@ -85,15 +103,15 @@ func (h *opensslHMAC) Write(p []byte) (int, error) { return len(p), nil } -func (h *opensslHMAC) Size() int { +func (h *hmac1) Size() int { return h.size } -func (h *opensslHMAC) BlockSize() int { +func (h *hmac1) BlockSize() int { return h.blockSize } -func (h *opensslHMAC) Sum(in []byte) []byte { +func (h *hmac1) Sum(in []byte) []byte { if h.sum == nil { size := h.Size() h.sum = make([]byte, size) @@ -102,8 +120,8 @@ func (h *opensslHMAC) Sum(in []byte) []byte { // that Sum has no effect on the underlying stream. // In particular it is OK to Sum, then Write more, then Sum again, // and the second Sum acts as if the first didn't happen. - ctx2 := hmacCtxNew() - defer hmacCtxFree(ctx2) + ctx2 := hmac1CtxNew() + defer hmac1CtxFree(ctx2) if C.go_openssl_HMAC_CTX_copy(ctx2, h.ctx) == 0 { panic("openssl: HMAC_CTX_copy failed") } @@ -111,7 +129,7 @@ func (h *opensslHMAC) Sum(in []byte) []byte { return append(in, h.sum...) } -func hmacCtxNew() C.GO_HMAC_CTX_PTR { +func hmac1CtxNew() C.GO_HMAC_CTX_PTR { if vMajor == 1 && vMinor == 0 { // 0x120 is the sizeof value when building against OpenSSL 1.0.2 on Ubuntu 16.04. ctx := (C.GO_HMAC_CTX_PTR)(C.malloc(0x120)) @@ -123,7 +141,7 @@ func hmacCtxNew() C.GO_HMAC_CTX_PTR { return C.go_openssl_HMAC_CTX_new() } -func hmacCtxReset(ctx C.GO_HMAC_CTX_PTR) { +func hmac1CtxReset(ctx C.GO_HMAC_CTX_PTR) { if ctx == nil { return } @@ -135,7 +153,7 @@ func hmacCtxReset(ctx C.GO_HMAC_CTX_PTR) { C.go_openssl_HMAC_CTX_reset(ctx) } -func hmacCtxFree(ctx C.GO_HMAC_CTX_PTR) { +func hmac1CtxFree(ctx C.GO_HMAC_CTX_PTR) { if ctx == nil { return } @@ -146,3 +164,88 @@ func hmacCtxFree(ctx C.GO_HMAC_CTX_PTR) { } C.go_openssl_HMAC_CTX_free(ctx) } + +// hmac3 implements hash.Hash +// using functions available in OpenSSL 3. +type hmac3 struct { + md C.GO_EVP_MAC_PTR + ctx C.GO_EVP_MAC_CTX_PTR + params [2]C.OSSL_PARAM + size int + blockSize int + key []byte + sum []byte +} + +func newHMAC3(key []byte, h hash.Hash, md C.GO_EVP_MD_PTR) *hmac3 { + mac := C.go_openssl_EVP_MAC_fetch(nil, paramAlgHMAC, nil) + ctx := C.go_openssl_EVP_MAC_CTX_new(mac) + if ctx == nil { + panic("openssl: EVP_MAC_CTX_new failed") + } + digest := C.go_openssl_EVP_MD_get0_name(md) + params := [2]C.OSSL_PARAM{ + C.go_openssl_OSSL_PARAM_construct_utf8_string(paramDigest, digest, 0), + C.go_openssl_OSSL_PARAM_construct_end(), + } + hmac := &hmac3{ + md: mac, + ctx: ctx, + params: params, + size: h.Size(), + blockSize: h.BlockSize(), + key: key, + } + runtime.SetFinalizer(hmac, (*hmac3).finalize) + hmac.Reset() + return hmac +} + +func (h *hmac3) Reset() { + if C.go_openssl_EVP_MAC_init(h.ctx, base(h.key), C.size_t(len(h.key)), &h.params[0]) == 0 { + panic(newOpenSSLError("EVP_MAC_init failed")) + } + runtime.KeepAlive(h) // Next line will keep h alive too; just making doubly sure. + h.sum = nil +} + +func (h *hmac3) finalize() { + if h.ctx == nil { + return + } + C.go_openssl_EVP_MAC_CTX_free(h.ctx) +} + +func (h *hmac3) Write(p []byte) (int, error) { + if len(p) > 0 { + C.go_openssl_EVP_MAC_update(h.ctx, base(p), C.size_t(len(p))) + } + runtime.KeepAlive(h) + return len(p), nil +} + +func (h *hmac3) Size() int { + return h.size +} + +func (h *hmac3) BlockSize() int { + return h.blockSize +} + +func (h *hmac3) Sum(in []byte) []byte { + if h.sum == nil { + size := h.Size() + h.sum = make([]byte, size) + } + // Make copy of context because Go hash.Hash mandates + // that Sum has no effect on the underlying stream. + // In particular it is OK to Sum, then Write more, then Sum again, + // and the second Sum acts as if the first didn't happen. + ctx2 := C.go_openssl_EVP_MAC_CTX_dup(h.ctx) + if ctx2 == nil { + panic("openssl: EVP_MAC_CTX_dup failed") + } + defer C.go_openssl_EVP_MAC_CTX_free(ctx2) + C.go_openssl_EVP_MAC_final(ctx2, base(h.sum), nil, C.size_t(len(h.sum))) + return append(in, h.sum...) +} diff --git a/openssl/openssl_funcs.h b/openssl/openssl_funcs.h index 9abe8ad..21c7ee9 100644 --- a/openssl/openssl_funcs.h +++ b/openssl/openssl_funcs.h @@ -83,6 +83,21 @@ typedef void* GO_EC_KEY_PTR; typedef void* GO_EC_POINT_PTR; typedef void* GO_EC_GROUP_PTR; typedef void* GO_RSA_PTR; +typedef void* GO_EVP_MAC_PTR; +typedef void* GO_EVP_MAC_CTX_PTR; + +// OSSL_PARAM does not follow the GO_FOO_PTR pattern +// because it is not passed around as a pointer but on the stack. +// We can't abstract it away by using a void*. +// Copied from +// https://github.com/openssl/openssl/blob/fcae2ae4f675def607d338b7945b9af1dd9bb746/include/openssl/core.h#L82-L88. +typedef struct { + const char *key; + unsigned int data_type; + void *data; + size_t data_size; + size_t return_size; +} OSSL_PARAM; // List of all functions from the libcrypto that are used in this package. // Forgetting to add a function here results in build failure with message reporting the function @@ -158,6 +173,7 @@ DEFINEFUNC_RENAMED_1_1(void, EVP_MD_CTX_free, EVP_MD_CTX_destroy, (GO_EVP_MD_CTX DEFINEFUNC(int, EVP_MD_CTX_copy_ex, (GO_EVP_MD_CTX_PTR out, const GO_EVP_MD_CTX_PTR in), (out, in)) \ DEFINEFUNC(int, EVP_MD_CTX_copy, (GO_EVP_MD_CTX_PTR out, const GO_EVP_MD_CTX_PTR in), (out, in)) \ DEFINEFUNC_RENAMED_1_1(int, EVP_MD_CTX_reset, EVP_MD_CTX_cleanup, (GO_EVP_MD_CTX_PTR ctx), (ctx)) \ +DEFINEFUNC_3_0(const char *, EVP_MD_get0_name, (const GO_EVP_MD_PTR md), (md)) \ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_md5, (void), ()) \ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_sha1, (void), ()) \ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_sha224, (void), ()) \ @@ -249,4 +265,14 @@ DEFINEFUNC(int, EVP_PKEY_decrypt_init, (GO_EVP_PKEY_CTX_PTR arg0), (arg0)) \ DEFINEFUNC(int, EVP_PKEY_encrypt_init, (GO_EVP_PKEY_CTX_PTR arg0), (arg0)) \ DEFINEFUNC(int, EVP_PKEY_sign_init, (GO_EVP_PKEY_CTX_PTR arg0), (arg0)) \ DEFINEFUNC(int, EVP_PKEY_verify_init, (GO_EVP_PKEY_CTX_PTR arg0), (arg0)) \ -DEFINEFUNC(int, EVP_PKEY_sign, (GO_EVP_PKEY_CTX_PTR arg0, unsigned char *arg1, size_t *arg2, const unsigned char *arg3, size_t arg4), (arg0, arg1, arg2, arg3, arg4)) +DEFINEFUNC(int, EVP_PKEY_sign, (GO_EVP_PKEY_CTX_PTR arg0, unsigned char *arg1, size_t *arg2, const unsigned char *arg3, size_t arg4), (arg0, arg1, arg2, arg3, arg4)) \ +DEFINEFUNC_3_0(GO_EVP_MAC_PTR, EVP_MAC_fetch, (GO_OSSL_LIB_CTX_PTR ctx, const char *algorithm, const char *properties), (ctx, algorithm, properties)) \ +DEFINEFUNC_3_0(GO_EVP_MAC_CTX_PTR, EVP_MAC_CTX_new, (GO_EVP_MAC_PTR arg0), (arg0)) \ +DEFINEFUNC_3_0(void, EVP_MAC_CTX_free, (GO_EVP_MAC_CTX_PTR arg0), (arg0)) \ +DEFINEFUNC_3_0(GO_EVP_MAC_CTX_PTR, EVP_MAC_CTX_dup, (const GO_EVP_MAC_CTX_PTR arg0), (arg0)) \ +DEFINEFUNC_3_0(int, EVP_MAC_init, (GO_EVP_MAC_CTX_PTR ctx, const unsigned char *key, size_t keylen, const OSSL_PARAM params[]), (ctx, key, keylen, params)) \ +DEFINEFUNC_3_0(int, EVP_MAC_update, (GO_EVP_MAC_CTX_PTR ctx, const unsigned char *data, size_t datalen), (ctx, data, datalen)) \ +DEFINEFUNC_3_0(int, EVP_MAC_final, (GO_EVP_MAC_CTX_PTR ctx, unsigned char *out, size_t *outl, size_t outsize), (ctx, out, outl, outsize)) \ +DEFINEFUNC_3_0(OSSL_PARAM, OSSL_PARAM_construct_utf8_string, (const char *key, char *buf, size_t bsize), (key, buf, bsize)) \ +DEFINEFUNC_3_0(OSSL_PARAM, OSSL_PARAM_construct_end, (void), ()) \ +