In [1]:
from pynq import Overlay
from pynq import Xlnk
import numpy as np
from struct import unpack
import math
import time

overlay = Overlay("./bit/forw_back_1.bit")  # ./bit/forw_back.bit  和  ./bit/model.bit
core = overlay.forw_back_0

In [2]:
# 向PL端口输入数据——将xlnk接口地址写入PL端接口地址
# 地址可在HLS下solution1/impl/misc/drivers/forw_back_v0_0/src/xforw_back_hw.h可查
# 地址也可在SDK下system.hdf文件IP blocks present in the design部分Registers可查


# 导入导出参数 flag==0导入，flag>4导出
def Import_export_parameter(flag, conv1, conv2, conv3, fc1, fc2, fc3):
    core.write(0x10, flag)
    core.write(0x20, conv1)
    core.write(0x28, conv2)
    core.write(0x30, conv3)
    core.write(0x38, fc1)
    core.write(0x40, fc2)
    core.write(0x48, fc3)
    core.write(0x00, 0x01)
    while(core.read(0x00)!=4):
        i = 1
    return 0

# 导入一张图像 flag==1
def Import_a_image(flag, in_d):
    core.write(0x10, flag)
    core.write(0x18, in_d)
    core.write(0x00, 0x01)
    while(core.read(0x00)!=4):  # 如果0x00!=4就代表PL端还在运行
        i = 1
    return 0

# 开始测试 flag==2
def Begin_test(flag):
    core.write(0x10, flag)
    core.write(0x00, 0x01)
    while(core.read(0x00)!=4):  # 如果0x00!=4就代表PL端还在运行
        i = 1
    return 0

# 开始训练 flag==3
def Begin_train(flag, label, lr):
    core.write(0x10, flag)
    core.write(0x58, label)
    core.write(0x60, lr)
    core.write(0x00, 0x01)
    while(core.read(0x00)!=4):
        i = 1
    return 0

# 导出结果 flag==4
def Export_result(flag, out):
    core.write(0x10, flag)
    core.write(0x50, out)
    core.write(0x00, 0x01)
    while(core.read(0x00)!=4):
        i = 1
    return 0

In [3]:
# 取最大值函数
def max10(data):
    for i in range(10):
        if(data[i] == max(data)):
            location = i
            break
    return location

In [4]:
xlnk = Xlnk()       # 该接口必须申请内存后才能被IP使用，可以使用xlnk来申请一段连续内存缓冲区，该缓冲区允许PS跟PL之间进行有效的数据传输
read_or_rand = 0    # 1_read 0_rand

# 给每个接口分配空间
data_in = xlnk.cma_array(shape=(30*30,), dtype=np.float32)
conv1 = xlnk.cma_array(shape=(9,), dtype=np.float32)
conv2 = xlnk.cma_array(shape=(9,), dtype=np.float32)
conv3 = xlnk.cma_array(shape=(9,), dtype=np.float32)
fc1 = xlnk.cma_array(shape=(576*180,), dtype=np.float32)
fc2 = xlnk.cma_array(shape=(180*45,), dtype=np.float32)
fc3 = xlnk.cma_array(shape=(45*10,), dtype=np.float32)
data_out = xlnk.cma_array(shape=(10,), dtype=np.float32)
lr = xlnk.cma_array(shape=(1,), dtype=np.float32)

para = open("./Network_parameter.bin", "rb")
conv1_f = []
conv2_f = []
conv3_f = []
fc1_f = []
fc2_f = []
fc3_f = []

if read_or_rand == 1:
    # 读取卷积核参数——存储器
    for i in range(9):
        data = para.read(4)
        conv1_f.append(unpack("f", data)[0])
    for i in range(9):
        data = para.read(4)
        conv2_f.append(unpack("f", data)[0])
    for i in range(9):
        data = para.read(4)
        conv3_f.append(unpack("f", data)[0])
    # 读取全连接系数矩阵参数——存储器
    for i in range(576*180):
        data = para.read(4)
        fc1_f.append(unpack("f", data)[0])
    for i in range(180*45):
        data = para.read(4)
        fc2_f.append(unpack("f", data)[0])
    for i in range(45*10):
        data = para.read(4)
        fc3_f.append(unpack("f", data)[0])
