In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import os
from torchsummary import summary
import random

seed = 123
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [29]:
is_transform = True
is_shuffle = False
batch_size = 64

## Load data

In [30]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     ]) if not is_transform else transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)


path = f'{os.path.expanduser("~")}/Downloads'
trainset = datasets.CIFAR10(root=path, train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=is_shuffle)

testset = datasets.CIFAR10(root=path, train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [31]:
# import matplotlib.pyplot as plt
# import numpy as np

# def imshow(img):
#     img = img / 2 + 0.5     # unnormalize
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()

# dataiter = iter(trainloader)
# images, labels = next(dataiter)
# print(images[0]*255.0)

# # get some random training images
for i, x in enumerate(trainloader, 0):
    img, label = x
    print(img.shape)
    print(img[0][0][0])
    print(label[0])
    break

torch.Size([64, 3, 32, 32])
tensor([-0.5373, -0.6627, -0.6078, -0.4667, -0.2314, -0.0667,  0.0902,  0.1373,
         0.1686,  0.1686,  0.0275, -0.0196,  0.1137,  0.1294,  0.0745,  0.0118,
         0.0745,  0.0510, -0.0275,  0.0902,  0.0902,  0.0431,  0.0667,  0.0902,
         0.1922,  0.2784,  0.3176,  0.2471,  0.2392,  0.2392,  0.1922,  0.1608])
tensor(6)


## Define Model

In [32]:
class AlexNet_CryptGPU(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet_CryptGPU, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=9),
            # nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=96, momentum=0),

            nn.Conv2d(96, 256, kernel_size=5, padding=1),
            # nn.ReLU(inplace=True),
            # nn.BatchNorm2d(num_features=256),
            nn.MaxPool2d(kernel_size=2, stride=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=256, momentum=0),

            
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

        if num_classes == 10:
            self.fc_layers = nn.Sequential(
                nn.Flatten(),
                nn.Linear(256, 256),
                nn.ReLU(inplace=True),
                nn.Linear(256, 256),
                nn.ReLU(inplace=True),
                nn.Linear(256, 10),
            )
        elif num_classes == 200:
            self.fc_layers = nn.Sequential(
                nn.AvgPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(1024, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, 200),
            )
        elif num_classes == 1000:
            self.fc_layers = nn.Sequential(
                nn.AvgPool2d(kernel_size=4),
                nn.Flatten(),
                nn.Linear(9216, 4096),
                nn.ReLU(),
                nn.Linear(4096, 4096),
                nn.ReLU(),
                nn.Linear(4096, 1000),
            )

    def forward(self, x):
        x = self.features(x)
        x = self.fc_layers(x)
        return x
    
class AlexNet_Falcon(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet_Falcon, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=9),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=96),

            nn.Conv2d(96, 256, kernel_size=5, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=256),

            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

        if num_classes == 10:
            self.fc_layers = nn.Sequential(
                nn.Flatten(),
                nn.Linear(256, 256),
                nn.ReLU(inplace=True),
                nn.Linear(256, 256),
                nn.ReLU(inplace=True),
                nn.Linear(256, 10),
                # nn.ReLU(inplace=True),
            )
        elif num_classes == 200:
            self.fc_layers = nn.Sequential(
                nn.AvgPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(1024, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, 200),
            )
        elif num_classes == 1000:
            self.fc_layers = nn.Sequential(
                nn.AvgPool2d(kernel_size=4),
                nn.Flatten(),
                nn.Linear(9216, 4096),
                nn.ReLU(),
                nn.Linear(4096, 4096),
                nn.ReLU(),
                nn.Linear(4096, 1000),
            )

    def forward(self, x):
        x = self.features(x)
        x = self.fc_layers(x)
        return x

class AlexNet_Official(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet_Official, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 2 * 2, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 2 * 2)
        x = self.classifier(x)
        return x
    
class AlexNet_Official_modify(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet_Official_modify, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=7, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AvgPool2d(kernel_size=2),
        )
        self.classifier = nn.Sequential(
            # nn.Dropout(),
            nn.Linear(256 * 2 * 2, 4096),
            nn.ReLU(inplace=True),
            # nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 2 * 2)
        x = self.classifier(x)
        return x


class AlexNet_SPDZ(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet_SPDZ, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=9),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=96, momentum=1),

            nn.Conv2d(96, 256, kernel_size=5, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=256, momentum=1),
            nn.MaxPool2d(kernel_size=2, stride=1),
            # nn.ReLU(inplace=True),
            # nn.BatchNorm2d(num_features=256),

            
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        if num_classes == 10:
            self.fc_layers = nn.Sequential(
                nn.Flatten(),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Linear(256, 10),
            )
        elif num_classes == 200:
            self.fc_layers = nn.Sequential(
                nn.AvgPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(1024, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, 200),
            )
        elif num_classes == 1000:
            self.fc_layers = nn.Sequential(
                nn.AvgPool2d(kernel_size=4),
                nn.Flatten(),
                nn.Linear(9216, 4096),
                nn.ReLU(),
                nn.Linear(4096, 4096),
                nn.ReLU(),
                nn.Linear(4096, 1000),
            )

    def forward(self, x):
        x = self.features(x)
        x = self.fc_layers(x)
        return x
    

In [33]:
from tensorboardX import SummaryWriter
logger = SummaryWriter(log_dir = 'log')

USE_SPDZ = True

model_fn = AlexNet_SPDZ if USE_SPDZ else AlexNet_CryptGPU

model = model_fn(num_classes=10)
summary(model, (3, 32, 32))

