In [4]:
#Copyright (c) 2022 Yuki Tanaka
#Released under the MIT license
#https://github.com/kwdlab/tanaka.yuki/blob/kuwakado.hidenori/LICENCE

from ctypes import *
import ctypes.util  
from numpy.ctypeslib import ndpointer 
import numpy as np
from numpy.random import *

from pynq import Overlay
import math
import random, string
import time


overlay = Overlay('/home/xilinx/pynq/overlays/photon256_4/photon256_4.bit') #Overlay
permutation_ip = overlay.Permutation_0

ROUND=12
D=6
RATE=32
RATEP=32
DIGESTSIZE=256
S=8
ReductionPoly = 0x1b
WORDFILTER = (1<<S)-1
DEBUG = 0
MessBitLen = 0

presets=[0,0,0]

def genRandomMessage(n):
   randlst = [random.choice(string.ascii_letters + string.digits) for i in range(n)]
   return ''.join(randlst)

def printState(state):
    for i in range(D):
        row=""
        for j in range(D):
            row+="{:02x} ".format(state[i][j])
        print(row)
    print("")

def printDigest(digest):
    text=""
    for i in range(len(digest)):
        text+="{:02x} ".format(digest[i])
    print(text)

def GetByte(str,BitOffSet):
    if(BitOffSet&0x7):
        return ((str[BitOffSet>>3]<<4) | (str[(BitOffSet>>3)+1]>>4))
    else:
        return str[BitOffSet>>3]

def WordXorByte(state,str,BitOffSet,WordOffSet,NoOfBits):
    i = 0
    while(i < NoOfBits):
        index_i=int((WordOffSet+int(i/S))/D)
        index_j=int((WordOffSet+int(i/S))%D)
        state[index_i][index_j] = state[index_i][index_j]^GetByte(str, BitOffSet+i)
        #print(f"i={i}")
        #print(state)
        #print(state[index_i][index_j])
        i += S
        
def WriteByte(str,value,BitOffSet,NoOfBits):
    ByteIndex = BitOffSet >> 3
    BitIndex = BitOffSet & 0x7
    localFilter = (1<<NoOfBits) - 1
    value &= localFilter
    if(BitIndex+ NoOfBits <= 8):
        str[ByteIndex] &= ~(localFilter<<(8-BitIndex-NoOfBits))
        str[ByteIndex] |= value<<(8-BitIndex-NoOfBits)
    else:
        tmp = ((str[ByteIndex]<<8)&0xFF00) | (str[ByteIndex+1]&0xFF)
        tmp &= ~((localFilter&0xFF)<<(16-BitIndex-NoOfBits))
        tmp |= (value&0xFF)<<(16-BitIndex-NoOfBits)
        str[ByteIndex] = (tmp>>8)&0xFF
        str[ByteIndex+1] = tmp&0xFF

def WordToByte(state,str,BitOffSet,NoOfBits):
    i = 0
    while(i < NoOfBits):
        WriteByte(str, (state[math.floor(i/(S*D))][int(i/S)%D] & WORDFILTER)>>(S-min(S, NoOfBits-i)), BitOffSet+i, min(S, NoOfBits-i))
        i += S
def Init(state):
    MessBitLen = 0
    for i in range(D):
        for j in range(D):
            state[i][j] = 0
    presets[0] = (DIGESTSIZE>>2) & 0xFF
    presets[1] = RATE & 0xFF
    presets[2] = RATEP  & 0xFF
    WordXorByte(state, presets, 0, D*D-24/S, 24)

def Permutation(state, R):
    n=0
    temp=0
    shift_base=1
    for i in range(D):
        for j in range(D):
            if(n%4==3):
                temp+=(shift_base*state[i][j])
                permutation_ip.write(0x40+(D*i)+j-0x03, temp)
                #print(f"{hex(0x40+D*i+j-0x03)} {hex(temp)}")
                shift_base=1
                temp=0
            else:
                temp+=(shift_base*state[i][j])
                shift_base=shift_base<<8
            #print(hex(temp))
            n+=1
    permutation_ip.register_map.R=ROUND
    permutation_ip.register_map.CTRL.AP_START = 1
    while permutation_ip.register_map.CTRL.AP_DONE == 0:
        pass
    n=0
    temp=0
    for i in range(D):
        for j in range(D):
            if(n%4==0):
                temp=permutation_ip.read(0x40+D*i+j)
                state[i][j]=(temp>>0)&0xFF
            elif(n%4==1):
                state[i][j]=(temp>>8)&0xFF
            elif(n%4==2):
                state[i][j]=(temp>>16)&0xFF
            elif(n%4==3):
                state[i][j]=(temp>>24)&0xFF
            n+=1
