# Import

In [1]:
from pynq import allocate
from Crypto.Protocol.KDF import PBKDF2 
from Crypto.Cipher import AES
import numpy as np
import time
import os

In [2]:
%run ./AES_Utils.ipynb

## Define AES core

In [3]:
AES128_MODE = 1
AES192_MODE = 2
AES256_MODE = 3
AES_LENGTH_ADDR = 0x100
AES_KEY_ADDR = 0x10

In [4]:
class AES_Core():
    def __init__(self , aes_core , dma_write , dma_read , key , buffer_size , AES_MODE = AES256_MODE , self_test = True ):
        self.key = key
        self.aes_core = aes_core
        self.dma_write = dma_write
        self.dma_read = dma_read
        self.length_addr = AES_LENGTH_ADDR
        self.key_expansion_addr = AES_KEY_ADDR
        #config aes
        self.AES_MODE = AES_MODE
        self.buffer_size = buffer_size
        self.Rcon = [ 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36 ]
        self.sbox = [
            0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
            0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
            0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
            0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
            0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
            0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
            0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
            0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
            0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
            0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
            0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
            0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
            0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
            0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
            0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
            0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16 
        ]
        self.Nb = 4
        if self.AES_MODE == AES128_MODE:
            self.Nk = 4
            self.Nr = 10
            self.AES_KEYLEN  = 16  
            self.AES_keyExpSize = 176
        elif self.AES_MODE == AES192_MODE:
            self.Nk = 6
            self.Nr = 12
            self.AES_KEYLEN  = 24
            self.AES_keyExpSize = 208
        elif self.AES_MODE == AES256_MODE:   
            self.Nk = 8
            self.Nr = 14
            self.AES_KEYLEN  = 32
            self.AES_keyExpSize = 240
        else:
            raise Exception ("UnSupport AES Mode")
        self.key_expansion = self._key_expansion(key)
        self.input_buffer = None
        self.output_buffer = None
        self.input_length = None
        
        if 'encrypt' in self.aes_core._fullpath:
            self.aes_core_mode = 'encrypt'
        elif 'decrypt' in self.aes_core._fullpath:
            self.aes_core_mode = 'decrypt'
        else:
            raise Exception ("init AES core fail")       
        #print( [hex(i) for i in self.key_expansion])
        
        #config key to PL
        for i in range(0,len(self.key_expansion),4):
            byte_data = int.from_bytes( self.key_expansion[i:i+4] ,'little')
            #write to PL
            self.aes_core.write(self.key_expansion_addr + i , byte_data )
        
        
        #AES self test
        if self_test:
            t = os.urandom(buffer_size)
            if self.aes_core_mode == 'encrypt':
                r = aesEncrypt(t , self.key , padding=False)
                self.encryption( t )
            elif self.aes_core_mode == 'decrypt':
                r = aesDecrypt(t , self.key , padding=False)
                self.decryption( t )
            hw_r = self.wait_done()
            t = np.frombuffer(r , dtype= np.uint8)
            if not (t == hw_r).all():
                print(f"software result: {t}")
                print(f"hardware result: {hw_r}")
                raise Exception (f"{self.aes_core._fullpath} self {self.aes_core_mode} test fail")  
            print(f"{self.aes_core._fullpath} self {self.aes_core_mode} test success")
        
        
    def encryption(self,plain):
        if self.aes_core_mode != 'encrypt':
            raise Exception ("This core unsupport encryption")       
        
        buffer_len = len(plain)
        self.input_length = buffer_len
        
        #config PL data length
        self.aes_core.write(self.length_addr , self.buffer_size ) 
        
        if ( buffer_len > self.buffer_size ):
            raise Exception ("input size is bigger than buffer size")            
        
        
        if self.input_buffer is None:
            self.input_buffer = allocate(shape=( self.buffer_size ,), dtype=np.uint8)
        if self.output_buffer is None:
            self.output_buffer = allocate(shape=( self.buffer_size ,), dtype=np.uint8)
        
        self.buffer_length = buffer_len
        
        self.input_buffer[:buffer_len] = np.frombuffer(plain,dtype=np.uint8 )
            
        #start the PL
        self.aes_core.write(0x00, 0x01)
        self.dma_write.sendchannel.transfer(self.input_buffer)
        self.dma_read.recvchannel.transfer(self.output_buffer)
        #self.dma_write.sendchannel.wait()
        #self.dma_read.recvchannel.wait()
        #self.dma_read.recvchannel.wait_async()
        #return np.array( self.output_buffer )
     
    
    def decryption(self,encrypt):
        if self.aes_core_mode != 'decrypt':
            raise Exception ("This core unsupport decryption")  
        
        buffer_len = len(encrypt)
        self.input_length = buffer_len
        
        #config PL data length
        self.aes_core.write(self.length_addr , self.buffer_size ) 
        
        if ( buffer_len > self.buffer_size ):
            raise Exception ("input size is bigger than buffer size")            
        
        
        if self.input_buffer is None:
            self.input_buffer = allocate(shape=( self.buffer_size ,), dtype=np.uint8)
        if self.output_buffer is None:
            self.output_buffer = allocate(shape=( self.buffer_size ,), dtype=np.uint8)
        
        self.buffer_length = buffer_len
        
        self.input_buffer[:buffer_len] = np.frombuffer(encrypt,dtype=np.uint8 )
            
        #start the PL
        self.aes_core.write(0x00, 0x01)
        self.dma_write.sendchannel.transfer(self.input_buffer)
        self.dma_read.recvchannel.transfer(self.output_buffer)
        
    def wait_done(self):
        self.dma_write.sendchannel.wait()
        self.dma_read.recvchannel.wait()
        return np.array( self.output_buffer )[:self.input_length]
    
    def free(self):
        if self.input_buffer is not None:
            self.input_buffer.invalidate()
            self.input_buffer.freebuffer()
            self.input_buffer.close()
        if self.output_buffer is not None:
            self.output_buffer.invalidate()
            self.output_buffer.freebuffer()
            self.output_buffer.close()
        self.input_buffer = None
        self.output_buffer = None
        
    def _key_expansion(self,key):
        round_key = np.zeros(shape=(self.AES_keyExpSize),dtype=np.uint8)
        tempa = np.zeros(shape=(4),dtype=np.uint8)
        for i in range(self.Nk):
            round_key[(i * 4) + 0] = key[(i * 4) + 0];
            round_key[(i * 4) + 1] = key[(i * 4) + 1];
            round_key[(i * 4) + 2] = key[(i * 4) + 2];
            round_key[(i * 4) + 3] = key[(i * 4) + 3];

        for i in range( self.Nk, self.Nb * ( self.Nr + 1 ) ):
            k = (i - 1) * 4;
            tempa[0]= round_key[k + 0];
            tempa[1]= round_key[k + 1];
            tempa[2]= round_key[k + 2];
            tempa[3]= round_key[k + 3];

            if (i % self.Nk == 0):
                u8tmp = tempa[0];
                tempa[0] = tempa[1];
                tempa[1] = tempa[2];
                tempa[2] = tempa[3];
                tempa[3] = u8tmp;

                tempa[0] = self.sbox[tempa[0]];
                tempa[1] = self.sbox[tempa[1]];
                tempa[2] = self.sbox[tempa[2]];
                tempa[3] = self.sbox[tempa[3]];

                tempa[0] = tempa[0] ^ self.Rcon[i // self.Nk];
            if self.AES_MODE == AES256_MODE:
                if (i % self.Nk == 4):
                    tempa[0] = self.sbox[tempa[0]];
                    tempa[1] = self.sbox[tempa[1]];
                    tempa[2] = self.sbox[tempa[2]];
                    tempa[3] = self.sbox[tempa[3]];
            j = i * 4; k=(i - self.Nk) * 4;
            round_key[j + 0] = round_key[k + 0] ^ tempa[0];
            round_key[j + 1] = round_key[k + 1] ^ tempa[1];
            round_key[j + 2] = round_key[k + 2] ^ tempa[2];
            round_key[j + 3] = round_key[k + 3] ^ tempa[3];  
        return round_key

## AES core schedule

In [5]:
class AES_Schedule():
    def __init__(self,cores , block_size = 10 * 1024 * 1024 ):
        self.aes_cores = cores  #AES_Core array
        self.aes_cores_number = len(self.aes_cores)
        print(f"Init success with {self.aes_cores_number} AES cores")
        self.queue = []
        self.done = []
        self.encrypt_file_expansion = ".encrypt"
        self.decrypt_file_expansion = ".decrypt"
        self.block_size = block_size
        
    def enqueue_file(self,file):
        self.queue.append( file )
        
        
    def list_md5(self):
        for info in self.done:
            mode = info[0]
            file = info[1]
            print(cmd("md5sum %s"%( file )))
            if mode == 'encrypt':
                print(cmd("md5sum %s"%( file + self.encrypt_file_expansion )))
            else:
                print(cmd("md5sum %s"%( file + self.decrypt_file_expansion )))
            
            
    def free(self):
        for core in self.aes_cores:
            core.free()
            
    def start_encrypt_schedule(self):
        for i in self.queue:
            f = open(i , 'rb' )
            f.seek(0, 2)          # seek to file end
            file_size = f.tell()  # get the file size
            f.seek(0, 0)          # seek to file start
            num_core_config = 0
            core_used = self.aes_cores_number
            f_w = open(i + self.encrypt_file_expansion , 'wb' )
            print(f"start encrypt {i} with size {file_size}")
            do_time = file_size // self.block_size
            print(f"need do {do_time} time")
            for t in range(do_time):
                read_buffer = f.read( self.block_size )
                print('Start part %4d / %4d on core #%d %s'%(t+1, do_time, num_core_config, '.'*(num_core_config+1)+' '*(self.aes_cores_number-1-num_core_config)), end='\r')
                core_status[num_core_config] = t+1
                self.aes_cores[ num_core_config ].encryption( read_buffer )
                num_core_config += 1
                if num_core_config >= core_used or t == do_time - 1:
                    for k in range(num_core_config):  # wait all core done
                        print(core_status, end='\r')
                        write_buffer = self.aes_cores[ k ].wait_done()
                        f_w.write( write_buffer )
                    num_core_config = 0
            print()
            last_size = file_size % self.block_size
            read_buffer = f.read(last_size)
            self.aes_cores[0].encryption( pad( read_buffer ) ) 
            write_buffer = self.aes_cores[0].wait_done()
            f_w.write( write_buffer )
            f_w.close()
            self.done.append( [ "encrypt" , i ] )
        self.queue = []
        
        
    def start_decrypt_schedule(self):
        for i in self.queue:
            f = open( i , 'rb' )
            f.seek(0, 2)          # seek to file end
            file_size = f.tell()  # get the file size
            f.seek(0, 0)          # seek to file start
            num_core_config = 0
            f_w = open(i + self.decrypt_file_expansion , 'wb' )
            print(f"start decrypt {i} with size {file_size}")
            do_time = file_size // self.block_size
            last_size = file_size % self.block_size
            print(f"need do {do_time} time")
            for t in range(do_time):
                read_buffer = f.read( self.block_size )
                print('Start part %4d / %4d on core #%d %s'%(t+1, do_time, num_core_config, '.'*(num_core_config+1)+' '*(self.aes_cores_number-1-num_core_config)), end='\r')
                self.aes_cores[ num_core_config ].decryption( read_buffer ) 
                num_core_config += 1
                if num_core_config >= self.aes_cores_number or t == do_time - 1:
                    for k in range(num_core_config):  # wait all core done
                        write_buffer = self.aes_cores[ k ].wait_done()
                        if t == do_time - 1 and last_size == 0:
                            write_buffer = unpad( write_buffer )
                        f_w.write( write_buffer )
                    num_core_config = 0
            print()
            if last_size != 0:
                read_buffer = f.read(last_size)
                self.aes_cores[0].decryption( read_buffer ) 
                write_buffer = self.aes_cores[0].wait_done()
                f_w.write( unpad( write_buffer ) )
            f_w.close()
            self.done.append( [ "decrypt" , i ] )
        self.queue = []