# GPU run
gpu_list = [0, 1, 2, 3]
gpu_list_str = ','.join(map(str, gpu_list))
os.environ.setdefault("CUDA_VISIBLE_DEVICES", gpu_list_str)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.nn.DataParallel(model)
model.to(device)

def evaluate(model):
    # model.eval()
    total = 0
    correct = 0
    for batch, data in enumerate(testloader, 0):
        images, labels = data
        output = model(images.to(device))
        
        correct += torch.sum(torch.argmax(output, dim=1) == labels.to(device))
        total += len(images)
    return correct * 1.0 / total

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 256, 1, 1]           --
|    └─Conv2d: 2-1                       [-1, 96, 10, 10]          34,944
|    └─ReLU: 2-2                         [-1, 96, 10, 10]          --
|    └─MaxPool2d: 2-3                    [-1, 96, 4, 4]            --
|    └─BatchNorm2d: 2-4                  [-1, 96, 4, 4]            192
|    └─Conv2d: 2-5                       [-1, 256, 2, 2]           614,656
|    └─ReLU: 2-6                         [-1, 256, 2, 2]           --
|    └─BatchNorm2d: 2-7                  [-1, 256, 2, 2]           512
|    └─MaxPool2d: 2-8                    [-1, 256, 1, 1]           --
|    └─Conv2d: 2-9                       [-1, 384, 1, 1]           885,120
|    └─ReLU: 2-10                        [-1, 384, 1, 1]           --
|    └─Conv2d: 2-11                      [-1, 384, 1, 1]           1,327,488
|    └─ReLU: 2-12                        [-1, 384, 1, 1]      

## Hyper-parameters

In [34]:
# hyper-parameters
epochs = 10

lr = 1. / (1 << 4)
use_softmax = False
use_MSE = False

momentum = 0.8
# CE This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class.
criterion = nn.MSELoss() if use_MSE else nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

debug_epochs = 1
debug_max_iter = 1

# AlexNet Official
# lr=0.001, momentum=0.9 converges about 2~3 epoch. Without momentum, it does not converge.
# lr=0.01 (remove Dropout) converges about 3~4 epoch.

# AlexNet Official Modified: MaxPooling --> AvgPooing, remove Dropout
# lr=0.01 converges

# AlexNet CryptGPU
# lr=0.01. momentum=0.9 not converge. No BN, remove the last avgpooling layer.

# AlexNet Falcon
# lr=0.01 converges

## Debug code

In [286]:
for i, x in enumerate(trainloader, 0):
    img, label = x
    print(img.shape)
    print(img[0][0][0])
    print(label[0])
    break

print(f'Use shuffle: {is_shuffle}, use transform: {is_transform}, batch size: {batch_size}, lr: {lr}, model_SPDZ: {USE_SPDZ}')

