@@ -63,15 +63,7 @@ namespace aes {
6363 #define Nr 10 // The number of rounds in AES Cipher.
6464#endif
6565
66- #if !defined(__CUDACC__) && !defined(__HIPCC__)
67- struct ulonglong2 // TODO: should have something like `__builtin_align__(16)`
68- {
69- unsigned long long int x, y;
70- };
71- #endif
72-
73- typedef ulonglong2 block_t ;
74- constexpr size_t block_t_size = sizeof (block_t );
66+ constexpr size_t block_t_size = 16 ;
7567
7668typedef uint8_t state_t [4 ][4 ];
7769
@@ -97,13 +89,33 @@ TORCH_CSPRNG_CONSTANT const uint8_t sbox[256] = {
9789 0xe1 , 0xf8 , 0x98 , 0x11 , 0x69 , 0xd9 , 0x8e , 0x94 , 0x9b , 0x1e , 0x87 , 0xe9 , 0xce , 0x55 , 0x28 , 0xdf ,
9890 0x8c , 0xa1 , 0x89 , 0x0d , 0xbf , 0xe6 , 0x42 , 0x68 , 0x41 , 0x99 , 0x2d , 0x0f , 0xb0 , 0x54 , 0xbb , 0x16 };
9991
92+ TORCH_CSPRNG_CONSTANT const uint8_t rsbox[256 ] = {
93+ 0x52 , 0x09 , 0x6a , 0xd5 , 0x30 , 0x36 , 0xa5 , 0x38 , 0xbf , 0x40 , 0xa3 , 0x9e , 0x81 , 0xf3 , 0xd7 , 0xfb ,
94+ 0x7c , 0xe3 , 0x39 , 0x82 , 0x9b , 0x2f , 0xff , 0x87 , 0x34 , 0x8e , 0x43 , 0x44 , 0xc4 , 0xde , 0xe9 , 0xcb ,
95+ 0x54 , 0x7b , 0x94 , 0x32 , 0xa6 , 0xc2 , 0x23 , 0x3d , 0xee , 0x4c , 0x95 , 0x0b , 0x42 , 0xfa , 0xc3 , 0x4e ,
96+ 0x08 , 0x2e , 0xa1 , 0x66 , 0x28 , 0xd9 , 0x24 , 0xb2 , 0x76 , 0x5b , 0xa2 , 0x49 , 0x6d , 0x8b , 0xd1 , 0x25 ,
97+ 0x72 , 0xf8 , 0xf6 , 0x64 , 0x86 , 0x68 , 0x98 , 0x16 , 0xd4 , 0xa4 , 0x5c , 0xcc , 0x5d , 0x65 , 0xb6 , 0x92 ,
98+ 0x6c , 0x70 , 0x48 , 0x50 , 0xfd , 0xed , 0xb9 , 0xda , 0x5e , 0x15 , 0x46 , 0x57 , 0xa7 , 0x8d , 0x9d , 0x84 ,
99+ 0x90 , 0xd8 , 0xab , 0x00 , 0x8c , 0xbc , 0xd3 , 0x0a , 0xf7 , 0xe4 , 0x58 , 0x05 , 0xb8 , 0xb3 , 0x45 , 0x06 ,
100+ 0xd0 , 0x2c , 0x1e , 0x8f , 0xca , 0x3f , 0x0f , 0x02 , 0xc1 , 0xaf , 0xbd , 0x03 , 0x01 , 0x13 , 0x8a , 0x6b ,
101+ 0x3a , 0x91 , 0x11 , 0x41 , 0x4f , 0x67 , 0xdc , 0xea , 0x97 , 0xf2 , 0xcf , 0xce , 0xf0 , 0xb4 , 0xe6 , 0x73 ,
102+ 0x96 , 0xac , 0x74 , 0x22 , 0xe7 , 0xad , 0x35 , 0x85 , 0xe2 , 0xf9 , 0x37 , 0xe8 , 0x1c , 0x75 , 0xdf , 0x6e ,
103+ 0x47 , 0xf1 , 0x1a , 0x71 , 0x1d , 0x29 , 0xc5 , 0x89 , 0x6f , 0xb7 , 0x62 , 0x0e , 0xaa , 0x18 , 0xbe , 0x1b ,
104+ 0xfc , 0x56 , 0x3e , 0x4b , 0xc6 , 0xd2 , 0x79 , 0x20 , 0x9a , 0xdb , 0xc0 , 0xfe , 0x78 , 0xcd , 0x5a , 0xf4 ,
105+ 0x1f , 0xdd , 0xa8 , 0x33 , 0x88 , 0x07 , 0xc7 , 0x31 , 0xb1 , 0x12 , 0x10 , 0x59 , 0x27 , 0x80 , 0xec , 0x5f ,
106+ 0x60 , 0x51 , 0x7f , 0xa9 , 0x19 , 0xb5 , 0x4a , 0x0d , 0x2d , 0xe5 , 0x7a , 0x9f , 0x93 , 0xc9 , 0x9c , 0xef ,
107+ 0xa0 , 0xe0 , 0x3b , 0x4d , 0xae , 0x2a , 0xf5 , 0xb0 , 0xc8 , 0xeb , 0xbb , 0x3c , 0x83 , 0x53 , 0x99 , 0x61 ,
108+ 0x17 , 0x2b , 0x04 , 0x7e , 0xba , 0x77 , 0xd6 , 0x26 , 0xe1 , 0x69 , 0x14 , 0x63 , 0x55 , 0x21 , 0x0c , 0x7d };
109+
100110// The round constant word array, Rcon[i], contains the values given by
101111// x to the power (i-1) being powers of x (x is denoted as {02}) in the field GF(2^8)
102112TORCH_CSPRNG_CONSTANT const uint8_t Rcon[11 ] = {
103113 0x8d , 0x01 , 0x02 , 0x04 , 0x08 , 0x10 , 0x20 , 0x40 , 0x80 , 0x1b , 0x36 };
104114
105115#define getSBoxValue (num ) (sbox[(num)])
106116
117+ #define getSBoxInvert (num ) (rsbox[(num)])
118+
107119// This function produces Nb(Nr+1) round keys. The round keys are used in each round to decrypt the states.
108120TORCH_CSPRNG_HOST_DEVICE void KeyExpansion (uint8_t * RoundKey, const uint8_t * Key){
109121 unsigned int i, j, k;
@@ -257,6 +269,78 @@ TORCH_CSPRNG_HOST_DEVICE void MixColumns(state_t* state)
257269 }
258270}
259271
272+ TORCH_CSPRNG_HOST_DEVICE uint8_t Multiply (uint8_t x, uint8_t y)
273+ {
274+ return (((y & 1 ) * x) ^
275+ ((y>>1 & 1 ) * xtime (x)) ^
276+ ((y>>2 & 1 ) * xtime (xtime (x))) ^
277+ ((y>>3 & 1 ) * xtime (xtime (xtime (x)))) ^
278+ ((y>>4 & 1 ) * xtime (xtime (xtime (xtime (x)))))); /* this last call to xtime() can be omitted */
279+ }
280+
281+ // MixColumns function mixes the columns of the state matrix.
282+ // The method used to multiply may be difficult to understand for the inexperienced.
283+ // Please use the references to gain more information.
284+ TORCH_CSPRNG_HOST_DEVICE void InvMixColumns (state_t * state)
285+ {
286+ int i;
287+ uint8_t a, b, c, d;
288+ for (i = 0 ; i < 4 ; ++i)
289+ {
290+ a = (*state)[i][0 ];
291+ b = (*state)[i][1 ];
292+ c = (*state)[i][2 ];
293+ d = (*state)[i][3 ];
294+
295+ (*state)[i][0 ] = Multiply (a, 0x0e ) ^ Multiply (b, 0x0b ) ^ Multiply (c, 0x0d ) ^ Multiply (d, 0x09 );
296+ (*state)[i][1 ] = Multiply (a, 0x09 ) ^ Multiply (b, 0x0e ) ^ Multiply (c, 0x0b ) ^ Multiply (d, 0x0d );
297+ (*state)[i][2 ] = Multiply (a, 0x0d ) ^ Multiply (b, 0x09 ) ^ Multiply (c, 0x0e ) ^ Multiply (d, 0x0b );
298+ (*state)[i][3 ] = Multiply (a, 0x0b ) ^ Multiply (b, 0x0d ) ^ Multiply (c, 0x09 ) ^ Multiply (d, 0x0e );
299+ }
300+ }
301+
302+ // The SubBytes Function Substitutes the values in the
303+ // state matrix with values in an S-box.
304+ TORCH_CSPRNG_HOST_DEVICE void InvSubBytes (state_t * state)
305+ {
306+ uint8_t i, j;
307+ for (i = 0 ; i < 4 ; ++i)
308+ {
309+ for (j = 0 ; j < 4 ; ++j)
310+ {
311+ (*state)[j][i] = getSBoxInvert ((*state)[j][i]);
312+ }
313+ }
314+ }
315+
316+ TORCH_CSPRNG_HOST_DEVICE void InvShiftRows (state_t * state)
317+ {
318+ uint8_t temp;
319+
320+ // Rotate first row 1 columns to right
321+ temp = (*state)[3 ][1 ];
322+ (*state)[3 ][1 ] = (*state)[2 ][1 ];
323+ (*state)[2 ][1 ] = (*state)[1 ][1 ];
324+ (*state)[1 ][1 ] = (*state)[0 ][1 ];
325+ (*state)[0 ][1 ] = temp;
326+
327+ // Rotate second row 2 columns to right
328+ temp = (*state)[0 ][2 ];
329+ (*state)[0 ][2 ] = (*state)[2 ][2 ];
330+ (*state)[2 ][2 ] = temp;
331+
332+ temp = (*state)[1 ][2 ];
333+ (*state)[1 ][2 ] = (*state)[3 ][2 ];
334+ (*state)[3 ][2 ] = temp;
335+
336+ // Rotate third row 3 columns to right
337+ temp = (*state)[0 ][3 ];
338+ (*state)[0 ][3 ] = (*state)[1 ][3 ];
339+ (*state)[1 ][3 ] = (*state)[2 ][3 ];
340+ (*state)[2 ][3 ] = (*state)[3 ][3 ];
341+ (*state)[3 ][3 ] = temp;
342+ }
343+
260344TORCH_CSPRNG_HOST_DEVICE void encrypt (uint8_t * state, const uint8_t * key) {
261345 uint8_t RoundKey[176 ];
262346 KeyExpansion (RoundKey, key);
@@ -284,4 +368,29 @@ TORCH_CSPRNG_HOST_DEVICE void encrypt(uint8_t* state, const uint8_t* key) {
284368 AddRoundKey (Nr, (state_t *)state, RoundKey);
285369}
286370
371+ TORCH_CSPRNG_HOST_DEVICE void decrypt (uint8_t * state, const uint8_t * key) {
372+ uint8_t RoundKey[176 ];
373+ KeyExpansion (RoundKey, key);
374+
375+ uint8_t round = 0 ;
376+
377+ // Add the First round key to the state before starting the rounds.
378+ AddRoundKey (Nr, (state_t *)state, RoundKey);
379+
380+ // There will be Nr rounds.
381+ // The first Nr-1 rounds are identical.
382+ // These Nr rounds are executed in the loop below.
383+ // Last one without InvMixColumn()
384+ for (round = (Nr - 1 ); ; --round)
385+ {
386+ InvShiftRows ((state_t *)state);
387+ InvSubBytes ((state_t *)state);
388+ AddRoundKey (round, (state_t *)state, RoundKey);
389+ if (round == 0 ) {
390+ break ;
391+ }
392+ InvMixColumns ((state_t *)state);
393+ }
394+ }
395+
287396}}}
0 commit comments