Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1857d6f
encrypt/decrypt
pbelevich Nov 14, 2020
1737f81
Update on "encrypt/decrypt"
pbelevich Nov 16, 2020
e4955f7
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
583fb80
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
66b2e55
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
6bd0030
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
7465ef1
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
247ca74
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
f6a260a
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
9cbd83f
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
5ae1104
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
1ea4090
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
a040964
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
2ae9d3e
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
cb7d77b
Update on "encrypt/decrypt"
pbelevich Nov 18, 2020
fb23aae
Update on "encrypt/decrypt"
pbelevich Nov 19, 2020
fa7970d
Update on "encrypt/decrypt"
pbelevich Nov 19, 2020
caefc34
Update on "encrypt/decrypt"
pbelevich Nov 19, 2020
71e00c3
Update on "encrypt/decrypt"
pbelevich Nov 19, 2020
9469ece
Update on "encrypt/decrypt"
pbelevich Nov 19, 2020
8b01c56
Update on "encrypt/decrypt"
pbelevich Nov 19, 2020
7f91368
Update on "encrypt/decrypt"
pbelevich Nov 19, 2020
5d9c1c7
Update on "torchcsprng.encrypt/torchcsprng.decrypt with AES128 ECB/CT…
pbelevich Nov 19, 2020
c1c75bb
Update on "torchcsprng.encrypt/torchcsprng.decrypt with AES128 ECB/CT…
pbelevich Nov 19, 2020
c45ff91
Update on "torchcsprng.encrypt/torchcsprng.decrypt with AES128 ECB/CT…
pbelevich Nov 19, 2020
36545c2
Update on "torchcsprng.encrypt/torchcsprng.decrypt with AES128 ECB/CT…
pbelevich Nov 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions test/test_csprng.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,5 +354,51 @@ def test_const_generator(self):
second = torch.empty(self.size, dtype=dtype, device=device).random_(generator=const_gen)
self.assertTrue((first - second).max().abs() == 0)

def test_encrypt_decrypt(self):
key_size_bytes = 16
block_size_bytes = 16

def sizeof(dtype):
if dtype == torch.bool:
return 1
elif dtype.is_floating_point:
return torch.finfo(dtype).bits // 8
else:
return torch.iinfo(dtype).bits // 8

for device in self.all_devices:
for key_dtype in self.all_dtypes:
key_size = key_size_bytes // sizeof(key_dtype)
key = torch.empty(key_size, dtype=key_dtype, device=device).random_()
for initial_dtype in self.all_dtypes:
for encrypted_dtype in self.all_dtypes:
for decrypted_dtype in self.all_dtypes:
for initial_size in [0, 4, 8, 15, 16, 23, 42]:
for mode in ["ecb", "ctr"]:
encrypted_size = (initial_size * sizeof(initial_dtype) + block_size_bytes - 1) // block_size_bytes * block_size_bytes // sizeof(encrypted_dtype)
decrypted_size = (encrypted_size * sizeof(encrypted_dtype) + block_size_bytes - 1) // block_size_bytes * block_size_bytes // sizeof(decrypted_dtype)

initial = torch.empty(initial_size, dtype=initial_dtype, device=device).random_()
encrypted = torch.empty(encrypted_size, dtype=encrypted_dtype, device=device).random_()
decrypted = torch.empty(decrypted_size, dtype=decrypted_dtype, device=device).random_()

initial_np = initial.cpu().numpy().view(np.int8)
decrypted_np = decrypted.cpu().numpy().view(np.int8)
padding_size_bytes = initial_size * sizeof(initial_dtype) - decrypted_size * sizeof(decrypted_dtype)
if padding_size_bytes != 0:
decrypted_np = decrypted_np[:padding_size_bytes]

csprng.encrypt(initial, encrypted, key, "aes128", mode)

if initial_size > 8:
self.assertFalse(np.array_equal(initial_np, decrypted_np))

csprng.decrypt(encrypted, decrypted, key, "aes128", mode)
decrypted_np = decrypted.cpu().numpy().view(np.int8)
if padding_size_bytes != 0:
decrypted_np = decrypted_np[:padding_size_bytes]

self.assertTrue(np.array_equal(initial_np, decrypted_np))