countMax = 0
outputs = []
for epoch in range(debug_epochs):
    for batch, data in enumerate(trainloader, 0):
        print(f'---------------{countMax}')
        images, labels = data

        x = images.to(device)
        for i in range(len(model.features)):
            # if isinstance(model.features[i], nn.Conv2d):
            #     print(f'Conv2d input: {x[0]}')
                # print(f'Conv2d weights: {model.features[i]}')
            # elif isinstance(model.features[i], nn.ReLU):
            #     print(f'ReLU input: {x[0]}')
            # if isinstance(model.features[i], nn.MaxPool2d):
            #     print(f'MaxPool2d input: {x[0]}')
            # elif isinstance(model.features[i], nn.BatchNorm2d):
            #     print(f'BatchNorm2d input: {x[0]}')
            if i == 7:
                maxPool_input = x
                x = model.features[i](maxPool_input)
            elif i == 6:
                bn_input = x
                x = model.features[i](bn_input)
            elif i == 5:
                relu_input_feature_2 = x
                x = model.features[i](relu_input_feature_2)
            elif i == 8:
                cnn_input_3 = x
                # print(f'CNN input: {cnn_input}')
                x = model.features[i](cnn_input_3)
                # print(f'CNN output: {x}')
            elif i == 0:
                cnn_input_1 = x
                # print(f'CNN input: {cnn_input}')
                x = model.features[i](cnn_input_1)
                # print(f'CNN output: {x}')
            elif i == 4:
                cnn_input_2 = x
                # print(f'CNN input: {cnn_input}')
                x = model.features[i](cnn_input_2)
                # print(f'CNN output: {x}')
            elif i == 1:
                relu_input_feature_1 = x
                # print(f'CNN input: {cnn_input}')
                x = model.features[i](relu_input_feature_1)
                # print(f'CNN output: {x}')
            elif i == 3:
                bn_input_1 = x
                x = model.features[i](bn_input_1)
            elif i == 9:
                relu_input_feature_3 = x
                # print(f'ReLU input: {relu_input}')
                x = model.features[i](relu_input_feature_3)
                # print(f'ReLU output: {x}')
            elif i == 11:
                relu_input_11 = x
                # print(f'ReLU input: {relu_input}')
                x = model.features[i](relu_input_11)
                # print(f'ReLU output: {x}')
            elif i == 13:
                relu_input_13 = x
                # print(f'ReLU input: {relu_input}')
                x = model.features[i](relu_input_13)
                # print(f'ReLU output: {x}')
            else:
                x = model.features[i](x)
            # if isinstance(model.features[i], nn.Conv2d):
            #     print(f'Conv2d output: {x[0]}')
            # elif isinstance(model.features[i], nn.ReLU):
            #     print(f'ReLU output: {x[0]}')
            # if isinstance(model.features[i], nn.MaxPool2d):
            #     print(f'MaxPool2d output: {x[0]}')
            # elif isinstance(model.features[i], nn.BatchNorm2d):
            #     print(f'BatchNorm2d mean: {model.features[i].running_mean.data}')
            #     print(f'BatchNorm2d var: {model.features[i].running_var.data}')
            #     print(f'BatchNorm2d output: {x[0]}')
        for i in range(len(model.fc_layers)):
            # if isinstance(model.fc_layers[i], nn.Conv2d):
            #     print(f'Conv2d input: {x[0]}')
            #     print(f'Conv2d weights: {model.fc_layers[i]}')
            # elif isinstance(model.fc_layers[i], nn.ReLU):
            #     print(f'ReLU input: {x[0]}')
            # elif isinstance(model.fc_layers[i], nn.MaxPool2d):
            #     print(f'MaxPool2d input: {x[0]}')
            # elif isinstance(model.fc_layers[i], nn.BatchNorm2d):
            #     print(f'BatchNorm2d input: {x[0]}')
            # if isinstance(model.fc_layers[i], nn.Linear):
            #     print(f'Linear input: {x[0]}')
            # 多了一层nn.Flatten()
            if i == 2:
                relu_input_1 = x
                # print(f'Relu-1 input: {relu_input_1}')
                x = model.fc_layers[i](relu_input_1)
            elif i == 3:
                fc_input_2 = x
                # print(f'Relu-1 input: {relu_input_1}')
                x = model.fc_layers[i](fc_input_2)
            elif i == 1:
                fc_input_1 = x
                # print(f'Relu-1 input: {relu_input_1}')
                x = model.fc_layers[i](fc_input_1)
            # elif i == 5:
            #     fc_input = x
            #     x = model.fc_layers[i](fc_input)
            elif i == 4:
                relu_input_2 = x
                # print(f'Relu-2 input: {relu_input_2}')
                x = model.fc_layers[i](relu_input_2)
                # print(f'Relu-2 output: {x}')
            else:
                x = model.fc_layers[i](x)
            # if isinstance(model.fc_layers[i], nn.Conv2d):
            #     print(f'Conv2d output: {x[0]}')
            # elif isinstance(model.fc_layers[i], nn.ReLU):
            #     print(f'ReLU output: {x[0]}')
            # elif isinstance(model.fc_layers[i], nn.MaxPool2d):
            #     print(f'MaxPool2d output: {x[0]}')
            # elif isinstance(model.fc_layers[i], nn.BatchNorm2d):
            #     print(f'BatchNorm2d mean: {model.features[i].running_mean.data}')
            #     print(f'BatchNorm2d var: {model.features[i].running_var.data}')
            #     print(f'BatchNorm2d output: {x[0]}')
            if isinstance(model.fc_layers[i], nn.Linear):
                logger.add_histogram(f'FC-Linear-output-{i}', x, countMax)
                logger.add_histogram(f'FC-Linear-weight{i}', model.fc_layers[i].weight, countMax)
                logger.add_histogram(f'FC-Linear-bias{i}', model.fc_layers[i].bias, countMax)
                # print(f'Linear output: {x[0]}')
        
        # print('output: ', x.cpu()[0])
        # print("softmax output: ", torch.nn.functional.softmax(x.cpu(), dim=1)[0])
        # print(f'BN var: {bn_input.var([0, 2, 3])}')
        # print("loss: ", (torch.nn.functional.softmax(x.cpu(), dim=1) - torch.nn.functional.one_hot(labels, num_classes=10).float())[0])
        # print("grad: ", torch.sum(torch.nn.functional.softmax(x.cpu(), dim=1) - torch.nn.functional.one_hot(labels, num_classes=10).float(), dim=0))
        outputs.append(x.cpu())
        optimizer.zero_grad()

        maxPool_input.retain_grad()
        cnn_input_3.retain_grad()
        # fc_input.retain_grad()
        relu_input_feature_3.retain_grad()
        # BN_input.retain_grad()
        # relu_input_2.retain_grad()
        # relu_input_1.retain_grad()
        # relu_input_11.retain_grad()
        # relu_input_13.retain_grad()
        # bn_input.retain_grad()
        # fc_input_2.retain_grad()
        # relu_input_feature_2.retain_grad()
        bn_input_1.retain_grad()
        cnn_input_2.retain_grad()
        # output = x
        # output.retain_grad()
        loss = criterion(x, labels.to(device))
        loss.backward()

        # print(f'BN-2 delta: {(maxPool_input.grad*batch_size)[0]}')
        # print(f'maxpool-2 delta: {(cnn_input.grad*batch_size)[0]}')
        # print(f'Relu-6 delta: {BN_input.grad}')
        # print(f'relu-4 grad: {fc_input.grad}')  # correct
        # print(f'fc-3 grad {relu_input_4.grad.shape}: {relu_input_4.grad}')
        # print(f'cnn-5 grad: {relu_input_13.grad.shape} {(relu_input_13.grad[0]*batch_size)}')
        # print(f'cnn-4 grad: {relu_input_11.grad.shape} {(relu_input_11.grad[0]*batch_size)}')
        # print(f'cnn-3 grad: {relu_input.grad.shape} {(relu_input.grad[0]*batch_size)}')
        # print(f'relu-2 delta: {(bn_input.grad*batch_size)[0]}')
        # print(f'fc-1 delta: {(relu_input_1.grad*batch_size)[0]}')
        # print(f'fc-2 delta: {(relu_input_2.grad*batch_size)[0]}')
        # print(f'relu-1 delta: {(fc_input_2.grad*batch_size)[0]}')
        

        # print(f'Theo output grad{torch.nn.functional.softmax(output) - torch.nn.functional.one_hot(labels.to(device), num_classes=10)}')
        # print(f'Actual output grad: {output.grad}')

        # for i in range(len(model.features)):
        #     if isinstance(model.features[i], nn.Conv2d):
        #         print(f'Conv2d weights grad: {model.features[i].weight.grad[0]}')
        #         print(f'Conv2d bias grad: {model.features[i].bias.shape}, {model.features[i].bias.grad}')
            
        #     if isinstance(model.features[i], nn.BatchNorm2d):
        #         print(f'BatchNorm2d weights grad: {model.features[i].weight.grad[0]}')
        #         print(f'BatchNorm2d bias grad: {model.features[i].bias.shape}, {model.features[i].bias.grad}')
        # print(f'BN weights grad: {(model.features[6].weight.grad/lr)}')
        # print(f'BN bias grad: {(model.features[6].bias.grad*batch_size)}')

        # print(f'CNN weights grad: {model.features[8].weight.grad}')
        # for idx in range(3):
        #     print(f'cnn-4 channel-{idx} grad: {(relu_input_11.grad*batch_size)[:, idx, :,:]}')

        #     print(f'cnn-4 channel-{idx} grad sum: {(relu_input_11.grad*batch_size).sum([0, 2, 3])[idx]}')
        # print(f'CNN-5 bias grad: {model.features[12].bias.grad.shape}, {model.features[12].bias.grad*batch_size}')
        # print(f'CNN-4 bias grad: {model.features[10].bias.grad.shape}, {model.features[10].bias.grad*batch_size}')
        # print(f'CNN-3 bias grad: {model.features[8].bias.grad*batch_size}')
        

        # for i in range(len(model.fc_layers)):
        #     if isinstance(model.fc_layers[i], nn.Linear):
        #         # print(f'Linear weights grad: {model.fc_layers[i].weight.grad[0]}')
        #         print(f'Linear bias grad: {model.fc_layers[i].bias.grad}')
        # print(f'Linear weights grad: {model.fc_layers[5].weight.grad}')
        # print(f'Linear bias grad: {model.fc_layers[5].bias.grad}')
        # print(f'Expect bias grad: {torch.sum(output.grad, dim=0)}')

        
        acc = torch.sum(torch.argmax(x, dim=1) == labels.to(device)) * 1.0 / len(labels.to(device))
        optimizer.step()

        print('Batch {}, Loss: {}, Acc: {}'.format(countMax, loss.item(), acc))
        # print(f'Input: {images[0]}')
        # print(f'Output: {x[0]}')
        countMax += 1
        if countMax >= debug_max_iter:
            break
