Skip to content

Commit 2bcb8cc

Browse files
committed
encrypt/decrypt
ghstack-source-id: c7cca30 Pull Request resolved: #83
1 parent 64fedf7 commit 2bcb8cc

File tree

4 files changed

+451
-124
lines changed

4 files changed

+451
-124
lines changed

test/test_csprng.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,5 +354,51 @@ def test_const_generator(self):
354354
second = torch.empty(self.size, dtype=dtype, device=device).random_(generator=const_gen)
355355
self.assertTrue((first - second).max().abs() == 0)
356356

357+
def test_encrypt_decrypt(self):
358+
key_size_bytes = 16
359+
block_size_bytes = 16
360+
361+
def sizeof(dtype):
362+
if dtype == torch.bool:
363+
return 1
364+
elif dtype.is_floating_point:
365+
return torch.finfo(dtype).bits // 8
366+
else:
367+
return torch.iinfo(dtype).bits // 8
368+
369+
for device in self.all_devices:
370+
for key_dtype in self.all_dtypes:
371+
key_size = key_size_bytes // sizeof(key_dtype)
372+
key = torch.empty(key_size, dtype=key_dtype, device=device).random_()
373+
for initial_dtype in self.all_dtypes:
374+
for encrypted_dtype in self.all_dtypes:
375+
for decrypted_dtype in self.all_dtypes:
376+
for initial_size in [0, 4, 8, 15, 16, 23, 42]:
377+
for mode in ["ecb", "ctr"]:
378+
encrypted_size = (initial_size * sizeof(initial_dtype) + block_size_bytes - 1) // block_size_bytes * block_size_bytes // sizeof(encrypted_dtype)
379+
decrypted_size = (encrypted_size * sizeof(encrypted_dtype) + block_size_bytes - 1) // block_size_bytes * block_size_bytes // sizeof(decrypted_dtype)
380+
381+
initial = torch.empty(initial_size, dtype=initial_dtype, device=device).random_()
382+
encrypted = torch.empty(encrypted_size, dtype=encrypted_dtype, device=device).random_()
383+
decrypted = torch.empty(decrypted_size, dtype=decrypted_dtype, device=device).random_()
384+
385+
initial_np = initial.cpu().numpy().view(np.int8)
386+
decrypted_np = decrypted.cpu().numpy().view(np.int8)
387+
padding_size_bytes = initial_size * sizeof(initial_dtype) - decrypted_size * sizeof(decrypted_dtype)
388+
if padding_size_bytes != 0:
389+
decrypted_np = decrypted_np[:padding_size_bytes]
390+
391+
csprng.encrypt(initial, encrypted, key, "aes128", mode)
392+
393+
if initial_size > 8:
394+
self.assertFalse(np.array_equal(initial_np, decrypted_np))
395+
396+
csprng.decrypt(encrypted, decrypted, key, "aes128", mode)
397+
decrypted_np = decrypted.cpu().numpy().view(np.int8)
398+
if padding_size_bytes != 0:
399+
decrypted_np = decrypted_np[:padding_size_bytes]
400+
401+
self.assertTrue(np.array_equal(initial_np, decrypted_np))
402+
357403
if __name__ == '__main__':
358404
unittest.main()

torchcsprng/csrc/aes.h

Lines changed: 118 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,7 @@ namespace aes {
6363
#define Nr 10 // The number of rounds in AES Cipher.
6464
#endif
6565

66-
#if !defined(__CUDACC__) && !defined(__HIPCC__)
67-
struct ulonglong2 // TODO: should have something like `__builtin_align__(16)`
68-
{
69-
unsigned long long int x, y;
70-
};
71-
#endif
72-
73-
typedef ulonglong2 block_t;
74-
constexpr size_t block_t_size = sizeof(block_t);
66+
constexpr size_t block_t_size = 16;
7567

7668
typedef uint8_t state_t[4][4];
7769

@@ -97,13 +89,33 @@ TORCH_CSPRNG_CONSTANT const uint8_t sbox[256] = {
9789
0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
9890
0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16 };
9991

92+
TORCH_CSPRNG_CONSTANT const uint8_t rsbox[256] = {
93+
0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
94+
0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
95+
0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
96+
0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
97+
0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
98+
0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
99+
0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
100+
0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
101+
0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
102+
0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
103+
0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
104+
0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
105+
0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
106+
0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
107+
0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
108+
0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d };
109+
100110
// The round constant word array, Rcon[i], contains the values given by
101111
// x to the power (i-1) being powers of x (x is denoted as {02}) in the field GF(2^8)
102112
TORCH_CSPRNG_CONSTANT const uint8_t Rcon[11] = {
103113
0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36 };
104114

105115
#define getSBoxValue(num) (sbox[(num)])
106116

117+
#define getSBoxInvert(num) (rsbox[(num)])
118+
107119
// This function produces Nb(Nr+1) round keys. The round keys are used in each round to decrypt the states.
108120
TORCH_CSPRNG_HOST_DEVICE void KeyExpansion(uint8_t* RoundKey, const uint8_t* Key){
109121
unsigned int i, j, k;
@@ -257,6 +269,78 @@ TORCH_CSPRNG_HOST_DEVICE void MixColumns(state_t* state)
257269
}
258270
}
259271

