In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.autograd import Function
import numpy as np
import torchvision
from BNN import *
from spikingjelly.activation_based import functional

In [2]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='/home/curry/code', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='/home/curry/code', train=False, transform=transform, download=True)

In [3]:
class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = BinaryConv2d(1, 2, kernel_size=3, stride=1, padding=0, bias=False)
        self.conv2 = BinaryConv2d(2, 4, kernel_size=3, stride=1, padding=0, bias=False)
        self.sn1 = BinaryActivation()
        self.sn2 = BinaryActivation()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = BinaryLinear(4*12*12, 10, bias=False)
    def forward(self, x):
        #大于127的变成1，小于等于127的变成0
        x = torch.where(x>0.5, torch.tensor(1.), torch.tensor(-1.))
        x = self.conv1(x)
        #x = self.pool1(x)
        #print(x.int())
        #x = self.pool1(x)
        #print(x.int())
        # x = torch.where(x>0, torch.tensor(1.), torch.tensor(-1.))
        x = self.sn1(x)
        #print(x.int())
        x = self.conv2(x)
        #print(x.int())
        #x = self.pool2(x)
        #x = self.relu2(x)
        #print(x.int())
        x = self.pool2(x) 
        x = self.sn2(x)
        #print(x.shape)
        # x = torch.where(x>0, torch.tensor(1.), torch.tensor(-1.))
        #x = self.conv3(x)
        #x = self.pool3(x)
        #x = torch.where(x>0, torch.tensor(1.), torch.tensor(-1.))
        #print(x.int())
        x = x.view(x.size(0), -1)
        #print(x.int())
        x = self.fc1(x)
        # x = self.relu3(x)
        # x = self.fc2(x)
        # x = self.relu3(x)
        # x = self.fc3(x)
        return x

In [13]:
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)
net = LeNet()
from tqdm import tqdm
np.int = int
classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
criterion = nn.CrossEntropyLoss()
LR = 0.001

optimizer = optim.Adam(net.parameters(), lr=LR)
scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)
EPOCH = 10
device = torch.device('cuda:0')
net.to(device)
for epoch in range(EPOCH):
    net.train()
    for i ,data in enumerate(dataloader):
        inputs, labels = data
        inputs = inputs.to(device).cuda()
        labels = labels.to(device).cuda()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        functional.reset_net(net)
        if i % 100 == 0:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, loss.item()))
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i,data in enumerate(testloader):
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            if(100*correct/total>=90):
                torch.save(net.state_dict(), 'weight/lenet.pth')
    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

[1,     1] loss: 2.440
[1,   101] loss: 0.605
[1,   201] loss: 0.581
[1,   301] loss: 0.311
[1,   401] loss: 0.351
Accuracy of the network on the 10000 test images: 88 %
[2,     1] loss: 0.306
[2,   101] loss: 0.401
[2,   201] loss: 0.430
[2,   301] loss: 0.453
[2,   401] loss: 0.307
Accuracy of the network on the 10000 test images: 88 %
[3,     1] loss: 0.314
[3,   101] loss: 0.310
[3,   201] loss: 0.424
[3,   301] loss: 0.302
[3,   401] loss: 0.351
Accuracy of the network on the 10000 test images: 89 %
[4,     1] loss: 0.230
[4,   101] loss: 0.462
[4,   201] loss: 0.330
[4,   301] loss: 0.550
[4,   401] loss: 0.364
Accuracy of the network on the 10000 test images: 89 %
[5,     1] loss: 0.378
[5,   101] loss: 0.334
[5,   201] loss: 0.286
[5,   301] loss: 0.321
[5,   401] loss: 0.342
Accuracy of the network on the 10000 test images: 89 %
[6,     1] loss: 0.277
[6,   101] loss: 0.263
[6,   201] loss: 0.450
[6,   301] loss: 0.238
[6,   401] loss: 0.319
Accuracy of the network on the 1000

