From 5cf6f2609ee118aa4317865d0bc706a545c92832 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Mon, 29 Apr 2024 22:35:25 +0800 Subject: [PATCH] Add sm4_MODE_encrypt_blocks --- CMakeLists.txt | 1 + include/gmssl/sm4.h | 14 ++- src/sm4.c | 83 ++++++++++++++ src/sm4_cbc.c | 62 ++++------ src/sm4_ctr.c | 95 ++++++++-------- src/sm4_ecb.c | 29 +++-- src/sm4_tbox.c | 272 +++++++++++++++++++++++++++++++++++++++++++- tests/sm4_ecbtest.c | 8 +- tests/sm4_gcmtest.c | 34 ++++++ tests/sm4test.c | 135 ++++++++++++++++++++-- 10 files changed, 615 insertions(+), 118 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e09bdc589..a72da6096 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -138,6 +138,7 @@ set(tests sm4 sm4_cbc sm4_ctr + sm4_gcm sm3 sm4_sm3_hmac sm2_z256 diff --git a/include/gmssl/sm4.h b/include/gmssl/sm4.h index bd8f4d2a4..6a47c0917 100644 --- a/include/gmssl/sm4.h +++ b/include/gmssl/sm4.h @@ -33,18 +33,22 @@ typedef struct { void sm4_set_encrypt_key(SM4_KEY *key, const uint8_t raw_key[SM4_KEY_SIZE]); void sm4_set_decrypt_key(SM4_KEY *key, const uint8_t raw_key[SM4_KEY_SIZE]); void sm4_encrypt(const SM4_KEY *key, const uint8_t in[SM4_BLOCK_SIZE], uint8_t out[SM4_BLOCK_SIZE]); -void sm4_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, uint8_t *out); - +void sm4_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, uint8_t *out); void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE], const uint8_t *in, size_t nblocks, uint8_t *out); void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE], const uint8_t *in, size_t nblocks, uint8_t *out); +void sm4_ctr_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t nblocks, uint8_t *out); +void sm4_ctr32_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t nblocks, uint8_t *out); -int sm4_cbc_padding_encrypt(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE], +int sm4_cbc_padding_encrypt(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); -int sm4_cbc_padding_decrypt(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE], +int sm4_cbc_padding_decrypt(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +void sm4_ctr_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out); +void sm4_ctr32_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out); + typedef struct { union { @@ -140,7 +144,7 @@ _gmssl_export int sm4_gcm_decrypt_finish(SM4_GCM_CTX *ctx, #ifdef ENABLE_SM4_ECB // call `sm4_set_decrypt_key` before decrypt -void sm4_ecb_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, uint8_t *out); +//void sm4_ecb_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, uint8_t *out); typedef struct { SM4_KEY sm4_key; diff --git a/src/sm4.c b/src/sm4.c index 257bf2a80..c372175a5 100644 --- a/src/sm4.c +++ b/src/sm4.c @@ -180,3 +180,86 @@ void sm4_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, u out += 16; } } + +void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], + const uint8_t *in, size_t nblocks, uint8_t *out) +{ + while (nblocks--) { + size_t i; + for (i = 0; i < 16; i++) { + out[i] = in[i] ^ iv[i]; + } + sm4_encrypt(key, out, out); + iv = out; + in += 16; + out += 16; + } +} + +void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], + const uint8_t *in, size_t nblocks, uint8_t *out) +{ + while (nblocks--) { + size_t i; + sm4_encrypt(key, in, out); + for (i = 0; i < 16; i++) { + out[i] ^= iv[i]; + } + iv = in; + in += 16; + out += 16; + } +} + +static void ctr_incr(uint8_t a[16]) { + int i; + for (i = 15; i >= 0; i--) { + a[i]++; + if (a[i]) break; + } +} + +void sm4_ctr_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out) +{ + uint8_t block[16]; + size_t len, i; + + while (inlen) { + len = inlen < 16 ? inlen : 16; + sm4_encrypt(key, ctr, block); + for (i = 0; i < len; i++) { + out[i] = in[i] ^ block[i]; + } + ctr_incr(ctr); + in += len; + out += len; + inlen -= len; + } +} + +// inc32() in nist-sp800-38d +static void ctr32_incr(uint8_t a[16]) { + int i; + for (i = 15; i >= 12; i--) { + a[i]++; + if (a[i]) break; + } +} + +void sm4_ctr32_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out) +{ + uint8_t block[16]; + size_t len, i; + + while (inlen) { + len = inlen < 16 ? inlen : 16; + sm4_encrypt(key, ctr, block); + for (i = 0; i < len; i++) { + out[i] = in[i] ^ block[i]; + } + ctr32_incr(ctr); + in += len; + out += len; + inlen -= len; + } +} diff --git a/src/sm4_cbc.c b/src/sm4_cbc.c index 7c73ba5fe..f0858326b 100644 --- a/src/sm4_cbc.c +++ b/src/sm4_cbc.c @@ -13,44 +13,6 @@ #include -void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], - const uint8_t *in, size_t nblocks, uint8_t *out) -{ - while (nblocks--) { - gmssl_memxor(out, in, iv, 16); - sm4_encrypt(key, out, out); - iv = out; - in += 16; - out += 16; - } -} - -void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], - const uint8_t *in, size_t nblocks, uint8_t *out) -{ - while (nblocks >= 8) { - uint8_t buf[16 * 8]; - - sm4_encrypt_blocks(key, in, 8, buf); - - gmssl_memxor(out, buf, iv, 16); - gmssl_memxor(out + 16, buf + 16, in, 16 * (8 - 1)); - - iv = in + 16 * (8 - 1); - in += 16 * 8; - out += 16 * 8; - nblocks -= 8; - } - - while (nblocks--) { - sm4_encrypt(key, in, out); - memxor(out, iv, 16); - iv = in; - in += 16; - out += 16; - } -} - int sm4_cbc_padding_encrypt(const SM4_KEY *key, const uint8_t iv[16], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) @@ -109,6 +71,10 @@ int sm4_cbc_padding_decrypt(const SM4_KEY *key, const uint8_t iv[16], int sm4_cbc_encrypt_init(SM4_CBC_CTX *ctx, const uint8_t key[SM4_BLOCK_SIZE], const uint8_t iv[SM4_BLOCK_SIZE]) { + if (!ctx || !key || !iv) { + error_print(); + return -1; + } sm4_set_encrypt_key(&ctx->sm4_key, key); memcpy(ctx->iv, iv, SM4_BLOCK_SIZE); memset(ctx->block, 0, SM4_BLOCK_SIZE); @@ -123,6 +89,10 @@ int sm4_cbc_encrypt_update(SM4_CBC_CTX *ctx, size_t nblocks; size_t len; + if (!ctx || !in || !out || !outlen) { + error_print(); + return -1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; @@ -162,6 +132,10 @@ int sm4_cbc_encrypt_update(SM4_CBC_CTX *ctx, int sm4_cbc_encrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen) { + if (!ctx || !out || !outlen) { + error_print(); + return -1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; @@ -176,6 +150,10 @@ int sm4_cbc_encrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen) int sm4_cbc_decrypt_init(SM4_CBC_CTX *ctx, const uint8_t key[SM4_BLOCK_SIZE], const uint8_t iv[SM4_BLOCK_SIZE]) { + if (!ctx || !key || !iv) { + error_print(); + return -1; + } sm4_set_decrypt_key(&ctx->sm4_key, key); memcpy(ctx->iv, iv, SM4_BLOCK_SIZE); memset(ctx->block, 0, SM4_BLOCK_SIZE); @@ -188,6 +166,10 @@ int sm4_cbc_decrypt_update(SM4_CBC_CTX *ctx, { size_t left, len, nblocks; + if (!ctx || !in || !out || !outlen) { + error_print(); + return -1; + } if (ctx->block_nbytes > SM4_BLOCK_SIZE) { error_print(); return -1; @@ -226,6 +208,10 @@ int sm4_cbc_decrypt_update(SM4_CBC_CTX *ctx, int sm4_cbc_decrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen) { + if (!ctx || !out || !outlen) { + error_print(); + return -1; + } if (ctx->block_nbytes != SM4_BLOCK_SIZE) { error_print(); return -1; diff --git a/src/sm4_ctr.c b/src/sm4_ctr.c index 52aa9561b..0f03196e7 100644 --- a/src/sm4_ctr.c +++ b/src/sm4_ctr.c @@ -13,34 +13,49 @@ #include -static void ctr_incr(uint8_t a[16]) +void sm4_ctr_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out) { - int i; - for (i = 15; i >= 0; i--) { - a[i]++; - if (a[i]) break; + if (inlen >= 16) { + size_t nblocks = inlen / 16; + size_t len = nblocks * 16; + sm4_ctr_encrypt_blocks(key, ctr, in, nblocks, out); + in += len; + out += len; + inlen -= len; + } + if (inlen) { + uint8_t block[16] = {0}; + memcpy(block, in, inlen); + sm4_ctr_encrypt_blocks(key, ctr, block, 1, block); + memcpy(out, block, inlen); } } -void sm4_ctr_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out) +void sm4_ctr32_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out) { - uint8_t block[16]; - size_t len; - - while (inlen) { - len = inlen < 16 ? inlen : 16; - sm4_encrypt(key, ctr, block); - gmssl_memxor(out, in, block, len); - ctr_incr(ctr); + if (inlen >= 16) { + size_t nblocks = inlen / 16; + size_t len = nblocks * 16; + sm4_ctr32_encrypt_blocks(key, ctr, in, nblocks, out); in += len; out += len; inlen -= len; } + if (inlen) { + uint8_t block[16] = {0}; + memcpy(block, in, inlen); + sm4_ctr32_encrypt_blocks(key, ctr, block, 1, block); + memcpy(out, block, inlen); + } } int sm4_ctr_encrypt_init(SM4_CTR_CTX *ctx, const uint8_t key[SM4_BLOCK_SIZE], const uint8_t ctr[SM4_BLOCK_SIZE]) { + if (!ctx || !key || !ctr) { + error_print(); + return -1; + } sm4_set_encrypt_key(&ctx->sm4_key, key); memcpy(ctx->ctr, ctr, SM4_BLOCK_SIZE); memset(ctx->block, 0, SM4_BLOCK_SIZE); @@ -55,6 +70,10 @@ int sm4_ctr_encrypt_update(SM4_CTR_CTX *ctx, size_t nblocks; size_t len; + if (!ctx || !in || !out || !outlen) { + error_print(); + return -1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; @@ -68,7 +87,7 @@ int sm4_ctr_encrypt_update(SM4_CTR_CTX *ctx, return 1; } memcpy(ctx->block + ctx->block_nbytes, in, left); - sm4_ctr_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, SM4_BLOCK_SIZE, out); + sm4_ctr_encrypt_blocks(&ctx->sm4_key, ctx->ctr, ctx->block, 1, out); in += left; inlen -= left; out += SM4_BLOCK_SIZE; @@ -77,7 +96,7 @@ int sm4_ctr_encrypt_update(SM4_CTR_CTX *ctx, if (inlen >= SM4_BLOCK_SIZE) { nblocks = inlen / SM4_BLOCK_SIZE; len = nblocks * SM4_BLOCK_SIZE; - sm4_ctr_encrypt(&ctx->sm4_key, ctx->ctr, in, len, out); + sm4_ctr_encrypt_blocks(&ctx->sm4_key, ctx->ctr, in, nblocks, out); in += len; inlen -= len; out += len; @@ -92,44 +111,27 @@ int sm4_ctr_encrypt_update(SM4_CTR_CTX *ctx, int sm4_ctr_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen) { + if (!ctx || !out || !outlen) { + error_print(); + return -1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; } - sm4_ctr_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, ctx->block_nbytes, out); + sm4_ctr_encrypt_blocks(&ctx->sm4_key, ctx->ctr, ctx->block, 1, ctx->block); + memcpy(out, ctx->block, ctx->block_nbytes); *outlen = ctx->block_nbytes; return 1; } -// inc32() in nist-sp800-38d -static void ctr32_incr(uint8_t a[16]) -{ - int i; - for (i = 15; i >= 12; i--) { - a[i]++; - if (a[i]) break; - } -} - -void sm4_ctr32_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out) -{ - uint8_t block[16]; - size_t len; - - while (inlen) { - len = inlen < 16 ? inlen : 16; - sm4_encrypt(key, ctr, block); - gmssl_memxor(out, in, block, len); - ctr32_incr(ctr); - in += len; - out += len; - inlen -= len; - } -} - int sm4_ctr32_encrypt_init(SM4_CTR_CTX *ctx, const uint8_t key[SM4_BLOCK_SIZE], const uint8_t ctr[SM4_BLOCK_SIZE]) { + if (!ctx || !key || !ctr) { + error_print(); + return -1; + } sm4_set_encrypt_key(&ctx->sm4_key, key); memcpy(ctx->ctr, ctr, SM4_BLOCK_SIZE); memset(ctx->block, 0, SM4_BLOCK_SIZE); @@ -157,7 +159,7 @@ int sm4_ctr32_encrypt_update(SM4_CTR_CTX *ctx, return 1; } memcpy(ctx->block + ctx->block_nbytes, in, left); - sm4_ctr32_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, SM4_BLOCK_SIZE, out); + sm4_ctr32_encrypt_blocks(&ctx->sm4_key, ctx->ctr, ctx->block, 1, out); in += left; inlen -= left; out += SM4_BLOCK_SIZE; @@ -166,7 +168,7 @@ int sm4_ctr32_encrypt_update(SM4_CTR_CTX *ctx, if (inlen >= SM4_BLOCK_SIZE) { nblocks = inlen / SM4_BLOCK_SIZE; len = nblocks * SM4_BLOCK_SIZE; - sm4_ctr32_encrypt(&ctx->sm4_key, ctx->ctr, in, len, out); + sm4_ctr32_encrypt_blocks(&ctx->sm4_key, ctx->ctr, in, nblocks, out); in += len; inlen -= len; out += len; @@ -185,7 +187,8 @@ int sm4_ctr32_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen) error_print(); return -1; } - sm4_ctr32_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, ctx->block_nbytes, out); + sm4_ctr32_encrypt_blocks(&ctx->sm4_key, ctx->ctr, ctx->block, 1, ctx->block); + memcpy(out, ctx->block, ctx->block_nbytes); *outlen = ctx->block_nbytes; return 1; } diff --git a/src/sm4_ecb.c b/src/sm4_ecb.c index 394d5daf0..06f4b2656 100644 --- a/src/sm4_ecb.c +++ b/src/sm4_ecb.c @@ -13,17 +13,12 @@ #include -void sm4_ecb_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, uint8_t *out) -{ - while (nblocks--) { - sm4_encrypt(key, in, out); - in += SM4_BLOCK_SIZE; - out += SM4_BLOCK_SIZE; - } -} - int sm4_ecb_encrypt_init(SM4_ECB_CTX *ctx, const uint8_t key[SM4_BLOCK_SIZE]) { + if (!ctx || !key) { + error_print(); + return -1; + } sm4_set_encrypt_key(&ctx->sm4_key, key); memset(ctx->block, 0, SM4_BLOCK_SIZE); ctx->block_nbytes = 0; @@ -37,6 +32,10 @@ int sm4_ecb_encrypt_update(SM4_ECB_CTX *ctx, size_t nblocks; size_t len; + if (!ctx || !in || !out || !outlen) { + error_print(); + return -1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; @@ -50,7 +49,7 @@ int sm4_ecb_encrypt_update(SM4_ECB_CTX *ctx, return 1; } memcpy(ctx->block + ctx->block_nbytes, in, left); - sm4_ecb_encrypt_blocks(&ctx->sm4_key, ctx->block, 1, out); + sm4_encrypt_blocks(&ctx->sm4_key, ctx->block, 1, out); in += left; inlen -= left; out += SM4_BLOCK_SIZE; @@ -59,7 +58,7 @@ int sm4_ecb_encrypt_update(SM4_ECB_CTX *ctx, if (inlen >= SM4_BLOCK_SIZE) { nblocks = inlen / SM4_BLOCK_SIZE; len = nblocks * SM4_BLOCK_SIZE; - sm4_ecb_encrypt_blocks(&ctx->sm4_key, in, nblocks, out); + sm4_encrypt_blocks(&ctx->sm4_key, in, nblocks, out); in += len; inlen -= len; out += len; @@ -74,6 +73,10 @@ int sm4_ecb_encrypt_update(SM4_ECB_CTX *ctx, int sm4_ecb_encrypt_finish(SM4_ECB_CTX *ctx, uint8_t *out, size_t *outlen) { + if (!ctx || !outlen) { + error_print(); + return -1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; @@ -88,6 +91,10 @@ int sm4_ecb_encrypt_finish(SM4_ECB_CTX *ctx, uint8_t *out, size_t *outlen) int sm4_ecb_decrypt_init(SM4_ECB_CTX *ctx, const uint8_t key[SM4_BLOCK_SIZE]) { + if (!ctx || !key) { + error_print(); + return -1; + } sm4_set_decrypt_key(&ctx->sm4_key, key); memset(ctx->block, 0, SM4_BLOCK_SIZE); ctx->block_nbytes = 0; diff --git a/src/sm4_tbox.c b/src/sm4_tbox.c index d48be58fe..7aa042b34 100644 --- a/src/sm4_tbox.c +++ b/src/sm4_tbox.c @@ -9,7 +9,7 @@ #include - +#include static uint32_t FK[4] = { 0xa3b1bac6, 0x56aa3350, 0x677d9197, 0xb27022dc, @@ -61,7 +61,7 @@ const uint8_t S[256] = { 0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48, }; - +/* #define GETU32(ptr) \ ((uint32_t)(ptr)[0] << 24 | \ (uint32_t)(ptr)[1] << 16 | \ @@ -75,6 +75,7 @@ const uint8_t S[256] = { (ptr)[3] = (uint8_t)(X)) #define ROL32(X,n) (((X)<<(n)) | ((X)>>(32-(n)))) +*/ #define L32(X) \ ((X) ^ \ @@ -526,3 +527,270 @@ void sm4_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, u out += 16; } } + +void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) +{ + const uint32_t *rk = key->rk; + uint32_t X0, X1, X2, X3, X4; + uint32_t X5; + + X0 = GETU32(iv ); // X0 = IV0 + X4 = GETU32(iv + 4); // X4 = IV1 + X3 = GETU32(iv + 8); // X3 = IV2 + X5 = GETU32(iv + 12); // X5 = IV3 + + while (nblocks--) { + + X0 = X0 ^ GETU32(in ); + X1 = X4 ^ GETU32(in + 4); + X2 = X3 ^ GETU32(in + 8); + X3 = X5 ^ GETU32(in + 12); + + ROUND( 0, X0, X1, X2, X3, X4); + ROUND( 1, X1, X2, X3, X4, X0); + ROUND( 2, X2, X3, X4, X0, X1); + ROUND( 3, X3, X4, X0, X1, X2); + ROUND( 4, X4, X0, X1, X2, X3); + ROUND( 5, X0, X1, X2, X3, X4); + ROUND( 6, X1, X2, X3, X4, X0); + ROUND( 7, X2, X3, X4, X0, X1); + ROUND( 8, X3, X4, X0, X1, X2); + ROUND( 9, X4, X0, X1, X2, X3); + ROUND(10, X0, X1, X2, X3, X4); + ROUND(11, X1, X2, X3, X4, X0); + ROUND(12, X2, X3, X4, X0, X1); + ROUND(13, X3, X4, X0, X1, X2); + ROUND(14, X4, X0, X1, X2, X3); + ROUND(15, X0, X1, X2, X3, X4); + ROUND(16, X1, X2, X3, X4, X0); + ROUND(17, X2, X3, X4, X0, X1); + ROUND(18, X3, X4, X0, X1, X2); + ROUND(19, X4, X0, X1, X2, X3); + ROUND(20, X0, X1, X2, X3, X4); + ROUND(21, X1, X2, X3, X4, X0); + ROUND(22, X2, X3, X4, X0, X1); + ROUND(23, X3, X4, X0, X1, X2); + ROUND(24, X4, X0, X1, X2, X3); + ROUND(25, X0, X1, X2, X3, X4); + ROUND(26, X1, X2, X3, X4, X0); + ROUND(27, X2, X3, X4, X0, X1); + ROUND(28, X3, X4, X0, X1, X2); + PUTU32(out + 12, X2); + ROUND(29, X4, X0, X1, X2, X3); + PUTU32(out + 8, X3); + ROUND(30, X0, X1, X2, X3, X4); + PUTU32(out + 4, X4); + ROUND(31, X1, X2, X3, X4, X0); + PUTU32(out, X0); + + X5 = X2; + + in += 16; + out += 16; + } +} + +void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) +{ + const uint32_t *rk = key->rk; + uint32_t IV0, IV1, IV2, IV3; + uint32_t X0, X1, X2, X3, X4; + uint32_t C0, C1, C2, C3; + + IV0 = GETU32(iv ); // X0 = IV0 + IV1 = GETU32(iv + 4); // X4 = IV1 + IV2 = GETU32(iv + 8); // X3 = IV2 + IV3 = GETU32(iv + 12); // X5 = IV3 + + while (nblocks--) { + + X0 = C0 = GETU32(in ); + X1 = C1 = GETU32(in + 4); + X2 = C2 = GETU32(in + 8); + X3 = C3 = GETU32(in + 12); + + ROUND( 0, X0, X1, X2, X3, X4); + ROUND( 1, X1, X2, X3, X4, X0); + ROUND( 2, X2, X3, X4, X0, X1); + ROUND( 3, X3, X4, X0, X1, X2); + ROUND( 4, X4, X0, X1, X2, X3); + ROUND( 5, X0, X1, X2, X3, X4); + ROUND( 6, X1, X2, X3, X4, X0); + ROUND( 7, X2, X3, X4, X0, X1); + ROUND( 8, X3, X4, X0, X1, X2); + ROUND( 9, X4, X0, X1, X2, X3); + ROUND(10, X0, X1, X2, X3, X4); + ROUND(11, X1, X2, X3, X4, X0); + ROUND(12, X2, X3, X4, X0, X1); + ROUND(13, X3, X4, X0, X1, X2); + ROUND(14, X4, X0, X1, X2, X3); + ROUND(15, X0, X1, X2, X3, X4); + ROUND(16, X1, X2, X3, X4, X0); + ROUND(17, X2, X3, X4, X0, X1); + ROUND(18, X3, X4, X0, X1, X2); + ROUND(19, X4, X0, X1, X2, X3); + ROUND(20, X0, X1, X2, X3, X4); + ROUND(21, X1, X2, X3, X4, X0); + ROUND(22, X2, X3, X4, X0, X1); + ROUND(23, X3, X4, X0, X1, X2); + ROUND(24, X4, X0, X1, X2, X3); + ROUND(25, X0, X1, X2, X3, X4); + ROUND(26, X1, X2, X3, X4, X0); + ROUND(27, X2, X3, X4, X0, X1); + ROUND(28, X3, X4, X0, X1, X2); + PUTU32(out + 12, IV3 ^ X2); + ROUND(29, X4, X0, X1, X2, X3); + PUTU32(out + 8, IV2 ^ X3); + ROUND(30, X0, X1, X2, X3, X4); + PUTU32(out + 4, IV1 ^ X4); + ROUND(31, X1, X2, X3, X4, X0); + PUTU32(out, IV0 ^ X0); + + IV0 = C0; + IV1 = C1; + IV2 = C2; + IV3 = C3; + + in += 16; + out += 16; + } +} + +void sm4_ctr_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t nblocks, uint8_t *out) +{ + const uint32_t *rk = key->rk; + uint32_t X0, X1, X2, X3, X4; + uint64_t C0, C1; + uint32_t D0, D1, D2, D3; + + C0 = GETU64(ctr ); + C1 = GETU64(ctr + 8); + + while (nblocks--) { + + X0 = (uint32_t)(C0 >> 32); + X1 = (uint32_t)(C0 ); + X2 = (uint32_t)(C1 >> 32); + X3 = (uint32_t)(C1 ); + + D0 = GETU32(in ); + D1 = GETU32(in + 4); + D2 = GETU32(in + 8); + D3 = GETU32(in + 12); + + ROUND( 0, X0, X1, X2, X3, X4); + ROUND( 1, X1, X2, X3, X4, X0); + ROUND( 2, X2, X3, X4, X0, X1); + ROUND( 3, X3, X4, X0, X1, X2); + ROUND( 4, X4, X0, X1, X2, X3); + ROUND( 5, X0, X1, X2, X3, X4); + ROUND( 6, X1, X2, X3, X4, X0); + ROUND( 7, X2, X3, X4, X0, X1); + ROUND( 8, X3, X4, X0, X1, X2); + ROUND( 9, X4, X0, X1, X2, X3); + ROUND(10, X0, X1, X2, X3, X4); + ROUND(11, X1, X2, X3, X4, X0); + ROUND(12, X2, X3, X4, X0, X1); + ROUND(13, X3, X4, X0, X1, X2); + ROUND(14, X4, X0, X1, X2, X3); + ROUND(15, X0, X1, X2, X3, X4); + ROUND(16, X1, X2, X3, X4, X0); + ROUND(17, X2, X3, X4, X0, X1); + ROUND(18, X3, X4, X0, X1, X2); + ROUND(19, X4, X0, X1, X2, X3); + ROUND(20, X0, X1, X2, X3, X4); + ROUND(21, X1, X2, X3, X4, X0); + ROUND(22, X2, X3, X4, X0, X1); + ROUND(23, X3, X4, X0, X1, X2); + ROUND(24, X4, X0, X1, X2, X3); + ROUND(25, X0, X1, X2, X3, X4); + ROUND(26, X1, X2, X3, X4, X0); + ROUND(27, X2, X3, X4, X0, X1); + ROUND(28, X3, X4, X0, X1, X2); + PUTU32(out + 12, D3 ^ X2); + ROUND(29, X4, X0, X1, X2, X3); + PUTU32(out + 8, D2 ^ X3); + ROUND(30, X0, X1, X2, X3, X4); + PUTU32(out + 4, D1 ^ X4); + ROUND(31, X1, X2, X3, X4, X0); + PUTU32(out, D0 ^ X0); + + C1++; + C0 = (C1 == 0) ? C0 + 1 : C0; + + in += 16; + out += 16; + } + + PUTU64(ctr , C0); + PUTU64(ctr + 8, C1); +} + +void sm4_ctr32_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t nblocks, uint8_t *out) +{ + const uint32_t *rk = key->rk; + uint32_t X0, X1, X2, X3, X4; + uint32_t C0, C1, C2, C3; + uint32_t D0, D1, D2, D3; + + C0 = GETU32(ctr ); + C1 = GETU32(ctr + 4); + C2 = GETU32(ctr + 8); + C3 = GETU32(ctr + 12); + + while (nblocks--) { + + X0 = C0; + X1 = C1; + X2 = C2; + X3 = C3++; + + D0 = GETU32(in ); + D1 = GETU32(in + 4); + D2 = GETU32(in + 8); + D3 = GETU32(in + 12); + + ROUND( 0, X0, X1, X2, X3, X4); + ROUND( 1, X1, X2, X3, X4, X0); + ROUND( 2, X2, X3, X4, X0, X1); + ROUND( 3, X3, X4, X0, X1, X2); + ROUND( 4, X4, X0, X1, X2, X3); + ROUND( 5, X0, X1, X2, X3, X4); + ROUND( 6, X1, X2, X3, X4, X0); + ROUND( 7, X2, X3, X4, X0, X1); + ROUND( 8, X3, X4, X0, X1, X2); + ROUND( 9, X4, X0, X1, X2, X3); + ROUND(10, X0, X1, X2, X3, X4); + ROUND(11, X1, X2, X3, X4, X0); + ROUND(12, X2, X3, X4, X0, X1); + ROUND(13, X3, X4, X0, X1, X2); + ROUND(14, X4, X0, X1, X2, X3); + ROUND(15, X0, X1, X2, X3, X4); + ROUND(16, X1, X2, X3, X4, X0); + ROUND(17, X2, X3, X4, X0, X1); + ROUND(18, X3, X4, X0, X1, X2); + ROUND(19, X4, X0, X1, X2, X3); + ROUND(20, X0, X1, X2, X3, X4); + ROUND(21, X1, X2, X3, X4, X0); + ROUND(22, X2, X3, X4, X0, X1); + ROUND(23, X3, X4, X0, X1, X2); + ROUND(24, X4, X0, X1, X2, X3); + ROUND(25, X0, X1, X2, X3, X4); + ROUND(26, X1, X2, X3, X4, X0); + ROUND(27, X2, X3, X4, X0, X1); + ROUND(28, X3, X4, X0, X1, X2); + PUTU32(out + 12, D3 ^ X2); + ROUND(29, X4, X0, X1, X2, X3); + PUTU32(out + 8, D2 ^ X3); + ROUND(30, X0, X1, X2, X3, X4); + PUTU32(out + 4, D1 ^ X4); + ROUND(31, X1, X2, X3, X4, X0); + PUTU32(out, D0 ^ X0); + + in += 16; + out += 16; + } + + PUTU32(ctr + 12, C3); +} + diff --git a/tests/sm4_ecbtest.c b/tests/sm4_ecbtest.c index 920950fb3..e7b8d3c41 100644 --- a/tests/sm4_ecbtest.c +++ b/tests/sm4_ecbtest.c @@ -32,10 +32,10 @@ static int test_sm4_ecb(void) rand_bytes(plaintext, sizeof(plaintext)); sm4_set_encrypt_key(&sm4_key, key); - sm4_ecb_encrypt_blocks(&sm4_key, plaintext, sizeof(plaintext)/16, encrypted); + sm4_encrypt_blocks(&sm4_key, plaintext, sizeof(plaintext)/16, encrypted); sm4_set_decrypt_key(&sm4_key, key); - sm4_ecb_encrypt_blocks(&sm4_key, encrypted, sizeof(encrypted)/16, decrypted); + sm4_encrypt_blocks(&sm4_key, encrypted, sizeof(encrypted)/16, decrypted); if (memcmp(decrypted, plaintext, sizeof(plaintext)) != 0) { error_print(); @@ -69,7 +69,7 @@ static int test_sm4_ecb_test_vectors(void) uint8_t decrypted[sizeof(plaintext)] = {0}; sm4_set_encrypt_key(&sm4_key, key); - sm4_ecb_encrypt_blocks(&sm4_key, plaintext, sizeof(plaintext)/16, encrypted); + sm4_encrypt_blocks(&sm4_key, plaintext, sizeof(plaintext)/16, encrypted); format_bytes(stderr, 0, 0, "", encrypted, sizeof(encrypted)); @@ -79,7 +79,7 @@ static int test_sm4_ecb_test_vectors(void) } sm4_set_decrypt_key(&sm4_key, key); - sm4_ecb_encrypt_blocks(&sm4_key, encrypted, sizeof(encrypted)/16, decrypted); + sm4_encrypt_blocks(&sm4_key, encrypted, sizeof(encrypted)/16, decrypted); if (memcmp(decrypted, plaintext, sizeof(plaintext)) != 0) { error_print(); diff --git a/tests/sm4_gcmtest.c b/tests/sm4_gcmtest.c index 4d7bee7b1..d70459e29 100644 --- a/tests/sm4_gcmtest.c +++ b/tests/sm4_gcmtest.c @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -312,12 +313,45 @@ static int test_sm4_gcm_ctx(void) return 1; } +static int speed_sm4_gcm_encrypt(void) +{ + SM4_KEY sm4_key; + uint8_t key[16] = {0}; + uint8_t iv[12]; + uint8_t aad[16]; + uint8_t tag[16]; + uint32_t buf[1024]; + clock_t begin, end; + double seconds; + int i; + + sm4_set_encrypt_key(&sm4_key, key); + + for (i = 0; i < 4096; i++) { + sm4_gcm_encrypt(&sm4_key, iv, sizeof(iv), aad, sizeof(aad), (uint8_t *)buf, sizeof(buf), (uint8_t *)buf, 16, tag); + } + begin = clock(); + for (i = 0; i < 4096; i++) { + sm4_gcm_encrypt(&sm4_key, iv, sizeof(iv), aad, sizeof(aad), (uint8_t *)buf, sizeof(buf), (uint8_t *)buf, 16, tag); + } + end = clock(); + + seconds = (double)(end - begin)/ CLOCKS_PER_SEC; + fprintf(stderr, "%s: %f MiB per second\n", __FUNCTION__, 16/seconds); + + return 1; +} + + int main(void) { if (test_sm4_gcm() != 1) goto err; if (test_sm4_gcm_gbt36624_1() != 1) goto err; if (test_sm4_gcm_gbt36624_2() != 1) goto err; if (test_sm4_gcm_ctx() != 1) goto err; +#if ENABLE_TEST_SPEED + if (speed_sm4_gcm_encrypt() != 1) goto err; +#endif printf("%s all tests passed\n", __FILE__); return 0; err: diff --git a/tests/sm4test.c b/tests/sm4test.c index ae42930d5..8a388db4e 100644 --- a/tests/sm4test.c +++ b/tests/sm4test.c @@ -133,7 +133,7 @@ static int test_sm4_encrypt_blocks(void) return 1; } -static int test_sm4_encrypt_speed(void) +static int speed_sm4_encrypt(void) { SM4_KEY sm4_key; uint8_t key[16] = {0}; @@ -155,49 +155,160 @@ static int test_sm4_encrypt_speed(void) end = clock(); seconds = (double)(end - begin)/ CLOCKS_PER_SEC; - fprintf(stderr, "sm4_encrypt: %f MiB per second\n", nbytes/(1024 * 1024 *seconds)); + fprintf(stderr, "%s: %f MiB per second\n", __FUNCTION__, nbytes/(1024 * 1024 *seconds)); return 1; } -static int test_sm4_encrypt_blocks_speed(void) +static int speed_sm4_encrypt_blocks(void) { SM4_KEY sm4_key; uint8_t key[16] = {0}; - //uint32_t buf[1024]; - uint8_t buf[4096 + 100] __attribute__((aligned(16))); + uint32_t buf[1024]; clock_t begin, end; double seconds; int i; sm4_set_encrypt_key(&sm4_key, key); for (i = 0; i < 4096; i++) { - // fprintf(stderr, "."); sm4_encrypt_blocks(&sm4_key, (uint8_t *)buf, sizeof(buf)/16, (uint8_t *)buf); } - //fprintf(stderr, "start\n"); - begin = clock(); for (i = 0; i < 4096; i++) { sm4_encrypt_blocks(&sm4_key, (uint8_t *)buf, sizeof(buf)/16, (uint8_t *)buf); - // fprintf(stderr, "."); } end = clock(); seconds = (double)(end - begin)/ CLOCKS_PER_SEC; - fprintf(stderr, "sm4_encrypt_blocks: %f MiB per second\n", 16/seconds); + fprintf(stderr, "%s: %f MiB per second\n", __FUNCTION__, 16/seconds); + + return 1; +} + +static int speed_sm4_cbc_encrypt_blocks(void) +{ + SM4_KEY sm4_key; + uint8_t key[16] = {0}; + uint8_t iv[16]; + uint32_t buf[1024]; + clock_t begin, end; + double seconds; + int i; + + sm4_set_encrypt_key(&sm4_key, key); + + for (i = 0; i < 4096; i++) { + sm4_cbc_encrypt_blocks(&sm4_key, iv, (uint8_t *)buf, sizeof(buf)/16, (uint8_t *)buf); + } + begin = clock(); + for (i = 0; i < 4096; i++) { + sm4_cbc_encrypt_blocks(&sm4_key, iv, (uint8_t *)buf, sizeof(buf)/16, (uint8_t *)buf); + } + end = clock(); + + seconds = (double)(end - begin)/ CLOCKS_PER_SEC; + fprintf(stderr, "%s: %f MiB per second\n", __FUNCTION__, 16/seconds); + + return 1; +} + +static int speed_sm4_cbc_decrypt_blocks(void) +{ + SM4_KEY sm4_key; + uint8_t key[16] = {0}; + uint8_t iv[16]; + uint32_t buf[1024]; + clock_t begin, end; + double seconds; + int i; + + sm4_set_decrypt_key(&sm4_key, key); + + for (i = 0; i < 4096; i++) { + sm4_cbc_decrypt_blocks(&sm4_key, iv, (uint8_t *)buf, sizeof(buf)/16, (uint8_t *)buf); + } + begin = clock(); + for (i = 0; i < 4096; i++) { + sm4_cbc_decrypt_blocks(&sm4_key, iv, (uint8_t *)buf, sizeof(buf)/16, (uint8_t *)buf); + } + end = clock(); + + seconds = (double)(end - begin)/ CLOCKS_PER_SEC; + fprintf(stderr, "%s: %f MiB per second\n", __FUNCTION__, 16/seconds); + + return 1; +} + +static int speed_sm4_ctr_encrypt_blocks(void) +{ + SM4_KEY sm4_key; + uint8_t key[16] = {0}; + uint8_t ctr[16]; + uint32_t buf[1024]; + clock_t begin, end; + double seconds; + int i; + + sm4_set_encrypt_key(&sm4_key, key); + rand_bytes(ctr, sizeof(ctr)); + + for (i = 0; i < 4096; i++) { + sm4_ctr_encrypt_blocks(&sm4_key, ctr, (uint8_t *)buf, sizeof(buf)/16, (uint8_t *)buf); + } + begin = clock(); + for (i = 0; i < 4096; i++) { + sm4_ctr_encrypt_blocks(&sm4_key, ctr, (uint8_t *)buf, sizeof(buf)/16, (uint8_t *)buf); + } + end = clock(); + + seconds = (double)(end - begin)/ CLOCKS_PER_SEC; + fprintf(stderr, "%s: %f MiB per second\n", __FUNCTION__, 16/seconds); return 1; } +static int speed_sm4_ctr32_encrypt_blocks(void) +{ + SM4_KEY sm4_key; + uint8_t key[16] = {0}; + uint8_t ctr[16]; + uint32_t buf[1024]; + clock_t begin, end; + double seconds; + int i; + + sm4_set_encrypt_key(&sm4_key, key); + rand_bytes(ctr, sizeof(ctr)); + + for (i = 0; i < 4096; i++) { + sm4_ctr32_encrypt_blocks(&sm4_key, ctr, (uint8_t *)buf, sizeof(buf)/16, (uint8_t *)buf); + } + begin = clock(); + for (i = 0; i < 4096; i++) { + sm4_ctr32_encrypt_blocks(&sm4_key, ctr, (uint8_t *)buf, sizeof(buf)/16, (uint8_t *)buf); + } + end = clock(); + + seconds = (double)(end - begin)/ CLOCKS_PER_SEC; + fprintf(stderr, "%s: %f MiB per second\n", __FUNCTION__, 16/seconds); + + return 1; +} + + + int main(void) { if (test_sm4() != 1) goto err; if (test_sm4_encrypt_blocks() != 1) goto err; #if ENABLE_TEST_SPEED - if (test_sm4_encrypt_speed() != 1) goto err; - if (test_sm4_encrypt_blocks_speed() != 1) goto err; + if (speed_sm4_encrypt() != 1) goto err; + if (speed_sm4_encrypt_blocks() != 1) goto err; + if (speed_sm4_cbc_encrypt_blocks() != 1) goto err; + if (speed_sm4_cbc_decrypt_blocks() != 1) goto err; + if (speed_sm4_ctr_encrypt_blocks() != 1) goto err; + if (speed_sm4_ctr32_encrypt_blocks() != 1) goto err; #endif printf("%s all tests passed\n", __FILE__); return 0;