# torch.save((relu_input_feature_2.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-cnn2-delta.pt")
# torch.save((relu_input_11.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-cnn4-delta.pt")
# torch.save((relu_input_13.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-cnn5-delta.pt")
torch.save((relu_input_feature_3.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-cnn3-delta.pt")
# torch.save((relu_input_1.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-fc1-delta.pt")
# torch.save((fc_input_2.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-relu1-delta.pt")
torch.save((bn_input_1.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-mp1-delta.pt")
torch.save((cnn_input_2.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-bn1-delta.pt")
torch.save((maxPool_input.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-bn2-delta.pt")
# torch.save((fc_input_2), os.path.expanduser("~")+ "/DNN/output/plaintext-relu1-output.pt")
# torch.save((relu_input_1), os.path.expanduser("~")+ "/DNN/output/plaintext-relu1-input.pt")
torch.save((fc_input_1), os.path.expanduser("~")+ "/DNN/output/plaintext-fc1-input.pt")
# torch.save((cnn_input_1), os.path.expanduser("~")+ "/DNN/output/plaintext-cnn1-input.pt")
torch.save((cnn_input_3), os.path.expanduser("~")+ "/DNN/output/plaintext-cnn3-input.pt")
torch.save((cnn_input_2), os.path.expanduser("~")+ "/DNN/output/plaintext-cnn2-input.pt")
torch.save((bn_input), os.path.expanduser("~")+ "/DNN/output/plaintext-bn2-input.pt")
torch.save((maxPool_input), os.path.expanduser("~")+ "/DNN/output/plaintext-bn2-output.pt")
torch.save((maxPool_input), os.path.expanduser("~")+ "/DNN/output/plaintext-mp2-input.pt")
torch.save((cnn_input_3), os.path.expanduser("~")+ "/DNN/output/plaintext-mp2-output.pt")
torch.save((bn_input_1), os.path.expanduser("~")+ "/DNN/output/plaintext-bn1-input.pt")
torch.save((cnn_input_2), os.path.expanduser("~")+ "/DNN/output/plaintext-bn1-output.pt")
torch.save((relu_input_feature_3), os.path.expanduser("~")+ "/DNN/output/plaintext-cnn3-output.pt")
# torch.save((relu_input_feature_1), os.path.expanduser("~")+ "/DNN/output/plaintext-cnn1-output.pt")
# torch.save((relu_input_2.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-fc2-delta.pt")
torch.save((model.features[6].bias.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-bn2-bias_grad.pt")
torch.save((model.features[6].weight.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-bn2-weight_grad.pt")
torch.save((model.features[3].bias.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-bn1-bias_grad.pt")
torch.save((model.features[3].bias.grad*batch_size), os.path.expanduser("~")+ "/DNN/output/plaintext-bn1-weight_grad.pt")
torch.save(outputs, os.path.expanduser("~")+ "/DNN/output/plaintext-log.pt")

torch.Size([32, 3, 32, 32])
tensor([-0.5373, -0.6627, -0.6078, -0.4667, -0.2314, -0.0667,  0.0902,  0.1373,
         0.1686,  0.1686,  0.0275, -0.0196,  0.1137,  0.1294,  0.0745,  0.0118,
         0.0745,  0.0510, -0.0275,  0.0902,  0.0902,  0.0431,  0.0667,  0.0902,
         0.1922,  0.2784,  0.3176,  0.2471,  0.2392,  0.2392,  0.1922,  0.1608])
tensor(6)
Use shuffle: False, use transform: True, batch size: 32, lr: 0.0625, model_SPDZ: True
---------------0
Batch 0, Loss: 2.290346145629883, Acc: 0.1875


## Train

In [35]:
for i, x in enumerate(trainloader, 0):
    img, label = x
    print(img.shape)
    print(img[0][0][0])
    print(label[0])
    break

torch.Size([64, 3, 32, 32])
tensor([-0.5373, -0.6627, -0.6078, -0.4667, -0.2314, -0.0667,  0.0902,  0.1373,
         0.1686,  0.1686,  0.0275, -0.0196,  0.1137,  0.1294,  0.0745,  0.0118,
         0.0745,  0.0510, -0.0275,  0.0902,  0.0902,  0.0431,  0.0667,  0.0902,
         0.1922,  0.2784,  0.3176,  0.2471,  0.2392,  0.2392,  0.1922,  0.1608])
tensor(6)


In [36]:
acc_file_path = (
    os.path.expanduser("~")
    + "/DNN/output/"
    + "AlexNet_"
    + ("SPDZ" if USE_SPDZ else "Falcon")
    + ("_train_GPU" if True else "_train_CPU")
    + "_"
    + "CIFAR10"
    + "_"
    + ("Transform_" if is_transform else "")
    + ("Shuffle_" if is_shuffle else "")
    + str(batch_size) + "_"
    + str(lr) + "_"
    + str(momentum) + "_"
    + ("MSE_" if use_MSE else "CE_")
    + "acc"
    + "_" + str(epochs)
    + '-epoch'
    + ".txt"
)

loss_file_path = (
    os.path.expanduser("~")
    + "/DNN/output/"
    + "AlexNet_"
    + ("SPDZ" if USE_SPDZ else "Falcon")
    + ("_train_GPU" if True else "_train_CPU")
    + "_"
    + "CIFAR10"
    + "_"
    + ("Transform_" if is_transform else "")
    + ("Shuffle_" if is_shuffle else "")
    + str(batch_size) + "_"
    + str(lr) + "_"
    + str(momentum) + "_"
    + ("MSE_" if use_MSE else "CE_")
    + "loss"
    + "_" + str(epochs)
    + '-epoch'
    + ".txt"
)

accF = open(acc_file_path, 'w')
lossF = open(loss_file_path, 'w')

for epoch in range(epochs):
    model.train()
    for batch, data in enumerate(trainloader, 0):
        images, labels = data
        output = model(images.to(device))

        optimizer.zero_grad()
        # MSEloss
        if use_MSE:
            t_labels = torch.nn.functional.one_hot(labels, num_classes=10).float()
        else:
            t_labels = labels
        loss = criterion(output, t_labels.to(device))
        loss.backward()
        
        acc = torch.sum(torch.argmax(output, dim=1) == labels.to(device)) * 1.0 / len(labels.to(device))
        optimizer.step()

        accF.write(f'{epoch * len(trainloader) + batch}\t{acc*100.0}\n')
        lossF.write(f'{epoch * len(trainloader) + batch}\t{loss.item()}\n')

        if batch % 100 == 0:
            print('Epoch: {}, Batch {}, Loss: {}, Acc: {}'.format(epoch, batch, loss.item(), acc))
        
            # print(model.features[6].running_var.data[0:2])

    test_acc = evaluate(model)
    print('Epoch: {}, Test Acc: {}'.format(epoch, test_acc))
accF.close()
lossF.close()

Epoch: 0, Batch 0, Loss: 2.2982542514801025, Acc: 0.125
Epoch: 0, Batch 100, Loss: 2.3020567893981934, Acc: 0.09375
Epoch: 0, Batch 200, Loss: 2.304307222366333, Acc: 0.078125
Epoch: 0, Batch 300, Loss: 2.300875663757324, Acc: 0.0625
Epoch: 0, Batch 400, Loss: 2.302187204360962, Acc: 0.125
Epoch: 0, Batch 500, Loss: 2.3046960830688477, Acc: 0.109375
Epoch: 0, Batch 600, Loss: 2.3027138710021973, Acc: 0.125
Epoch: 0, Batch 700, Loss: 2.296396017074585, Acc: 0.125
Epoch: 0, Test Acc: 0.16089999675750732
Epoch: 1, Batch 0, Loss: 2.297051429748535, Acc: 0.09375
Epoch: 1, Batch 100, Loss: 2.1815133094787598, Acc: 0.1875
Epoch: 1, Batch 200, Loss: 1.87067711353302, Acc: 0.25
Epoch: 1, Batch 300, Loss: 2.122049331665039, Acc: 0.15625
Epoch: 1, Batch 400, Loss: 1.917684555053711, Acc: 0.1875
Epoch: 1, Batch 500, Loss: 1.837367296218872, Acc: 0.234375
Epoch: 1, Batch 600, Loss: 1.8541369438171387, Acc: 0.1875
Epoch: 1, Batch 700, Loss: 1.6632658243179321, Acc: 0.359375
Epoch: 1, Test Acc: 0.296

In [248]:
# test_sss = iter(testloader)
# images, labels = test_sss.next()
# images, labels = test_sss.next()
# output = model(images)
# _, output = torch.max(output, 1)

# correct = 0
# for i in range(len(output)):
#     if output[i] == labels[i]:
#         correct += 1
# print(f"acc: {correct * 1. / 128}")
# print(f"Output: ${output}")
# print(f"Target: ${labels}")

# Save Pretrained Parameters

In [32]:
params = [(name, p.data.cpu().numpy()) for (name, p) in model.named_parameters()]

for (name, p) in params:
    print(name)
    print(f"Layer {str(name.split('.')[0])+'.'+str(name.split('.')[1])}, type {name.split('.')[2]}, shape {p.shape}")

features.0.weight
Layer features.0, type weight, shape (96, 3, 11, 11)
features.0.bias
Layer features.0, type bias, shape (96,)
features.3.weight
Layer features.3, type weight, shape (96,)
features.3.bias
Layer features.3, type bias, shape (96,)
features.4.weight
Layer features.4, type weight, shape (256, 96, 5, 5)
features.4.bias
Layer features.4, type bias, shape (256,)
features.6.weight
Layer features.6, type weight, shape (256,)
features.6.bias
Layer features.6, type bias, shape (256,)
features.8.weight
Layer features.8, type weight, shape (384, 256, 3, 3)
features.8.bias
Layer features.8, type bias, shape (384,)
features.10.weight
Layer features.10, type weight, shape (384, 384, 3, 3)
features.10.bias
Layer features.10, type bias, shape (384,)
features.12.weight
Layer features.12, type weight, shape (256, 384, 3, 3)
features.12.bias
Layer features.12, type bias, shape (256,)
fc_layers.1.weight
Layer fc_layers.1, type weight, shape (256, 256)
fc_layers.1.bias
Layer fc_layers.1, typ

In [33]:
import os
subdir = 'trained' if epochs != 0 else 'init'
path = f"{os.path.expanduser('~')}/DNN/params/{subdir}/AlexNet/"
if not os.path.exists(path):
    os.makedirs(path)

print(f'Save parameters to {path}')

# print(f'cnn1_weight_0: {params[0][1].reshape(3*11*11, 96).tolist()}')
np.savetxt(fname=path+"cnn1_weight_0", delimiter=" ", X=params[0][1].reshape(3*11*11, 96).tolist())
np.savetxt(fname=path+"cnn1_bias_0", delimiter=" ",  X=params[1][1].tolist())
np.savetxt(fname=path+"bn1_gamma_0", delimiter=" ",  X=params[2][1].tolist())
np.savetxt(fname=path+"bn1_beta_0", delimiter=" ",  X=params[3][1].tolist())
np.savetxt(fname=path+"cnn2_weight_0", delimiter=" ", X=params[4][1].reshape(96*5*5, 256).tolist())
np.savetxt(fname=path+"cnn2_bias_0", delimiter=" ", X=params[5][1].tolist())
np.savetxt(fname=path+"bn2_gamma_0", delimiter=" ",  X=params[6][1].tolist())
np.savetxt(fname=path+"bn2_beta_0", delimiter=" ",  X=params[7][1].tolist())
np.savetxt(fname=path+"cnn3_weight_0", delimiter=" ", X=params[8][1].reshape(256*3*3, 384).tolist())
np.savetxt(fname=path+"cnn3_bias_0", delimiter=" ", X=params[9][1].tolist())
np.savetxt(fname=path+"cnn4_weight_0", delimiter=" ", X=params[10][1].reshape(384*3*3, 384).tolist())
np.savetxt(fname=path+"cnn4_bias_0", delimiter=" ", X=params[11][1].tolist())
np.savetxt(fname=path+"cnn5_weight_0", delimiter=" ", X=params[12][1].reshape(384*3*3, 256).tolist())
np.savetxt(fname=path+"cnn5_bias_0", delimiter=" ", X=params[13][1].tolist())
# FC
np.savetxt(fname=path+"fc1_weight_0", delimiter=" ", X=params[14][1].tolist())
np.savetxt(fname=path+"fc1_bias_0", delimiter=" ", X=params[15][1].tolist())
np.savetxt(fname=path+"fc2_weight_0", delimiter=" ", X=params[16][1].tolist())
np.savetxt(fname=path+"fc2_bias_0", delimiter=" ", X=params[17][1].tolist())
np.savetxt(fname=path+"fc3_weight_0", delimiter=" ", X=params[18][1].tolist())
np.savetxt(fname=path+"fc3_bias_0", delimiter=" ", X=params[19][1].tolist())

Save parameters to /home/haoqi.whq/DNN/params/init/AlexNet/


# Batch Normalization Pseudocode

In [251]:
import math

IT_N = 3
def inverse_sqrt(x):
    init_g = np.exp(-(x/2+0.2))*2.2 + 0.2
    init_g -= x/1024

    for i in range(IT_N):
        init_g = init_g*(3-x * init_g * init_g)/2
    return init_g
    
class BN:
    def __init__(self, dims, gamma=0, beta=0):
        print('BN')
        self.dims = dims
        self.eps = 0
        self.gamma = np.ones((dims, ), dtype="float32")
        self.beta = np.zeros((dims, ), dtype="float32")
        
        self.inv_sqrt = None
        self.norm_x = None

        self.beta_grad = None
        self.gamma_grad = None
        self.act_grad = None

    def forward(self, x):
        mean = np.mean(x, axis=0)   # 1 truncation by batchSize [1, D]
        x_mean = x - mean   # [B, D]
        var = np.mean(x_mean * x_mean, axis=0)  # 1 multiplication, 1 truncation by batchsize [1, D]
        var_eps = var + self.eps

        # protocol inv_sqrt
        self.inv_sqrt = 1. / np.sqrt(var_eps)   # 1 inverse sqrt [1, D]
        self.inv_sqrt = inverse_sqrt(var_eps)   # 1 inverse sqrt [1, D]
        # print("======")
        print(var_eps)
        print('Target: ', 1. / np.sqrt(var_eps))
        print('Compute: ', inverse_sqrt(var_eps))
        self.norm_x = x_mean * self.inv_sqrt    # 1 multiplication [B, D] * [1, D]. Falcon here has bug.
        # print(self.inv_sqrt)
        print(self.norm_x)

        return self.gamma * self.norm_x + self.beta     # 1 multiplication

    def backward(self, grad):
        B, D = grad.shape
        self.beta_grad = np.sum(grad, axis=0)
        self.gamma_grad = np.sum(self.norm_x * grad, axis=0)    # 1 multiplication

        dxhat = grad * self.gamma   # 1 multiplication

        print('+' * 20)
        print(B*dxhat)
        # print(np.sum(dxhat, axis=0))
        # print(np.sum(dxhat * self.norm_x, axis=0))
        # print(self.norm_x * np.sum(dxhat * self.norm_x, axis=0))
        print((B*dxhat - np.sum(dxhat, axis=0) - self.norm_x * np.sum(dxhat * self.norm_x, axis=0)))
        self.act_grad = self.inv_sqrt * \
                        (B*dxhat - np.sum(dxhat, axis=0) - self.norm_x * np.sum(dxhat * self.norm_x, axis=0)) \
                        / B # 3 multiplication, 1 truncation

        return self.act_grad, self.gamma_grad, self.beta_grad

bn = BN(5)

data = np.array([[0.02, 2, 3, 4, 5],
                 [0.03, 3, 5, 7, 8],
                 [0.04, 2, 3, 6, 6],
                 [0.01, 2, 4, 5, 20]]).astype(np.float32)
x_raw = torch.from_numpy(data)
x = x_raw.numpy()

grad = np.array([[1, 2, 3, 4, 5],
                 [1, 3, 5, 7, 8],
                 [1, 2, 3, 6, 6],
                 [1, 2, 4, 5, 6]]).astype(np.float32)
# grad = grad * 1024

print("forward")
f_o = bn.forward(x)
print(f_o)

# print("backward")
# b_o = bn.backward(grad)
# print(b_o)

# print(inverse_sqrt(np.array([16, 100, 0.01])))

BN
forward
[1.2499999e-04 1.8750000e-01 6.8750000e-01 1.2500000e+00 3.6187500e+01]
Target:  [89.442726    2.309401    1.2060454   0.8944272   0.16623433]
Compute:  [6.744339   2.3093176  1.2058791  0.8931092  0.16623433]
[[-0.03372169 -0.5773294  -0.9044093  -1.3396637  -0.78961307]
 [ 0.0337217   1.7319882   1.5073489   1.3396637  -0.29091007]
 [ 0.10116509 -0.5773294  -0.9044093   0.4465546  -0.62337875]
 [-0.10116508 -0.5773294   0.30146977 -0.4465546   1.7039019 ]]
[[-0.03372169 -0.5773294  -0.9044093  -1.3396637  -0.78961307]
 [ 0.0337217   1.7319882   1.5073489   1.3396637  -0.29091007]
 [ 0.10116509 -0.5773294  -0.9044093   0.4465546  -0.62337875]
 [-0.10116508 -0.5773294   0.30146977 -0.4465546   1.7039019 ]]


In [252]:
class MyBN:
    def __init__(self, momentum, eps, num_features):
        """
        初始化参数值
        :param momentum: 追踪样本整体均值和方差的动量
        :param eps: 防止数值计算错误
        :param num_features: 特征数量
        """
        # 对每个batch的mean和var进行追踪统计
        self._running_mean = 0
        self._running_var = 1
        # 更新self._running_xxx时的动量
        self._momentum = momentum
        # 防止分母计算为0
        self._eps = eps
        # 对应论文中需要更新的beta和gamma，采用pytorch文档中的初始化值
        self._beta = np.zeros(shape=(num_features, ))
        self._gamma = np.ones(shape=(num_features, ))

    def batch_norm(self, x):
        """
        BN向传播
        :param x: 数据
        :return: BN输出
        """
        x_mean = x.mean(axis=0)
        x_var = x.var(axis=0)
        # 对应running_mean的更新公式
        self._running_mean = (1-self._momentum)*x_mean + self._momentum*self._running_mean
        self._running_var = (1-self._momentum)*x_var + self._momentum*self._running_var
        # 对应论文中计算BN的公式
        x_hat = (x-x_mean)/np.sqrt(x_var+self._eps)
        y = self._gamma*x_hat + self._beta
        return y

my_bn = MyBN(momentum=0.01, eps=1e-5, num_features=5)
bn_output = my_bn.batch_norm(x)
print(bn_output)

[[-0.43033141 -0.57733488 -0.90452743 -1.34163547 -0.78961295]
 [ 0.43033159  1.73200464  1.50754571  1.34163547 -0.29091004]
 [ 1.29099452 -0.57733488 -0.90452743  0.44721183 -0.62337863]
 [-1.29099441 -0.57733488  0.30150914 -0.44721183  1.70390165]]


In [253]:
input = torch.randn(2, 2, 1, 2)
print(input)
print(torch.mean(input, [0, 2, 3]))

input.var([0,2,3])

tensor([[[[ 0.9448,  1.1164]],

         [[-0.6701, -0.7322]]],


        [[[ 0.4516,  0.5333]],

         [[ 1.6078, -0.7906]]]])
tensor([ 0.7615, -0.1463])


tensor([0.1026, 1.3700])

In [254]:

def batchnorm_forward(x, gamma, beta, eps):
    # read some useful parameter
    N, D = x.shape

    # BN forward pass
    sample_mean = x.mean(axis=0)
    sample_var = x.var(axis=0)
    x_ = (x - sample_mean) / np.sqrt(sample_var + eps)
    out = gamma * x_ + beta

    # storage variables for backward pass
    cache = (x_, gamma, x - sample_mean, sample_var + eps)

    return out, cache


def batchnorm_backward(dout, cache):
    # extract variables
    N, D = dout.shape
    x_, gamma, x_minus_mean, var_plus_eps = cache

    # calculate gradients
    dgamma = np.sum(x_ * dout, axis=0)
    dbeta = np.sum(dout, axis=0)

    dx_ = np.matmul(np.ones((N,1)), gamma.reshape((1, -1))) * dout
    dx = N * dx_ - np.sum(dx_, axis=0) - x_ * np.sum(dx_ * x_, axis=0)
    dx *= (1.0/N) / np.sqrt(var_plus_eps)

    return dx, dgamma, dbeta

gamma = np.ones((5, ), dtype="float32")
beta = np.zeros((5, ), dtype="float32")
f_o, cache = batchnorm_forward(x, gamma, beta, 1e-5)
print('forward', f_o)

grad = np.ones((3, 5), dtype='float32')
dx, dgamma, dbeta = batchnorm_backward(grad, cache)
print(dx)
print(dgamma)
print(dbeta)

forward [[-0.4303314  -0.5773349  -0.9045274  -1.3416355  -0.78961295]
 [ 0.4303316   1.7320046   1.5075457   1.3416355  -0.29091004]
 [ 1.2909945  -0.5773349  -0.9045274   0.44721183 -0.62337863]
 [-1.2909944  -0.5773349   0.30150914 -0.44721183  1.7039016 ]]


ValueError: operands could not be broadcast together with shapes (4,5) (3,5) 

In [None]:

def backward_batchnorm2d(input, output, grad_output, layer):
    gamma = layer.weight
    gamma = gamma.view(1,-1,1,1) # edit
    # beta = layer.bias
    # avg = layer.running_mean
    # var = layer.running_var
    eps = layer.eps
    B = input.shape[0] * input.shape[2] * input.shape[3] # edit

    # add new
    mean = input.mean(dim = (0,2,3), keepdim = True)
    variance = input.var(dim = (0,2,3), unbiased=False, keepdim = True)
    x_hat = (input - mean)/(torch.sqrt(variance + eps))
    
    inv_sqrt = 1.0 / torch.sqrt(variance + eps)

    dL_dxi_hat = grad_output * gamma
    # dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True) 
    # dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
    dL_dvar = (-0.5 * dL_dxi_hat * (input - mean)).sum((0, 2, 3), keepdim=True)  * ((variance + eps) ** -1.5) # edit
    dL_davg = (-inv_sqrt * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + (dL_dvar * (-2.0 * (input - mean)).sum((0, 2, 3), keepdim=True) / B) #edit
    
    dL_dxi = (dL_dxi_hat / torch.sqrt(variance + eps)) + (2.0 * dL_dvar * (input - mean) / B) + (dL_davg / B) # dL_dxi_hat / sqrt()
    # dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True) 
    dL_dgamma = (grad_output * x_hat).sum((0, 2, 3), keepdim=True) # edit
    dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
    return dL_dxi, dL_dgamma, dL_dbeta

In [None]:
from scipy.special import softmax
x = np.array([
    [2.60365, -0.290006, -0.721366, 0.507656, -0.434107, -0.595984, 0.532808, -0.0694342, -0.377345, -0.0523729],
    [-0.124932, 0.807279, -0.534957, -0.0588999, 0.0464973, -0.855241, 0.116036, 0.219089, 0.0108147, 0.756191],
    [0.41907, 0.301184, -0.434992, 0.190973, 0.18747, -0.639831, 0.21955, 0.151983, -0.169142, 0.523242]
    ])

m = softmax(x, axis=1)
m

array([[0.61762686, 0.0342001 , 0.02221724, 0.07593597, 0.02961045,
        0.02518505, 0.07787013, 0.04264032, 0.03133982, 0.04337406],
       [0.07585606, 0.19268342, 0.05034063, 0.08103407, 0.09004115,
        0.03654442, 0.09652534, 0.10700318, 0.0868849 , 0.18308683],
       [0.13332502, 0.11849893, 0.0567541 , 0.10613299, 0.10576185,
        0.04624202, 0.1092097 , 0.1020745 , 0.07403796, 0.14796295]])

In [None]:
IT_N = 3
def inverse_sqrt(x):
    init_g = np.exp(-(x/2+0.2))*2.2 + 0.2
    init_g -= x/1024

    for i in range(IT_N):
        init_g = init_g*(3-x * init_g * init_g)/2
    return init_g