In [1]:
from extern_funcs import interpolate, Fusion, ln
from pynq import Overlay, allocate
import numpy as np
import math
import time
import nngen_ctrl as ng

In [2]:

params = np.load("../params_nngen/params.npz")['arr_0']
inputs_npz = np.load("../params_nngen/inputs.npz")
outputs_npz = np.load("../params_nngen/outputs.npz")
intrinsics_npz = np.load("../params_nngen/intrinsics.npz")

In [3]:
output_files = ['cell_state',
                'hidden_state',
                'depth_org']
input_files = ['reference_image',
               'hidden_state',
               'cell_state']

In [4]:
chunk_size = 64
def get_end_addr(addr, memory_size):
    return int(math.ceil((addr + memory_size) / chunk_size)) * chunk_size

def shape2size(shape):
    size = 1
    for s in shape:
        size *= s
    return size

In [5]:
axi_datawidth = 128
act_bit = 16
num_align_words = axi_datawidth // act_bit

output_offset = 0
outputs = []
output_aligned_shapes = []
addrs = [output_offset]
for file in output_files:
    output = outputs_npz[file]
    output_aligned_shape = (*output.shape[:-1], int(math.ceil(output.shape[-1] / num_align_words)) * num_align_words)
    outputs.append(output)
    output_aligned_shapes.append(output_aligned_shape)
    addrs.append(get_end_addr(addrs[-1], shape2size(output_aligned_shape) * (act_bit // 8)))

input_offset = addrs[-1]
inputs = []
for file in input_files:
    input = inputs_npz[file]
    lack = (num_align_words - (input.shape[-1] % num_align_words)) % num_align_words
    if lack > 0:
        input = np.append(input, np.zeros(list(input.shape[:-1]) + [lack], dtype=input.dtype), axis=input.ndim-1)
    input = np.reshape(input, [-1]).astype(np.int16)
    inputs.append(input)
    addrs.append(get_end_addr(addrs[-1], input.size * (act_bit // 8)))
cell_state_offset = addrs[-2] # change output cell_state addr to input addr
addrs[0] = cell_state_offset
param_offset = addrs[-1]
print(output_offset, input_offset, param_offset)
print(addrs)

0 110592 221184
[215040, 6144, 12288, 110592, 208896, 215040, 221184]


In [6]:
bitfile = 'design_1.bit'
ipname = 'dvmvs_0'

overlay = Overlay(bitfile)
# overlay.ip_dict
ip = ng.nngen_core(overlay, ipname)

In [7]:
memory_size = 1024 * 1024 * 192
buf = allocate(shape=(memory_size,), dtype=np.uint8)

In [8]:
for input, addr in zip(inputs, addrs[len(output_files):-1]):
    buf[addr:addr + input.size * (act_bit // 8)] = input.view(np.uint8)
buf[param_offset:param_offset + params.size] = params.view(np.int8)

In [9]:
ip.set_global_buffer(buf)
ip.write_buffer_address(0, cell_state_offset)
for i in range(7):
    print(ip.read_buffer_address(i))

215040
6144
12288
110592
208896
215040
221184


In [10]:
half_K = intrinsics_npz["half_K"]
pose1s = intrinsics_npz["current_pose"]
pose2ss = intrinsics_npz["measurement_poses"]
fusion = Fusion(11, half_K, pose1s, pose2ss)

# opcode -> (func, input.addr, input.aligned_shape, output.addr, output.aligned_shape)
externs = {0x80: (fusion, 175484352, (1, 32, 48, 32), 175680960, (1, 32, 48, 64)),
           0x105: (ln(12), 176906688, (1, 2, 3, 512), 176925120, (1, 2, 3, 512)),
           0x106: (ln(12), 176931264, (1, 2, 3, 512), cell_state_offset, (1, 2, 3, 512)),
           0x108: (interpolate(4, 6, 0, 'bilinear'), 6144, (1, 2, 3, 512), 176943552, (1, 4, 6, 512)),
           0x114: (interpolate(8, 12, 0, 'bilinear'), 177029568, (1, 4, 6, 256), 177041856, (1, 8, 12, 256)),
           0x116: (interpolate(8, 12, 0, 'bilinear'), 177091008, (1, 4, 6, 8), 177091392, (1, 8, 12, 8)),
           0x121: (interpolate(16, 24, 0, 'bilinear'), 177194304, (1, 8, 12, 128), 177218880, (1, 16, 24, 128)),
           0x123: (interpolate(16, 24, 0, 'bilinear'), 177317184, (1, 8, 12, 8), 177318720, (1, 16, 24, 8)),
           0x128: (interpolate(32, 48, 0, 'bilinear'), 177533760, (1, 16, 24, 64), 177582912, (1, 32, 48, 64)),
           0x130: (interpolate(32, 48, 0, 'bilinear'), 177779520, (1, 16, 24, 8), 177785664, (1, 32, 48, 8)),
           0x135: (interpolate(64, 96, 0, 'bilinear'), 178350912, (1, 32, 48, 8), 178768704, (1, 64, 96, 8)),
           0x136: (interpolate(64, 96, 0, 'bilinear'), 178252608, (1, 32, 48, 32), 178375488, (1, 64, 96, 32))}

In [11]:
def run_extern(code):
    start_time = time.time()
    
    func, input_addr, input_aligned_shape, output_addr, output_aligned_shape = externs[code]
    input = buf[input_addr:input_addr + shape2size(input_aligned_shape) * (act_bit // 8)].view(np.int16).reshape(input_aligned_shape)
    if input.shape[-1] == 8:
        input = input[:,:,:,:1]
    output = func(input)
    if output.shape != output_aligned_shape:
        output = np.append(output, np.zeros((*output.shape[:-1], 7), dtype=output.dtype), axis=output.ndim-1)
    output = output.astype(np.int16).reshape(-1)
    buf[output_addr:output_addr + shape2size(output_aligned_shape) * (act_bit // 8)] = output.view(np.uint8)

    print(code, time.time() - start_time)

In [12]:
frame_number = inputs_npz["frame_number"]
n_measurement_frames = inputs_npz["n_measurement_frames"]
image2s = [inputs_npz['measurement_feature0'], inputs_npz['measurement_feature1']]

start_time = time.time()

N = 1
for _ in range(N):
    ip.run()
    fusion.prep(frame_number, n_measurement_frames, image2s)
    for i in range(len(externs)):
        code = ip.wait_extern()
        run_extern(code)
        ip.resume_extern()
    ip.wait()

print((time.time() - start_time) / N)

128 0.5190370082855225
261 0.0035042762756347656
262 0.003245830535888672
264 0.018647193908691406
276 0.0018463134765625
278 0.0012662410736083984
289 0.00281524658203125
291 0.0019576549530029297
296 0.005766630172729492
304 0.0016286373138427734
310 0.013348579406738281
309 0.0028235912322998047
2.6100411415100098


In [13]:
for i, (output, output_aligned_shape) in enumerate(zip(outputs, output_aligned_shapes)):
    orig = buf[addrs[i]:addrs[i] + shape2size(output_aligned_shape) * (act_bit // 8)].view(np.int16).reshape(output_aligned_shape)
    orig = orig[:,:,:,:output.shape[-1]]
    print(output_files[i], np.corrcoef(output.reshape(-1), orig.reshape(-1))[0, 1])

cell_state 0.9999965920678364
hidden_state 0.9999952132076847
depth_org 0.9999482153902011


In [14]:
buf.freebuffer()