def Permutation_c(state, R):
    np_state0 = np.array(state[0],dtype=np.uint8)
    np_state1 = np.array(state[1],dtype=np.uint8)
    np_state2 = np.array(state[2],dtype=np.uint8)
    np_state3 = np.array(state[3],dtype=np.uint8)
    np_state4 = np.array(state[4],dtype=np.uint8)
    np_state5 = np.array(state[5],dtype=np.uint8)
    
    lib = np.ctypeslib.load_library("permutation_c3.so",".")

    #doubleのポインタのポインタ型を用意
    PP = ndpointer(dtype=np.uint8, ndim=1, flags='C_CONTIGUOUS' )
    #add_matrix()関数の引数の型を指定(ctypes)　
    lib.Permutation.argtypes = [PP,PP,PP,PP,PP,PP,c_int32] #pointer,round数
    #add_matrix()関数が返す値の型を指定(今回は返り値なし)
    lib.Permutation.restype = None

    #tp = np.uint8
    #mpp = (np_state.__array_interface__['data'][0] + np.arange(np_state.shape[0])*np_state.strides[0]).astype(tp)

    #int型もctypeのc_int型へ変換して渡す
    R = ctypes.c_int(R)
    #print(mpp)
    #print(R)

    lib.Permutation(np_state0,np_state1,np_state2,np_state3,np_state4,np_state5, R)
    
    #print(np_state)
    
    state[0]=np_state0.tolist()
    state[1]=np_state1.tolist()
    state[2]=np_state2.tolist()
    state[3]=np_state3.tolist()
    state[4]=np_state4.tolist()
    state[5]=np_state5.tolist()

def Squeeze(state,digest,mode):
    i = 0
    while(1):
        WordToByte(state, digest, i, min(RATEP, DIGESTSIZE-i))
        i += RATEP
        if(i >= DIGESTSIZE):
            break
        if(mode==0):
            Permutation(state, ROUND)
        else:
            Permutation_c(state, ROUND)

def CompressFunction(state,mess,BitOffSet,mode):
    WordXorByte(state, mess, BitOffSet, 0, RATE)
    if(mode==0):
        Permutation(state, ROUND)
    else:
        Permutation_c(state, ROUND)
    
def hash(digest,mess,BitLen,mode):
    state=[]
    for i in range(D):
        state.append([])
        for j in range(D):
            state[i].append(0)
    padded=[]
    for i in range(math.ceil(RATE/8.0) + 1):
        padded.append(0)
    Init(state)
    #print("after_Init")
    #printState(state)
    MessIndex = 0
    while(MessIndex <= (BitLen-RATE)):
        CompressFunction(state, mess.encode(), MessIndex,mode)
        MessIndex += RATE
    j = math.ceil((BitLen - MessIndex)/8.0)
    i = math.ceil(RATE/8.0) + 1
    while(i<j):
        padded[i]=mess[(MessIndex/8)+i]
        i+=1
    padded[j] = 0x80
    CompressFunction(state, padded, MessIndex&0x7,mode)
    #print("after_CompressFunction")
    #printState(state)
    Squeeze(state, digest,mode)
    #print("after_Squeeze")
    #printState(state)
    
def testPermutation(n,mode=0,result="default"):
    #ランダムなstateをn個生成
    test_state=[]
    for n in range(n):
        test_state.append([])
        for i in range(6):
            test_state[n].append([])
            for j in range(6):
                test_state[n][i].append(random.randint(0x00,0xFF))
                
    start_time=0
    end_time=0
    #テスト開始
    sum=0
    start_time=time.perf_counter()
    if(mode==0):
        print("-------FPGA-------")
    else:
        print("-------C言語-------")
    for n in range(n):
        start_time=time.perf_counter()
        if(mode==0):
            Permutation(test_state[n],ROUND)
        else:
            Permutation_c(test_state[n],ROUND)
        end_time=time.perf_counter()
        if(result=="all"):
            print(f"{n}:{end_time-start_time}s")
        else:
            pass
        sum+=end_time-start_time
    print(f"平均:{sum/n}s")
def testHash(message_len):
    digest=[]
    for i in range(math.floor(DIGESTSIZE/8)):
        digest.append(0)
    mess = genRandomMessage(message_len)
    #mess = "The PHOTON Lightweight Hash Functions Family"
    #print("{}".format(mess))
    print("-------FPGA-------")
    start_time=time.perf_counter()
    hash(digest, mess, 8*len(mess),0)
    end_time=time.perf_counter()
    printDigest(digest)
    fpga_time=end_time-start_time
    print(f"time:{end_time-start_time}s")
    
    print("-------CPU-------")
    start_time=time.perf_counter()
    hash(digest, mess, 8*len(mess),1)
    end_time=time.perf_counter()
    printDigest(digest)
    cpu_time=end_time-start_time
    print(f"time:{end_time-start_time}s")

print("D = {}, Rate = {}, Ratep = {}, DigestSize = {}".format(D, RATE, RATEP, DIGESTSIZE))
testHash(30000)

D = 6, Rate = 32, Ratep = 32, DigestSize = 256
-------FPGA-------
58 d4 e6 b3 23 d2 ce 90 d0 f9 ea 61 aa f1 aa fc dd 17 0a 48 0f 0b 3f 37 12 ce 97 65 25 7e e1 2b 
time:10.389401573000214s
-------CPU-------
58 d4 e6 b3 23 d2 ce 90 d0 f9 ea 61 aa f1 aa fc dd 17 0a 48 0f 0b 3f 37 12 ce 97 65 25 7e e1 2b 
time:24.75431899599971s