In [14]:
class scale_Bconvd(BinaryConv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
        super(scale_Bconvd, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
    def forward(self, x):
        w = self.weight
        bw = BinaryWeight.apply(w)
        scaling_factor = torch.mean(torch.mean(torch.mean(torch.mean(abs(w),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True),dim=0,keepdim=True)
        scaling_factor = scaling_factor.detach()
        #print(scaling_factor)
        # bw = scaling_factor * BinaryWeight.apply(w)
    
        return F.conv2d(x, bw, self.bias, self.stride,
                    self.padding, self.dilation, self.groups)
class scale_Blinear(BinaryLinear):
    def __init__(self, in_features, out_features, bias=False):
        super(scale_Blinear, self).__init__(in_features, out_features, bias)
        
    def forward(self, x):

        w = self.weight
        bw = BinaryWeight.apply(w)
        # print(w.shape)
        scaling_factor = torch.mean(torch.mean(abs(w),dim=1,keepdim=True),dim=0,keepdim=True)
        scaling_factor = scaling_factor.detach()
        #print(scaling_factor)
        # bw = scaling_factor * BinaryWeight.apply(w)
        
        return F.linear(x, bw, self.bias)
    
class scale_leNet(nn.Module):
    def __init__(self, num_classes=10, T=4):
        super().__init__()
        self.T = T
        self.conv1 = scale_Bconvd(1, 2, kernel_size=3, stride=1, padding=0, bias=False)
        self.sn1 = BinaryActivation()
        self.conv2 = scale_Bconvd(3, 2, kernel_size=3, stride=1, padding=0, bias=False)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.sn2 = BinaryActivation()
        self.fc1 = scale_Blinear(4*12*12, 10, bias=False)
        # self.relu3 = nn.ReLU()
        # self.fc2 = scale_Blinear(120, 84, bias=False)
        # self.relu4 = nn.ReLU()
        # self.fc3 = scale_Blinear(84, num_classes, bias=False)
    def forward(self, x):
        x = torch.where(x>0.5, torch.tensor(1.), torch.tensor(-1.))
        x = self.conv1(x)
        #print(x.int())
        #x = self.pool1(x)
        #print(x.int())
        x = self.sn1(x)
        #print(x.int())
        x = self.conv2(x)
        #print(x.int())
        #x = self.pool2(x)
        #x = self.relu2(x)
        #print(x.int())
        x = self.pool2(x)
        x = self.sn2(x)
        #x = self.conv3(x)
        #x = self.pool3(x)
        #x = torch.where(x>0, torch.tensor(1.), torch.tensor(-1.))
        #print(x.int())
        x = x.view(x.size(0), -1)
        #print(x.int())
        x = self.fc1(x)
        # x = self.relu3(x)
        # x = self.fc2(x)
        # x = self.relu3(x)
        # x = self.fc3(x)
        return x

In [15]:
#将权重加载，二值化，保存为另外一个pth文件
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)
net = LeNet()
device = torch.device('cuda:0')
net.to(device)
net.load_state_dict(torch.load('weight/lenet.pth',weights_only=True))
for name, param in net.named_parameters():
    if 'weight' in name:
        param.data = BinaryWeight.apply(param.data)
torch.save(net.state_dict(), 'weight/lenet_binary.pth')
#加载二值化后的权重
net.load_state_dict(torch.load('weight/lenet_binary.pth',weights_only=True))
net.to(device)
net.eval()
correct = 0
total = 0
with torch.no_grad():
    for i,data in enumerate(testloader):
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

Accuracy of the network on the 10000 test images: 90 %


In [19]:
#将网络第一层的权重导出
conv1_weight = net.conv1.weight.data
#print(conv1_weight)
conv1_weight = conv1_weight.view(-1)
#print(conv1_weight.shape)
conv1_weight_int = conv1_weight.int()
#print(conv1_weight_int)
#将第一层的权重导出为txt文件
with open('test_array/test_conv1_weight_txt.txt', 'w') as f:
    for weight in conv1_weight_int:
        weight = weight.item()
        if weight < 0:
            weight = 0
        f.write(format(weight, 'b') + '\n')
#将网络第二层的权重导出
conv2_weight = net.conv2.weight.data
#print(conv2_weight.shape)
#print(conv2_weight)
conv2_weight = conv2_weight.view(-1)
#print(conv2_weight.shape)
conv2_weight_int = conv2_weight.int()
#print(conv2_weight_int)
#将第二层的权重导出为txt文件
with open('test_array/test_conv2_weight_txt.txt', 'w') as f:
    for weight in conv2_weight_int:
        weight = weight.item()
        if weight < 0:
            weight = 0
        f.write(format(weight, 'b') + '\n')
#将网络第三层的权重导出
fc1_weight = net.fc1.weight.data
print(fc1_weight)
print(fc1_weight.shape)
fc1_weight = fc1_weight.view(-1)
print(fc1_weight.shape)
fc1_weight_int = fc1_weight.int()
print(fc1_weight_int)
#将第三层的权重导出为txt文件
for i in range(10):
    with open('test_array/test_fc1_weight_txt'+str(i)+'.txt', 'w') as f:
        for weight in fc1_weight_int[i*576:i*576+576]:
            weight = weight.item()
            if weight < 0:
                weight = 0
            f.write(format(weight, 'b') + '\n')
    

tensor([[ 1.,  1., -1.,  ..., -1., -1., -1.],
        [-1., -1.,  1.,  ..., -1.,  1.,  1.],
        [ 1.,  1.,  1.,  ...,  1.,  1., -1.],
        ...,
        [ 1., -1.,  1.,  ..., -1., -1., -1.],
        [-1., -1.,  1.,  ..., -1., -1., -1.],
        [ 1.,  1., -1.,  ...,  1.,  1.,  1.]], device='cuda:0')
torch.Size([10, 576])
torch.Size([5760])
tensor([ 1,  1, -1,  ...,  1,  1,  1], device='cuda:0', dtype=torch.int32)


In [21]:
class watch_scale_leNet(nn.Module):
    def __init__(self, num_classes=10, T=4):
        super().__init__()
        self.T = T
        self.conv1 = scale_Bconvd(1, 2, kernel_size=3, stride=1, padding=0, bias=False)
        self.sn1 = BinaryActivation()   
        self.conv2 = scale_Bconvd(2, 4, kernel_size=3, stride=1, padding=0, bias=False)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.sn2 = BinaryActivation()
        self.fc1 = scale_Blinear(4*12*12, 10, bias=False)
    def forward(self, x):
        x = torch.where(x>127, torch.tensor(1.), torch.tensor(-1.))
        #print(x.shape)
        #print(x.int())
        x = self.conv1(x)
        x = self.sn1(x)
        conv1_out_int = torch.where(x==1., torch.tensor(1.), torch.tensor(0.)).int()
        with open('test_array/test_conv1_output_txt.txt', 'w') as f:
            for i in range(26):
                for j in range(26):
                    f.write(format(conv1_out_int[0][0][i][j].item(), 'b') + '\n')
        # print(x.shape)
        # print(x.int())
        x = self.conv2(x)
        print(x.shape)
        print(x.int())
        #print(self.conv2.weight)
        x = self.pool2(x)
        x = self.sn2(x)
        print(x.shape)
        # print(x.int())
        x = x.view(x.size(0), -1)
        print(x.shape)
        print(x.int())
        x_b = torch.where(x==1., torch.tensor(1.), torch.tensor(0.))
        x_b_int = x_b.int()
        with open('test_array/test_fc1_input_txt.txt', 'w') as f:
            for i in range(4*12*12):
                f.write(format(x_b_int[0][i].item(), 'b') + '\n')
        #print(self.fc1.weight[0])
        x = self.fc1(x)
        # x = self.relu3(x)
        # x = self.fc2(x)
        # x = self.relu4(x)
        # x = self.fc3(x)
        return x

net = watch_scale_leNet()
net.load_state_dict(torch.load('weight/lenet_binary.pth',weights_only=True))
image, label = train_dataset[0]
image_b = torch.where(image>0.5, torch.tensor(1.), torch.tensor(0.))
image_b_int = image_b.int()
with open('test_array/test_image_b_txt.txt', 'w') as f:
    for i in range(28):
        for j in range(28):
            f.write(format(image_b_int[0][i][j].item(), 'b') + '\n')
image = image*255
image_int = image.int()
with open('test_array/test_image_txt.txt', 'w') as f:
    for i in range(28):
        for j in range(28):
            f.write(format(image_int[0][i][j].item(), '08b') + '\n')
image = image
print(image.shape)
output = net(image.unsqueeze(0))
print(output.int())


torch.Size([1, 28, 28])
torch.Size([1, 4, 24, 24])
tensor([[[[ 16,  16,  16,  ...,  16,  16,  16],
          [ 16,  16,  16,  ...,  16,  16,  16],
          [ 16,  16,  16,  ...,  16,  16,  16],
          ...,
          [ 14,  10,   6,  ...,  16,  16,  16],
          [ 12,  10,   2,  ...,  16,  16,  16],
          [ 12,  10,   2,  ...,  16,  16,  16]],

         [[-10, -10, -10,  ..., -10, -10, -10],
          [-10, -10, -10,  ..., -10, -10, -10],
          [-10, -10, -10,  ..., -10, -10, -10],
          ...,
          [ -8,  -4,  -4,  ..., -10, -10, -10],
          [ -6,   0,   8,  ..., -10, -10, -10],
          [ -6,  -4,   0,  ..., -10, -10, -10]],

         [[ -4,  -4,  -4,  ...,  -4,  -4,  -4],
          [ -4,  -4,  -4,  ...,  -4,  -4,  -4],
          [ -4,  -4,  -4,  ...,  -4,  -4,  -4],
          ...,
          [ -2,   2,   6,  ...,  -4,  -4,  -4],
          [ -4,   2,   6,  ...,  -4,  -4,  -4],
          [ -4,  -6,  -2,  ...,  -4,  -4,  -4]],

         [[ -2,  -2,  -2,  ...,  -

In [11]:
import torch
import torch.nn as nn
weight_conv2 = torch.tensor([[-1.,-1.,-1.],[-1.,1.,-1.],[-1.,1.,-1.]])
test_conv2 = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=0, bias=False)
with torch.no_grad():
    test_conv2.weight.copy_(weight_conv2)
#input来自一个 txt 文件
input_data = []
with open('/home/ygl/code/BNN_accelerator/train/test_array/test_conv1_output_txt.txt', 'r') as f:
    for line in f:
        input_data.append(int(line.strip(), 2))
input_tensor = torch.tensor(input_data, dtype=torch.float32).view(1, 1, 26, 26)
input_tensor = torch.where(input_tensor==1., torch.tensor(1.), torch.tensor(-1.))
output = test_conv2(input_tensor)
print(output)




tensor([[[[ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,
            5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
          [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,
            5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
          [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,
            5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
          [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  3.,
            5.,  3.,  3.,  3.,  3.,  3.,  5.,  3.,  5.,  5.],
          [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  3.,  5.,  3.,  3.,  3.,  3.,  1.,
            3.,  1.,  1.,  1.,  1.,  3.,  3.,  3.,  5.,  5.],
          [ 5.,  5.,  5.,  5.,  3.,  5.,  3.,  1.,  3.,  1.,  1.,  1.,  1., -1.,
           -3., -3., -5., -3., -3., -1., -1.,  3.,  5.,  5.],
          [ 5.,  5.,  5.,  5.,  1.,  5.,  1., -1., -3., -5., -5., -5., -5., -5.,
           -5., -1., -5., -1., -1.,  1.,  3.,  5.,  5.,  5.],