In [6]:
import os
import shutil
import numpy as np
from tqdm.notebook import tqdm

import torch
import torchvision
import torchvision.transforms as transforms

import binarybrain as bb


In [9]:
# config
net_name= 'MnistStochasticLutSimple'
data_path= os.path.join('./data/', net_name)
rtl_sim_path= '../../verilog/mnist/tb_mnist_lut_simple'
rtl_module_name='MnistLutSimple'
output_velilog_file=os.path.join(data_path, net_name + '.v')
sim_velilog_file=os.path.join(rtl_sim_path, rtl_module_name + '.v')
epochs=4
mini_batch_size=32


##dataset

dataset_path='./data/'
dataset_train=torchvision.datasets.MNIST(root=dataset_path, train=True, transform=transforms.ToTensor(),download=True)
dataset_test=torchvision.datasets.MNIST(root=dataset_path, train=False, transform=transforms.ToTensor(), download=True)
loader_train=torch.utils.data.DataLoader(dataset=dataset_train, batch_size=mini_batch_size, shuffle=True, num_workers=2)
loader_test=torch.utils.data.DataLoader(dataset=dataset_test, batch_size=mini_batch_size, shuffle=False, num_workers=2)






In [72]:
#define network

lut_layer0_0 = bb.DifferentiableLut([6*36], batch_norm=False, binarize=False)
lut_layer0_1 = bb.DifferentiableLut([36], batch_norm=False, binarize=False)
lut_layer1_0 = bb.DifferentiableLut([2*6*36], batch_norm=False, binarize=False)
lut_layer1_1 = bb.DifferentiableLut([2*36], batch_norm=False, binarize=False)
lut_layer2_0 = bb.DifferentiableLut([2*6*36], batch_norm=False, binarize=False)
lut_layer2_1 = bb.DifferentiableLut([2*36], batch_norm=False, binarize=False)

lut_layer3_0 = bb.DifferentiableLut([4*6*36], batch_norm=False, binarize=False)
lut_layer3_1 = bb.DifferentiableLut([10], batch_norm=False, binarize=False)

print(lut_layer0_0)
print(lut_layer0_1)
print(lut_layer1_0)
print(lut_layer1_1)
print(lut_layer2_0)
print(lut_layer2_1)
print(lut_layer3_0)
print(lut_layer3_1)

net = bb.Sequential([
       lut_layer0_0, lut_layer0_1,
       lut_layer1_0, lut_layer1_1,
       lut_layer2_0, lut_layer2_1,
        lut_layer3_0, lut_layer3_1,
      
])

print(net.get_info())





----------------------------------------------------------------------
[DifferentiableLut6] 
 input  shape : {} output shape : {216}
 connection : random
 binary : 0 batch_norm : 0
----------------------------------------------------------------------

----------------------------------------------------------------------
[DifferentiableLut6] 
 input  shape : {} output shape : {36}
 connection : random
 binary : 0 batch_norm : 0
----------------------------------------------------------------------

----------------------------------------------------------------------
[DifferentiableLut6] 
 input  shape : {} output shape : {432}
 connection : random
 binary : 0 batch_norm : 0
----------------------------------------------------------------------

----------------------------------------------------------------------
[DifferentiableLut6] 
 input  shape : {} output shape : {72}
 connection : random
 binary : 0 batch_norm : 0
--------------------------------------------------------------

In [73]:
net.set_input_shape([1, 28, 28])
print(net.get_input_shape())
print(net.get_input_node_size())
print(net.get_output_node_size())
print(net.get_output_shape())

net.send_command("binary false")  
net.send_command("lut_binarize true")
print(net.get_info())


[1, 28, 28]
784
864
[10]
----------------------------------------------------------------------
[Sequential] 
 input  shape : [1, 28, 28] output shape : [10]
  --------------------------------------------------------------------
  [DifferentiableLut6] 
   input  shape : {1, 28, 28} output shape : {216}
   connection : random
   binary : 0   batch_norm : 0
  --------------------------------------------------------------------
  [DifferentiableLut6] 
   input  shape : {216} output shape : {36}
   connection : random
   binary : 0   batch_norm : 0
  --------------------------------------------------------------------
  [DifferentiableLut6] 
   input  shape : {36} output shape : {432}
   connection : random
   binary : 0   batch_norm : 0
  --------------------------------------------------------------------
  [DifferentiableLut6] 
   input  shape : {432} output shape : {72}
   connection : random
   binary : 0   batch_norm : 0
  -------------------------------------------------------------