if __name__ == '__main__':
unittest.main()
127 changes: 118 additions & 9 deletions torchcsprng/csrc/aes.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,7 @@ namespace aes {
#define Nr 10 // The number of rounds in AES Cipher.
#endif

#if !defined(__CUDACC__) && !defined(__HIPCC__)
struct ulonglong2 // TODO: should have something like `__builtin_align__(16)`
{
unsigned long long int x, y;
};
#endif

typedef ulonglong2 block_t;
constexpr size_t block_t_size = sizeof(block_t);
constexpr size_t block_t_size = 16;

typedef uint8_t state_t[4][4];

Expand All @@ -97,13 +89,33 @@ TORCH_CSPRNG_CONSTANT const uint8_t sbox[256] = {
0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16 };

TORCH_CSPRNG_CONSTANT const uint8_t rsbox[256] = {
0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d };

// The round constant word array, Rcon[i], contains the values given by
// x to the power (i-1) being powers of x (x is denoted as {02}) in the field GF(2^8)
TORCH_CSPRNG_CONSTANT const uint8_t Rcon[11] = {
0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36 };

#define getSBoxValue(num) (sbox[(num)])

#define getSBoxInvert(num) (rsbox[(num)])

// This function produces Nb(Nr+1) round keys. The round keys are used in each round to decrypt the states.
TORCH_CSPRNG_HOST_DEVICE void KeyExpansion(uint8_t* RoundKey, const uint8_t* Key){
unsigned int i, j, k;
Expand Down Expand Up @@ -257,6 +269,78 @@ TORCH_CSPRNG_HOST_DEVICE void MixColumns(state_t* state)
}
}

TORCH_CSPRNG_HOST_DEVICE uint8_t Multiply(uint8_t x, uint8_t y)
{
return (((y & 1) * x) ^
((y>>1 & 1) * xtime(x)) ^
((y>>2 & 1) * xtime(xtime(x))) ^
((y>>3 & 1) * xtime(xtime(xtime(x)))) ^
((y>>4 & 1) * xtime(xtime(xtime(xtime(x)))))); /* this last call to xtime() can be omitted */
}

// MixColumns function mixes the columns of the state matrix.
// The method used to multiply may be difficult to understand for the inexperienced.
// Please use the references to gain more information.
TORCH_CSPRNG_HOST_DEVICE void InvMixColumns(state_t* state)
{
int i;
uint8_t a, b, c, d;
for (i = 0; i < 4; ++i)
{
a = (*state)[i][0];
b = (*state)[i][1];
c = (*state)[i][2];
d = (*state)[i][3];

(*state)[i][0] = Multiply(a, 0x0e) ^ Multiply(b, 0x0b) ^ Multiply(c, 0x0d) ^ Multiply(d, 0x09);
(*state)[i][1] = Multiply(a, 0x09) ^ Multiply(b, 0x0e) ^ Multiply(c, 0x0b) ^ Multiply(d, 0x0d);
(*state)[i][2] = Multiply(a, 0x0d) ^ Multiply(b, 0x09) ^ Multiply(c, 0x0e) ^ Multiply(d, 0x0b);
(*state)[i][3] = Multiply(a, 0x0b) ^ Multiply(b, 0x0d) ^ Multiply(c, 0x09) ^ Multiply(d, 0x0e);
}
}

// The SubBytes Function Substitutes the values in the
// state matrix with values in an S-box.
TORCH_CSPRNG_HOST_DEVICE void InvSubBytes(state_t* state)
{
uint8_t i, j;
for (i = 0; i < 4; ++i)
{
for (j = 0; j < 4; ++j)
{
(*state)[j][i] = getSBoxInvert((*state)[j][i]);
}
}
}

TORCH_CSPRNG_HOST_DEVICE void InvShiftRows(state_t* state)
{
uint8_t temp;

// Rotate first row 1 columns to right
temp = (*state)[3][1];
(*state)[3][1] = (*state)[2][1];
(*state)[2][1] = (*state)[1][1];
(*state)[1][1] = (*state)[0][1];
(*state)[0][1] = temp;

// Rotate second row 2 columns to right
temp = (*state)[0][2];
(*state)[0][2] = (*state)[2][2];
(*state)[2][2] = temp;

temp = (*state)[1][2];
(*state)[1][2] = (*state)[3][2];
(*state)[3][2] = temp;

// Rotate third row 3 columns to right
temp = (*state)[0][3];
(*state)[0][3] = (*state)[1][3];
(*state)[1][3] = (*state)[2][3];
(*state)[2][3] = (*state)[3][3];
(*state)[3][3] = temp;
}

TORCH_CSPRNG_HOST_DEVICE void encrypt(uint8_t* state, const uint8_t* key) {
uint8_t RoundKey[176];
KeyExpansion(RoundKey, key);
Expand Down Expand Up @@ -284,4 +368,29 @@ TORCH_CSPRNG_HOST_DEVICE void encrypt(uint8_t* state, const uint8_t* key) {
AddRoundKey(Nr, (state_t*)state, RoundKey);
}

TORCH_CSPRNG_HOST_DEVICE void decrypt(uint8_t* state, const uint8_t* key) {
uint8_t RoundKey[176];
KeyExpansion(RoundKey, key);

uint8_t round = 0;

// Add the First round key to the state before starting the rounds.
AddRoundKey(Nr, (state_t*)state, RoundKey);

// There will be Nr rounds.
// The first Nr-1 rounds are identical.
// These Nr rounds are executed in the loop below.
// Last one without InvMixColumn()
for (round = (Nr - 1); ; --round)
{
InvShiftRows((state_t*)state);
InvSubBytes((state_t*)state);
AddRoundKey(round, (state_t*)state, RoundKey);
if (round == 0) {
break;
}
InvMixColumns((state_t*)state);
}
}

}}}
Loading