Permalink
Browse files

fix for wrong ptr use, debug prints, benchmarks, flags

  • Loading branch information...
1 parent e6032a3 commit f70547da2dfa741c410eb72d09685b4a0a5ddf33 Nate Wiger committed Aug 30, 2012
Showing with 102 additions and 70 deletions.
  1. +8 −0 ext/extconf.rb
  2. +39 −27 ext/fast_aes.c
  3. +5 −3 ext/fast_aes.h
  4. +0 −4 lib/fast-aes.rb
  5. +0 −8 lib/fast_aes_static.rb
  6. +22 −24 test/benchmark.rb
  7. +11 −1 test/fast_aes_spec.rb
  8. +17 −3 test/spec_helper.rb
View
@@ -1,6 +1,14 @@
# Loads mkmf which is used to make makefiles for Ruby extensions
require 'mkmf'
+# http://www.ruby-forum.com/topic/4374530
+def add_define(name)
+ $defs.push("-D#{name}")
+end
+
+# Generate faster code
+add_define "FULL_UNROLL"
+
# Give it a name
extension_name = 'fast_aes'
View
@@ -23,6 +23,13 @@ int fast_aes_do_gen_tables = 1;
#define RSTRING_LEN(s) (RSTRING(s)->len)
#endif
+/* http://stackoverflow.com/questions/1941307/c-debug-print-macros */
+#ifdef DEBUG
+ #define DEBUG_PRINT(...) do{ fprintf( stderr, __VA_ARGS__ ); } while(0)
+#else
+ #define DEBUG_PRINT(...) do{ } while (0)
+#endif
+
/* Ruby buckets */
VALUE rb_cFastAES;
@@ -72,22 +79,26 @@ VALUE fast_aes_initialize(VALUE self, VALUE key)
char* key_data = StringValuePtr(key);
/* determine key bits based on string length. breaks if try to use \0 in string. */
- key_bits = strlen(key_data)*8;
+ key_bits = (int)strlen(key_data)*8;
+ /*DEBUG_PRINT("AES key=%s, bits=%d\n", key_data, key_bits);*/
switch(key_bits)
{
case 128:
case 192:
case 256:
fast_aes->key_bits = key_bits;
memcpy(fast_aes->key, key_data, key_bits/8);
- /*printf("AES key=%s, bits=%d\n", fast_aes->key, fast_aes->key_bits);*/
+ /*DEBUG_PRINT("AES key=%s, bits=%d\n", fast_aes->key, fast_aes->key_bits);*/
break;
default:
- sprintf(error_mesg, "AES key must be 128, 192, or 256 bits in length (got %d): %s", key_bits, key_data);
- rb_raise(rb_eArgError, error_mesg);
+ rb_raise(rb_eArgError, "AES key must be 128, 192, or 256 bits in length (got %d): %s", key_bits, key_data);
return Qnil;
}
+ /* reset state flags */
+ fast_aes->inited_erk = 0;
+ fast_aes->inited_drk = 0;
+
if (fast_aes_initialize_state(fast_aes)) {
rb_raise(rb_eRuntimeError, "Failed to initialize AES internal state");
return Qnil;
@@ -146,9 +157,14 @@ VALUE fast_aes_encrypt(
Data_Get_Struct(self, fast_aes_t, fast_aes);
char* text = StringValuePtr(buffer);
- int bytes_in = RSTRING_LEN(buffer);
+ int bytes_in = (int)RSTRING_LEN(buffer);
char* data = malloc((bytes_in + 15) & -16); /* auto-malloc min size in 16-byte increments */
+ if (data == NULL) {
+ rb_raise(rb_eRuntimeError, "Failed to allocate memory for encrypt");
+ return Qnil;
+ }
unsigned char temp[16];
+ DEBUG_PRINT("encrypt: %s\n", text);
/* pointers to traverse text/data */
unsigned char *ptext, *pdata;
@@ -175,6 +191,7 @@ VALUE fast_aes_encrypt(
// 16 bytes of input remaining.
*/
while (bytes_in >= 16) {
+ DEBUG_PRINT("encrypt: %d in, %d out; in=%p, out=%p\n", bytes_in, bytes_out, ptext, pdata);
rijndaelEncrypt(fast_aes->erk, fast_aes->nr, ptext, pdata);
ptext += 16;
pdata += 16;
@@ -188,14 +205,15 @@ VALUE fast_aes_encrypt(
// 16-byte blocks. The policy here will be to pad the input with zeros.
*/
if (bytes_in > 0) {
+ DEBUG_PRINT("encrypt: %d extra bytes\n", bytes_in);
memset(temp, 0, sizeof(temp)); /* pad with 0's */
memcpy(temp, ptext, bytes_in);
rijndaelEncrypt(fast_aes->erk, fast_aes->nr, temp, pdata);
bytes_out += 16;
}
/* return the encrypted string */
- VALUE new_str = rb_str_new(pdata, bytes_out);
+ VALUE new_str = rb_str_new(data, bytes_out);
free(data);
return new_str;
}
@@ -210,9 +228,14 @@ VALUE fast_aes_decrypt(
Data_Get_Struct(self, fast_aes_t, fast_aes);
char* data = StringValuePtr(buffer);
- int bytes_in = RSTRING_LEN(buffer);
+ int bytes_in = (int)RSTRING_LEN(buffer);
char* text = malloc((bytes_in + 15) & -16); /* auto-malloc min size in 16-byte increments */
+ if (text == NULL) {
+ rb_raise(rb_eRuntimeError, "Failed to allocate memory for decrypt");
+ return Qnil;
+ }
unsigned char temp[16];
+ DEBUG_PRINT("decrypt: %s\n", data);
/* pointers to traverse text/data */
unsigned char *ptext, *pdata;
@@ -239,37 +262,26 @@ VALUE fast_aes_decrypt(
// 16 bytes of input remaining.
*/
while (bytes_in >= 16) {
- rijndaelEncrypt(fast_aes->drk, fast_aes->nr, pdata, ptext);
+ DEBUG_PRINT("decrypt: %d in, %d out; in=%p, out=%p\n", bytes_in, bytes_out, pdata, ptext);
+ rijndaelDecrypt(fast_aes->drk, fast_aes->nr, pdata, ptext);
ptext += 16;
pdata += 16;
bytes_in -= 16;
bytes_out += 16;
}
- /*//////////////////////////////////////////////////////////////////////////
- ////////////////////////////////////////////////////////////////////////////
- // Have to catch any straggling bytes that are left after encoding the
- // 16-byte blocks. The policy here will be to pad the input with zeros.
- */
- if (bytes_in > 0) {
- memset(temp, 0, sizeof(temp)); /* pad with 0's */
- memcpy(temp, pdata, bytes_in);
- rijndaelEncrypt(fast_aes->drk, fast_aes->nr, temp, ptext);
- bytes_out += 16;
- }
+ /* AES decryption should always be 16-byte blocks */
+ if (bytes_in != 0)
+ rb_raise(rb_eRuntimeError, "Hit %d straggling bytes on AES decrypt (should be 16-byte blocks)", bytes_in);
- /*//////////////////////////////////////////////////////////////////////////
- // Strip trailing zeros, simple but effective. This is something fucking
- // loose-cannon rjc couldn't figure out despite being a "genius". He needs
- // a punch in the junk, I swear to god.
- */
- while (bytes_out > 0) {
- if (ptext[bytes_out - 1] != 0) break;
+ /* Strip trailing zeros post decrypt which were added by padding. */
+ for (int i=0; bytes_out > 0 && i < 16; i++) {
+ if (text[bytes_out - 1] != 0) break;
bytes_out -= 1;
}
/* return the encrypted string */
- VALUE new_str = rb_str_new(ptext, bytes_out);
+ VALUE new_str = rb_str_new(text, bytes_out);
free(text);
return new_str;
}
View
@@ -25,19 +25,22 @@
/* structure to store our key and keysize */
typedef struct {
- unsigned char key[256]; /* max key is 256 */
- int key_bits; /* 128, 192, 256 */
+ u8 key[256]; /* max key is 256 */
+ int key_bits; /* 128, 192, 256 */
/* Encryption Round Keys */
u32 erk[4*(MAXNR + 1)];
u32 initial_erk[4*(MAXNR + 1)];
+ char inited_erk;
/* Decryption Round Keys */
u32 drk[4*(MAXNR + 1)];
u32 initial_drk[4*(MAXNR + 1)];
+ char inited_drk;
/* Number of rounds. */
int nr;
+
} fast_aes_t;
/* class functions */
@@ -50,7 +53,6 @@ VALUE fast_aes_key(VALUE self);
/* setup round keys */
int fast_aes_initialize_state(fast_aes_t* fast_aes_config);
-int fast_aes_initialize_state();
/* encryption routines */
VALUE fast_aes_encrypt(VALUE self, VALUE buffer);
View
@@ -1,4 +0,0 @@
-# This file just bridges the loading of the fast_aes.c extension,
-# with the FastAES include that can be used.
-require 'fast_aes'
-require 'fast_aes_static'
@@ -1,8 +0,0 @@
-class FastAES
- module Static
-
-
-
-
- end
-end
View
@@ -2,39 +2,37 @@
# Comparison of AES encryption libraries
#
-$LOAD_PATH.unshift "#{File.dirname(__FILE__)}/../ext/#{RUBY_PLATFORM}"
+require File.expand_path('spec_helper', File.dirname(__FILE__))
require 'benchmark'
-require './fast_aes'
-#require 'crypt/rijndael'
+require 'openssl'
-CHARS = ('A'..'z').collect{|x| x.to_s}
+Benchmark.bmbm(20) do |bm|
+ bm.report 'openssl' do
+ 10000.times do
+ cipher = OpenSSL::Cipher.new("AES-256-CBC")
+ cipher.encrypt
+ key = cipher.random_key
+ iv = cipher.random_iv
+ enc = cipher.update(LOREM_IPSUM) + cipher.final
-def random_key(bits)
- str = ''
- (bits/8).times do
- str += CHARS[rand(CHARS.length)]
- end
- str
-end
+ cipher = OpenSSL::Cipher.new("AES-256-CBC")
+ cipher.decrypt
+ cipher.key = key
+ cipher.iv = iv
+ dec = cipher.update(enc) + cipher.final
-Benchmark.bmbm(20) do |bm|
- #bm.report 'crypt/rijndael' do
- # 1000.times do
- # rijndael = Crypt::Rijndael.new(random_key(256))
- # plainBlock = "ABCDEFGH12345678"
- # encryptedBlock = rijndael.encrypt_block(plainBlock)
- # decryptedBlock = rijndael.decrypt_block(encryptedBlock)
- # end
- #end
+ raise "mismatch" unless dec == LOREM_IPSUM
+ end
+ end
bm.report 'fast_aes' do
10000.times do
aes = FastAES.new(random_key(256))
- plainBlock = "ABCDEFGH12345678"
- encryptedBlock = aes.encrypt(plainBlock)
- decryptedBlock = aes.decrypt(encryptedBlock)
+ enc = aes.encrypt(LOREM_IPSUM)
+ dec = aes.decrypt(enc)
+ raise "mismatch" unless dec == LOREM_IPSUM
end
end
-end
+end
View
@@ -14,11 +14,21 @@
text.should == aes.decrypt(data) # "Hey there, how are you?"
end
+ it "should fail on bad keys" do
+ e = nil
+ begin
+ FastAES.new("too short")
+ rescue => e
+ end
+ e.should.be.kind_of ArgumentError
+ end
+
it "should handle different key lengths" do
TEST_KEYS.each do |key|
- aes = FastAES.new(key.to_s)
+ aes = FastAES.new(key)
data = aes.encrypt(LOREM_IPSUM)
text = aes.decrypt(data)
+ text.should == LOREM_IPSUM
end
end
end
View
@@ -1,6 +1,8 @@
require 'bacon'
$LOAD_PATH.unshift File.expand_path(File.dirname(__FILE__) + '/..')
+$LOAD_PATH.unshift File.expand_path(File.dirname(__FILE__) + "/../ext/#{RUBY_PLATFORM}")
+
require 'fast_aes'
Bacon.summary_at_exit
@@ -10,11 +12,23 @@
require 'digest'
+def random_key(len)
+ x = ''
+ (len/8).times do
+ c = rand(256) until !c.nil? && c != 0
+ x << c.chr
+ end
+ x
+end
+
# Some sample keys
TEST_KEYS = [
- Digest::SHA256.hexdigest('foo').hex,
- Digest::SHA256.hexdigest('bar').hex,
- Digest::SHA256.hexdigest('fuck').hex
+ random_key(128),
+ random_key(128),
+ random_key(192),
+ random_key(192),
+ random_key(256),
+ random_key(256)
]
LOREM_IPSUM = <<-EndLorem

0 comments on commit f70547d

Please sign in to comment.