In [77]:
# learning

loss= bb.LossSoftmaxCrossEntropy()
print(loss)

metrics =bb.MetricsBinaryCategoricalAccuracy()
print(metrics)

optimizer=bb.OptimizerAdam()
print(optimizer)

optimizer.set_variables(net.get_parameters(), net.get_gradients())
print(net.get_parameters())
print(net.get_gradients())
epochs=20
for epoch in range(epochs):
    loss.clear()
    metrics.clear()

    #learning

    with tqdm(loader_train) as t:
        for images, labels in t:
            x_buf = bb.FrameBuffer.from_numpy(np.array(images).astype(np.float32))
            t_buf= bb.FrameBuffer.from_numpy(np.identity(10)[np.array(labels)].astype(np.float32))
            y_buf= net.forward(x_buf, train=True)
            dy_buf= loss.calculate(y_buf, t_buf)
            metrics.calculate(y_buf, t_buf)
            net.backward(dy_buf)
            optimizer.update()

            t.set_postfix(loss=loss.get(), acc=metrics.get())

    loss.clear()
    metrics.clear()
    for images, labels in loader_test:
        x_buf = bb.FrameBuffer.from_numpy(np.array(images).astype(np.float32))
        t_buf = bb.FrameBuffer.from_numpy(np.identity(10)[np.array(labels)].astype(np.float32))
        y_buf = net.forward(x_buf, train=False)    
        loss.calculate(y_buf, t_buf)
        metrics.calculate(y_buf, t_buf)

    bb.save_networks(data_path, net)
    print('epoch[%d] : loss=%f accuracy=%f' % (epoch, loss.get(), metrics.get()))

            

            

    

    

<binarybrain.losses.LossSoftmaxCrossEntropy object at 0x7f7034808e80>
<binarybrain.metrics.MetricsBinaryCategoricalAccuracy object at 0x7f703480b910>
<binarybrain.optimizer.OptimizerAdam object at 0x7f703480b8b0>
<binarybrain.variables.Variables object at 0x7f6f19b329e0>
<binarybrain.variables.Variables object at 0x7f6f222d88b0>


  0%|          | 0/1875 [00:00<?, ?it/s]

  x_buf = bb.FrameBuffer.from_numpy(np.array(images).astype(np.float32))
  t_buf= bb.FrameBuffer.from_numpy(np.identity(10)[np.array(labels)].astype(np.float32))
  x_buf = bb.FrameBuffer.from_numpy(np.array(images).astype(np.float32))
  t_buf = bb.FrameBuffer.from_numpy(np.identity(10)[np.array(labels)].astype(np.float32))


epoch[0] : loss=1.559163 accuracy=0.964700


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[1] : loss=1.560958 accuracy=0.964510


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[2] : loss=1.559725 accuracy=0.963130


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[3] : loss=1.552849 accuracy=0.967690


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[4] : loss=1.555471 accuracy=0.965400


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[5] : loss=1.550608 accuracy=0.966080


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[6] : loss=1.549486 accuracy=0.967360


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[7] : loss=1.547668 accuracy=0.968090


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[8] : loss=1.544050 accuracy=0.970890


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[9] : loss=1.548690 accuracy=0.969390


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[10] : loss=1.547224 accuracy=0.970100


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[11] : loss=1.543706 accuracy=0.970970


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[12] : loss=1.544134 accuracy=0.971620


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[13] : loss=1.542414 accuracy=0.972110


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[14] : loss=1.542337 accuracy=0.972400


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[15] : loss=1.540827 accuracy=0.973360


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[16] : loss=1.540165 accuracy=0.972300


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[17] : loss=1.542743 accuracy=0.971910


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[18] : loss=1.538157 accuracy=0.973610


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch[19] : loss=1.542132 accuracy=0.972810


In [78]:
# export verilog
with open(output_velilog_file, 'w') as f:
    f.write('\n`timescale 1ns / 1ps\n\n\n')
    f.write(bb.make_verilog_lut_layers(rtl_module_name + 'Simple' , net[0]))
    f.write(bb.make_verilog_lut_layers(rtl_module_name + 'Simple' , net[1]))
    f.write(bb.make_verilog_lut_layers(rtl_module_name + 'Simple' , net[2]))