else:
    # 随机初始化网络参数——存储器
    conv1_f = np.random.random(9)
    conv2_f = np.random.random(9) / 5
    conv3_f = np.random.random(9) / 5
    fc1_f = np.random.random(576*180) / 1000
    fc2_f = np.random.random(180*45) / 100
    fc3_f = np.random.random(45*10) / 10
    
# 写入接口缓存——xlnk
for k in range(9):
    conv1[k] = conv1_f[k]
    conv2[k] = conv2_f[k]
    conv3[k] = conv3_f[k]
for k in range(576*180):
    fc1[k] = fc1_f[k]
for k in range(180*45):
    fc2[k] = fc2_f[k]
for k in range(45*10):
    fc3[k] = fc3_f[k]

In [5]:
# 读取图片标签和数据
img_label = []
def label_loader():
    for i in range(10):
        for j in range(30):
            img_label.append(i)

img_data = []
def data_loader(path):
    for i in range(10):                                 # 遍历分组
        imgs_path = path + '/'+ str(i) + '/'
        for j in range(30):                             # 遍历图片
            img_path = imgs_path + str(j + 1) + '.bmp'
            img = open(img_path, 'rb')
            img.seek(62)                                 # 跳过前62个没用的字节
            nums = []
            and_list = [1, 2, 4, 8, 16, 32, 64, 128]     # 辅助从字节中提取比特
            for i in range(120):                        
                num = unpack("B", img.read(1))[0]        # 从头开始逐一读取120个字节
                for j in range(8):                      # 遍历8个bit
                    if (i % 4 == 3) and (j >= 6):       # 检测到两个多余bit就跳过
                        continue
                    nums.append(int((num & and_list[7 - j]) == and_list[7 - j]))  # 真实数据就放到列表后面
            img_data.append(nums[:])
            

In [7]:
# 训练
data_loader("./Training_set")
label_loader()

# 0 导入初始化参数
Import_export_parameter(0, conv1.physical_address, conv2.physical_address, conv3.physical_address, fc1.physical_address, fc2.physical_address, fc3.physical_address)

corss_loss_max = 2
for epoch in range(100):
    lr[0] = pow((corss_loss_max/10), 1.7)
    if lr[0] > 0.01:
        lr[0] = 0.01
    if (epoch+1)%10==0:
        print('lose=',corss_loss_max,'   Learning rate=',lr[0])
    if lr[0] < 0.0000000001:
        break
    
    # 打乱图片——不打乱会学不到东西
    img_data_change = []
    img_label_change = 0
    label_1 = 0
    label_2 = 0
    for i in range(300):
        label_1 = np.random.randint(0,299)
        label_2 = np.random.randint(0,299)
        if label_1 != label_2:
            img_data_change = img_data[label_1]
            img_data[label_1] = img_data[label_2]
            img_data[label_2] = img_data_change
            img_label_change = img_label[label_1]
            img_label[label_1] = img_label[label_2]
            img_label[label_2] = img_label_change
    
    for i in range(300):
        corss_loss_max = 0
        for j in range(30*30):
            data_in[j] = img_data[i][j]

        # 1 导入一张图片数据
        Import_a_image(1, data_in.physical_address)
        
        # 3 开始训练
        # start = time.time()
        Begin_train(3, img_label[i], lr.physical_address)
        # end = time.time()
        # print('Running time: %s Seconds'%((end-start)))
        
        # 4 导出训练结果
        Export_result(4, data_out.physical_address)
        corss_loss_current = -(math.log10(data_out[img_label[i]]))
        
        if corss_loss_current > corss_loss_max:
            corss_loss_max = corss_loss_current

lose= 0.5882764844476362    Learning rate= 0.008096336
lose= 0.09338737000746733    Learning rate= 0.00035439685
lose= 0.09385506883810144    Learning rate= 0.00035741943
lose= 0.09220626963106518    Learning rate= 0.00034681094
lose= 0.08088558395621845    Learning rate= 0.00027757537
lose= 0.07674634398855432    Learning rate= 0.00025386224
lose= 0.07401391696625764    Learning rate= 0.0002386892
lose= 0.06725663005496929    Learning rate= 0.00020283817
lose= 0.0691017498137789    Learning rate= 0.00021238868
lose= 0.06831462469603568    Learning rate= 0.00020829233