272+
TORCH_CSPRNG_HOST_DEVICE uint8_t Multiply(uint8_t x, uint8_t y)
273+
{
274+
return (((y & 1) * x) ^
275+
((y>>1 & 1) * xtime(x)) ^
276+
((y>>2 & 1) * xtime(xtime(x))) ^
277+
((y>>3 & 1) * xtime(xtime(xtime(x)))) ^
278+
((y>>4 & 1) * xtime(xtime(xtime(xtime(x)))))); /* this last call to xtime() can be omitted */
279+
}
280+
281+
// MixColumns function mixes the columns of the state matrix.
282+
// The method used to multiply may be difficult to understand for the inexperienced.
283+
// Please use the references to gain more information.
284+
TORCH_CSPRNG_HOST_DEVICE void InvMixColumns(state_t* state)
285+
{
286+
int i;
287+
uint8_t a, b, c, d;
288+
for (i = 0; i < 4; ++i)
289+
{
290+
a = (*state)[i][0];
291+
b = (*state)[i][1];
292+
c = (*state)[i][2];
293+
d = (*state)[i][3];
294+
295+
(*state)[i][0] = Multiply(a, 0x0e) ^ Multiply(b, 0x0b) ^ Multiply(c, 0x0d) ^ Multiply(d, 0x09);
296+
(*state)[i][1] = Multiply(a, 0x09) ^ Multiply(b, 0x0e) ^ Multiply(c, 0x0b) ^ Multiply(d, 0x0d);
297+
(*state)[i][2] = Multiply(a, 0x0d) ^ Multiply(b, 0x09) ^ Multiply(c, 0x0e) ^ Multiply(d, 0x0b);
298+
(*state)[i][3] = Multiply(a, 0x0b) ^ Multiply(b, 0x0d) ^ Multiply(c, 0x09) ^ Multiply(d, 0x0e);
299+
}
300+
}
301+
302+
// The SubBytes Function Substitutes the values in the
303+
// state matrix with values in an S-box.
304+
TORCH_CSPRNG_HOST_DEVICE void InvSubBytes(state_t* state)
305+
{
306+
uint8_t i, j;
307+
for (i = 0; i < 4; ++i)
308+
{
309+
for (j = 0; j < 4; ++j)
310+
{
311+
(*state)[j][i] = getSBoxInvert((*state)[j][i]);
312+
}
313+
}
314+
}
315+
316+
TORCH_CSPRNG_HOST_DEVICE void InvShiftRows(state_t* state)
317+
{
318+
uint8_t temp;
319+
320+
// Rotate first row 1 columns to right
321+
temp = (*state)[3][1];
322+
(*state)[3][1] = (*state)[2][1];
323+
(*state)[2][1] = (*state)[1][1];
324+
(*state)[1][1] = (*state)[0][1];
325+
(*state)[0][1] = temp;
326+
327+
// Rotate second row 2 columns to right
328+
temp = (*state)[0][2];
329+
(*state)[0][2] = (*state)[2][2];
330+
(*state)[2][2] = temp;
331+
332+
temp = (*state)[1][2];
333+
(*state)[1][2] = (*state)[3][2];
334+
(*state)[3][2] = temp;
335+
336+
// Rotate third row 3 columns to right
337+
temp = (*state)[0][3];
338+
(*state)[0][3] = (*state)[1][3];
339+
(*state)[1][3] = (*state)[2][3];
340+
(*state)[2][3] = (*state)[3][3];
341+
(*state)[3][3] = temp;
342+
}
343+
260344
TORCH_CSPRNG_HOST_DEVICE void encrypt(uint8_t* state, const uint8_t* key) {
261345
uint8_t RoundKey[176];
262346
KeyExpansion(RoundKey, key);
@@ -284,4 +368,29 @@ TORCH_CSPRNG_HOST_DEVICE void encrypt(uint8_t* state, const uint8_t* key) {
284368
AddRoundKey(Nr, (state_t*)state, RoundKey);
285369
}
286370

371+
TORCH_CSPRNG_HOST_DEVICE void decrypt(uint8_t* state, const uint8_t* key) {
372+
uint8_t RoundKey[176];
373+
KeyExpansion(RoundKey, key);
374+
375+
uint8_t round = 0;
376+
377+
// Add the First round key to the state before starting the rounds.
378+
AddRoundKey(Nr, (state_t*)state, RoundKey);
379+
380+
// There will be Nr rounds.
381+
// The first Nr-1 rounds are identical.
382+
// These Nr rounds are executed in the loop below.
383+
// Last one without InvMixColumn()
384+
for (round = (Nr - 1); ; --round)
385+
{
386+
InvShiftRows((state_t*)state);
387+
InvSubBytes((state_t*)state);
388+
AddRoundKey(round, (state_t*)state, RoundKey);
389+
if (round == 0) {
390+
break;
391+
}
392+
InvMixColumns((state_t*)state);
393+
}
394+
}
395+
287396
}}}

0 commit comments

Comments
 (0)