# Simulation用ファイルに上書きコピー
shutil.copyfile(output_velilog_file, sim_velilog_file)

# Simulationで使う画像の生成
def img_geneator():
    for data in dataset_test:
        yield data[0] # 画像とラベルの画像の方を返す

img = (bb.make_image_tile(480//28+1, 640//28+1, img_geneator())*255).astype(np.uint8)
bb.write_ppm(os.path.join(rtl_sim_path, 'mnist_test_160x120.ppm'), img[:,:120,:160])
bb.write_ppm(os.path.join(rtl_sim_path, 'mnist_test_640x480.ppm'), img[:,:480,:640])

In [104]:
bb.load_networks(data_path, net)

# LUTモデルは BIT型を使ってメモリ節約が可能
bin_dtype = bb.DType.BIT  # bb.DType.BIT or bb.DType.FP32

# 同一形状のバイナリLUTを生成
bin_lut0_0 = bb.BinaryLut.from_sparse_model(lut_layer0_0)
bin_lut0_1 = bb.BinaryLut.from_sparse_model(lut_layer0_1)
bin_lut1_0 = bb.BinaryLut.from_sparse_model(lut_layer1_0)
bin_lut1_1 = bb.BinaryLut.from_sparse_model(lut_layer1_1)
bin_lut2_0 = bb.BinaryLut.from_sparse_model(lut_layer2_0)
bin_lut2_1 = bb.BinaryLut.from_sparse_model(lut_layer2_1)
bin_lut3_0 = bb.BinaryLut.from_sparse_model(lut_layer3_0)
bin_lut3_1 = bb.BinaryLut.from_sparse_model(lut_layer3_1)

print(bin_lut0_0)
print(bin_lut0_1)
frame_modulation_size = 7

test_net= bb.Sequential([
           #     bb.RealToBinary(frame_modulation_size=frame_modulation_size, bin_dtype=bin_dtype),
                bb.Sequential([bin_lut0_0, bin_lut0_1]), bb.Sequential([bin_lut1_0, bin_lut1_1]),  
                bb.Sequential([bin_lut2_0, bin_lut2_1]), bb.Sequential([bin_lut3_0, bin_lut3_1]),
         #       bb.BinaryToReal(frame_integration_size=frame_modulation_size)
])
print(test_net.get_info())

test_net.set_input_shape([1, 28, 28])

test_loss    = bb.LossSoftmaxCrossEntropy()
test_metrics = bb.MetricsCategoricalAccuracy()

loss.clear()
metrics.clear()
for images, labels in tqdm(loader_test):
    x_buf = bb.FrameBuffer.from_numpy(np.array(images).astype(np.float32))
   # print(x_buf.get_type())
    
    t_buf = bb.FrameBuffer.from_numpy(np.identity(10)[np.array(labels)].astype(np.float32))
    y_buf = test_net.forward(x_buf, train=False)
    #print(y_buf.get_type())
    test_loss.calculate(y_buf, t_buf)
    test_metrics.calculate(y_buf, t_buf)

print('Binary LUT test : loss=%f accuracy=%f' % (test_loss.get(), test_metrics.get()))


----------------------------------------------------------------------
[BinaryLut6] 
 input  shape : {1, 28, 28} output shape : {216}
 connection : random
----------------------------------------------------------------------

----------------------------------------------------------------------
[BinaryLut6] 
 input  shape : {216} output shape : {36}
 connection : random
----------------------------------------------------------------------

----------------------------------------------------------------------
[Sequential] 
 input  shape : [1, 28, 28] output shape : [10]
  --------------------------------------------------------------------
  [Sequential] 
   input  shape : [1, 28, 28]   output shape : [36]
    ------------------------------------------------------------------
    [BinaryLut6] 
     input  shape : {1, 28, 28} output shape : {216}
     connection : random
    ------------------------------------------------------------------
    [BinaryLut6] 
     input  shape : {216}

  0%|          | 0/313 [00:00<?, ?it/s]

  x_buf = bb.FrameBuffer.from_numpy(np.array(images).astype(np.float32))
  t_buf = bb.FrameBuffer.from_numpy(np.identity(10)[np.array(labels)].astype(np.float32))


Binary LUT test : loss=1.583333 accuracy=0.810200
