In [14]:
import csv
import random

filename = "test_data/input_bn.csv"
B = 1<<5
d = 100

with open(filename,'w') as csvred:
    spam = csv.writer(csvred,delimiter=' ')

    line = [] # x
    for b in range(0,B):
        for i in range(0,d):
            line.append(random.randint(10,80)/random.randint(10,80))
    for b in range(0,B):
        for i in range(0,d):
            line.append(random.randint(10,30)/random.randint(10,30))
    spam.writerow(line)
            

In [20]:
import math
import numpy as np
import torch
import csv

IT_N = 3
def inverse_sqrt(x):
    init_g = np.exp(-(x/2+0.2))*2 + 0
    init_g -= x/1024

    for i in range(IT_N):
        init_g = init_g*(3-x * init_g * init_g)/2
    return init_g
    
class BN:
    def __init__(self, dims, gamma=0, beta=0):
        self.eps = 1e-5
        self.gamma = np.ones((dims, ), dtype="float32")
        self.beta = np.zeros((dims, ), dtype="float32")
        
        self.inv_sqrt = None
        self.norm_x = None

        self.beta_grad = None
        self.gamma_grad = None
        self.act_grad = None

    def forward(self, x):
        mean = np.mean(x, axis=0)   # 1 truncation by batchSize [1, D]
        # print(mean)
        x_mean = x - mean   # [B, D]
        # print(x_mean)
        var = np.mean(x_mean * x_mean, axis=0)  # 1 multiplication, 1 truncation by batchsize [1, D]
        # print(var)
        var_eps = var + self.eps
        # print(var_eps)

        # protocol inv_sqrt
        self.inv_sqrt = 1. / np.sqrt(var_eps)   # 1 inverse sqrt [1, D]
        # print("======")
        # print(var_eps)
        # print(self.inv_sqrt)
        # print(inverse_sqrt(var_eps))
        self.norm_x = x_mean * self.inv_sqrt    # 1 multiplication [B, D] * [1, D]. Falcon here has bug.

        return self.gamma * self.norm_x + self.beta     # 1 multiplication

    def backward(self, grad):
        B, D = grad.shape
        self.beta_grad = np.sum(grad, axis=0)
        self.gamma_grad = np.sum(self.norm_x * grad, axis=0)    # 1 multiplication

        dxhat = grad * self.gamma   # 1 multiplication
        self.act_grad = self.inv_sqrt * \
                        (B*dxhat - np.sum(dxhat, axis=0) - self.norm_x * np.sum(dxhat * self.norm_x, axis=0)) \
                        / B # 3 multiplication, 1 truncation

        self.gamma = self.gamma - self.gamma_grad*0.03125
        self.beta = self.beta - self.beta_grad*0.03125
        return self.act_grad, self.gamma_grad, self.beta_grad

# D = [100,1000]
# B = 4
# d = 3
filename = "test_data/input_bn.csv"
outf = "test_data/output_bn_plain.csv"
# j = 0
with open(filename,'r') as csvred:
    spam = csv.reader(csvred,delimiter=' ')
    for row in spam:
        data = []
        grad = []
        for b in range(0,B):
            line = []
            for i in range(0,d):
                line.append(row[b*d+i])
            data.append(line)
        for b in range(0,B):
            line = []
            for i in range(0,d):
                line.append(row[ B*d + b*d + i])
            grad.append(line)
        data = np.array(data).astype(np.float32)
        x_raw = torch.from_numpy(data)
        x = x_raw.numpy()
        
        grad = np.array(grad).astype(np.float32)
        grad = torch.from_numpy(grad).numpy()
        
        # print(data)
        # print(grad)
    
        bn = BN(d)
        x_forward = bn.forward(x)
        x_grad, gamma_grad, beta_grad = bn.backward(grad)
        # print(x_forward)
        # print(x_grad)
        # print(gamma_grad)
        # print(beta_grad)
        with open(outf,'w') as csvred:
            spam = csv.writer(csvred)
            spam.writerows(x_forward)
            spam.writerows(x_grad)
            spam.writerow(gamma_grad)
            spam.writerow(beta_grad)
        # x_forward, x_grad, gamma_grad, beta_grad
        


In [21]:
# caculate acc
import numpy as np
import csv
from sklearn import metrics

def acc(x1,x2):
    xmse = metrics.mean_squared_error(x1, x2)
    xmae = metrics.mean_absolute_error(x1, x2)    
    print(xmse)
    print(xmae)

# x_forward,x_grad, gamma_grad, beta_grad
out1 = "test_data/output_bn_plain.csv"
out2 = "test_data/output_bn_l.csv"
xf1 = []
xf2 = []
xg1 = []
xg2 = []
gg1 = []
gg2 = []
bg1 = []
bg2 = []

B = 1<<5
d = 100


with open(out1,'r') as csvred:
    spam = list(csv.reader(csvred,delimiter=','))
    for b in range(0,B):
        for i in range(0,d):
            xf1.append(float(spam[b][i]))
    for b in range(0,B):
        for i in range(0,d):
            xg1.append(float(spam[b+B][i]))
    for i in range(0,d):
        gg1.append(float(spam[2*B][i]))
    for i in range(0,d):
        bg1.append(float(spam[2*B+1][i]))
# print(xf1)
# print(xg1)
# print(gg1)
# print(bg1)

with open(out2,'r') as csvred:
    spam = list(csv.reader(csvred,delimiter=','))
    for b in range(0,B):
        for i in range(0,d):
            xf2.append(float(spam[b][i]))
    for b in range(0,B):
        for i in range(0,d):
            xg2.append(float(spam[b+B][i]))
    for i in range(0,d):
        gg2.append(float(spam[2*B][i]))
    for i in range(0,d):
        bg2.append(float(spam[2*B+1][i]))

# print(xf2)
# print(xg2)
# print(gg2)
# print(bg2)

acc(xf1,xf2)
acc(xg1,xg2)
acc(gg1,gg2)
acc(bg1,bg2)

10.046148684402512
0.6443539869589062
0.4632511726820433
0.28873576541717155
318.1356745708251
4.327837305729999
2.770577780000388e-06
0.0016532000000001545
