Skip to content

Commit

Permalink
Non-enclave Dependent Encrypt/Decrypt (#65)
Browse files Browse the repository at this point in the history
Add `encrypt()` and `decrypt()` util functions to remove client side enclave dependency when encrypting and decrypting

Reorder outputted cipher in enclave crypto functions to match Java's ordering. Was previously `(IV, MAC, data)`, is now `(IV, data, MAC)`

Co-authored by @ankurdave

Fixes #37 
WIP on #64
  • Loading branch information
chester-leung authored and ankurdave committed Nov 20, 2018
1 parent 68223fe commit a3502ab
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 54 deletions.
27 changes: 0 additions & 27 deletions src/enclave/App/App.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,33 +808,6 @@ JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEncla
return ciphertext;
}

JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_Decrypt(
JNIEnv *env, jobject obj, jlong eid, jbyteArray ciphertext) {
(void)obj;

uint32_t clength = (uint32_t) env->GetArrayLength(ciphertext);
jboolean if_copy = false;
jbyte *ptr = env->GetByteArrayElements(ciphertext, &if_copy);

uint8_t *ciphertext_ptr = (uint8_t *) ptr;

const jsize plength = clength - SGX_AESGCM_IV_SIZE - SGX_AESGCM_MAC_SIZE;
jbyteArray plaintext = env->NewByteArray(plength);

uint8_t *plaintext_copy = new uint8_t[plength];

sgx_check_quiet(
"Decrypt", ecall_decrypt(eid, ciphertext_ptr, clength, plaintext_copy, (uint32_t) plength));

env->SetByteArrayRegion(plaintext, 0, plength, (jbyte *) plaintext_copy);

env->ReleaseByteArrayElements(ciphertext, ptr, 0);

delete[] plaintext_copy;

return plaintext;
}

JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_Sample(
JNIEnv *env, jobject obj, jlong eid, jbyteArray input_rows) {
(void)obj;
Expand Down
16 changes: 9 additions & 7 deletions src/enclave/Enclave/Crypto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,23 @@ void encrypt(uint8_t *plaintext, uint32_t plaintext_length,

initKeySchedule();

// key size is 12 bytes/128 bits
// key size is 16 bytes/128 bits
// IV size is 12 bytes/96 bits
// MAC size is 16 bytes/128 bits

// one buffer to store IV (12 bytes) + ciphertext + mac (16 bytes)

uint8_t *iv_ptr = ciphertext;
uint8_t *ciphertext_ptr = ciphertext + SGX_AESGCM_IV_SIZE;
sgx_aes_gcm_128bit_tag_t *mac_ptr =
(sgx_aes_gcm_128bit_tag_t *) (ciphertext + SGX_AESGCM_IV_SIZE + plaintext_length);

// generate random IV
sgx_read_rand(iv_ptr, SGX_AESGCM_IV_SIZE);
sgx_aes_gcm_128bit_tag_t *mac_ptr = (sgx_aes_gcm_128bit_tag_t *) (ciphertext + SGX_AESGCM_IV_SIZE);
uint8_t *ciphertext_ptr = ciphertext + SGX_AESGCM_IV_SIZE + SGX_AESGCM_MAC_SIZE;

AesGcm cipher(ks, iv_ptr, SGX_AESGCM_IV_SIZE);
cipher.encrypt(plaintext, plaintext_length, ciphertext_ptr, plaintext_length);
memcpy(mac_ptr, cipher.tag().t, SGX_AESGCM_MAC_SIZE);

}


Expand All @@ -57,7 +58,7 @@ void decrypt(const uint8_t *ciphertext, uint32_t ciphertext_length,
// decrypt using a global key
// TODO: fix this; should use key obtained from client

// key size is 12 bytes/128 bits
// key size is 16 bytes/128 bits
// IV size is 12 bytes/96 bits
// MAC size is 16 bytes/128 bits

Expand All @@ -66,8 +67,9 @@ void decrypt(const uint8_t *ciphertext, uint32_t ciphertext_length,
uint32_t plaintext_length = ciphertext_length - SGX_AESGCM_IV_SIZE - SGX_AESGCM_MAC_SIZE;

uint8_t *iv_ptr = (uint8_t *) ciphertext;
sgx_aes_gcm_128bit_tag_t *mac_ptr = (sgx_aes_gcm_128bit_tag_t *) (ciphertext + SGX_AESGCM_IV_SIZE);
uint8_t *ciphertext_ptr = (uint8_t *) (ciphertext + SGX_AESGCM_IV_SIZE + SGX_AESGCM_MAC_SIZE);
uint8_t *ciphertext_ptr = (uint8_t *) (ciphertext + SGX_AESGCM_IV_SIZE);
sgx_aes_gcm_128bit_tag_t *mac_ptr =
(sgx_aes_gcm_128bit_tag_t *) (ciphertext + SGX_AESGCM_IV_SIZE + plaintext_length);

AesGcm decipher(ks, iv_ptr, SGX_AESGCM_IV_SIZE);
decipher.decrypt(ciphertext_ptr, plaintext_length, plaintext, plaintext_length);
Expand Down
15 changes: 0 additions & 15 deletions src/enclave/Enclave/Enclave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,6 @@ void ecall_encrypt(uint8_t *plaintext, uint32_t plaintext_length,
}
}

void ecall_decrypt(uint8_t *ciphertext,
uint32_t ciphertext_length,
uint8_t *plaintext,
uint32_t plaintext_length) {
try {
// IV (12 bytes) + ciphertext + mac (16 bytes)
assert(ciphertext_length >= plaintext_length + SGX_AESGCM_IV_SIZE + SGX_AESGCM_MAC_SIZE);
(void)ciphertext_length;
(void)plaintext_length;
decrypt(ciphertext, ciphertext_length, plaintext);
} catch (const std::runtime_error &e) {
ocall_throw(e.what());
}
}

void ecall_project(uint8_t *condition, size_t condition_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t **output_rows, size_t *output_rows_length) {
Expand Down
4 changes: 0 additions & 4 deletions src/enclave/Enclave/Enclave.edl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ enclave {
[user_check] uint8_t *plaintext, uint32_t length,
[user_check] uint8_t *ciphertext, uint32_t cipher_length);

public void ecall_decrypt(
[in, size=ciphertext_length] uint8_t *ciphertext, uint32_t ciphertext_length,
[out, size=plaintext_length] uint8_t *plaintext, uint32_t plaintext_length);

public void ecall_sample(
[user_check] uint8_t *input_rows, size_t input_rows_length,
[out] uint8_t **output_rows, [out] size_t *output_rows_length);
Expand Down
33 changes: 32 additions & 1 deletion src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import java.io.File
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.util.UUID
import javax.crypto._
import javax.crypto.spec.GCMParameterSpec
import javax.crypto.spec.SecretKeySpec
import java.security.SecureRandom;

import scala.collection.mutable.ArrayBuilder

Expand Down Expand Up @@ -199,6 +203,33 @@ object Utils extends Logging {
}
}

final val GCM_IV_LENGTH = 12
final val GCM_KEY_LENGTH = 16
final val GCM_TAG_LENGTH = 16

def encrypt(data: Array[Byte]): Array[Byte] = {
val random = SecureRandom.getInstance("SHA1PRNG")
val key = new Array[Byte](GCM_KEY_LENGTH)
val cipherKey = new SecretKeySpec(key, "AES")
val iv = new Array[Byte](GCM_IV_LENGTH)
random.nextBytes(iv)
val spec = new GCMParameterSpec(GCM_TAG_LENGTH * 8, iv)
val cipher = Cipher.getInstance("AES/GCM/NoPadding", "SunJCE")
cipher.init(Cipher.ENCRYPT_MODE, cipherKey, spec)
val cipherText = cipher.doFinal(data)
iv ++ cipherText
}

def decrypt(data: Array[Byte]): Array[Byte] = {
val key = new Array[Byte](GCM_KEY_LENGTH)
val cipherKey = new SecretKeySpec(key, "AES")
val iv = data.take(GCM_IV_LENGTH)
val cipherText = data.drop(GCM_IV_LENGTH)
val cipher = Cipher.getInstance("AES/GCM/NoPadding", "SunJCE")
cipher.init(Cipher.DECRYPT_MODE, cipherKey, new GCMParameterSpec(GCM_TAG_LENGTH * 8, iv))
cipher.doFinal(cipherText)
}

var eid = 0L
var attested : Boolean = false
var attesting_getepid : Boolean = false
Expand Down Expand Up @@ -642,7 +673,7 @@ object Utils extends Logging {

// 2. Decrypt the row data
val (enclave, eid) = initEnclave()
val plaintext = enclave.Decrypt(eid, ciphertext)
val plaintext = decrypt(ciphertext)

// 1. Deserialize the tuix.Rows and return them as Scala InternalRow objects
val rows = tuix.Rows.getRootAsRows(ByteBuffer.wrap(plaintext))
Expand Down
9 changes: 9 additions & 0 deletions src/test/scala/edu/berkeley/cs/rise/opaque/QEDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,13 @@ class QEDSuite extends FunSuite with BeforeAndAfterAll {
val data = for (i <- 0 until 8) yield (i)
RA.initRA(spark.sparkContext.parallelize(data, 2))
}

test("java encryption/decryption") {
val data = Array[Byte](0, 1, 2)
val (enclave, eid) = Utils.initEnclave()
assert(data === Utils.decrypt(Utils.encrypt(data)))
assert(data === Utils.decrypt(enclave.Encrypt(eid, data)))
}
}


0 comments on commit a3502ab

Please sign in to comment.