In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchsummaryX import summary

In [2]:
import os
cwd = os.getcwd()
print(cwd)

/Users/wudidaizi/Project/stochasticSim/pytorchSim


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [4]:
# MNIST data loader
transform=transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])

trainset = torchvision.datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4)

testset = torchvision.datasets.MNIST(root='./data/mnist', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, num_workers=4)


In [5]:
class SC_Weight_Clipper(object):
    
    def __init__(self, frequency=1, representation="bipolar", constraint="clip", bits=8):
        self.frequency = frequency
        # "unipolar" or "bipolar"
        self.representation = representation
        # "clip" or "norm"
        self.constraint = constraint
        self.bits = bits
        self.scale = 2 ** self.bits

    def __call__(self, module):
        # filter the variables to get the ones you want
        if self.frequency > 1:
            self.constraint = "clip"
        else:
            self.constraint = "norm"
                
        if hasattr(module, 'weight'):
            w = module.weight.data
            self.clipping(w)
        
        if hasattr(module, 'bias'):
            w = module.bias.data
            self.clipping(w)
        
        self.frequency = self.frequency + 1
            
    def clipping(self, w):
        if self.representation == "unipolar":
            if self.constraint == "norm":
                w.sub_(torch.min(w)).div_(torch.max(w) - torch.min(w)) \
                .mul_(self.scale).round_().clamp_(0.0,self.scale).div_(self.scale)
            elif self.constraint == "clip":
                w.clamp_(0.0,1.0) \
                .mul_(self.scale).round_().clamp_(0.0,self.scale).div_(self.scale)
            else:
                raise TypeError("unknown constraint type '{}' in SC_Weight, should be 'clip' or 'norm'"
                                .format(self.constraint))
        elif self.representation == "bipolar":
            if self.constraint == "norm":
                w.sub_(torch.min(w)).div_(torch.max(w) - torch.min(w)).mul_(2).sub_(1) \
                .mul_(self.scale/2).round_().clamp_(-self.scale/2,self.scale/2).div_(self.scale/2)
            elif self.constraint == "clip":
                w.clamp_(-1.0,1.0) \
                .mul_(self.scale/2).round_().clamp_(-self.scale/2,self.scale/2).div_(self.scale/2)
            else:
                raise TypeError("unknown constraint type '{}' in SC_Weight, should be 'clip' or 'norm'"
                                .format(self.constraint))
        else:
            raise TypeError("unknown representation type '{}' in SC_Weight, should be 'unipolar' or 'bipolar'"
                            .format(self.representation))

In [6]:
 class LeNet(nn.Module):

    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
#         self.conv1 = nn.Conv2d(1, 6, 5)
#         self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.conv1 = nn.Linear(1*1*32*32, 512)  # 6*6 from image dimension

        self.conv2 = nn.Linear(512, 400)  # 6*6 from image dimension

#         self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 6*6 from image dimension


        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x.view(-1, 32*32)
        x = self.conv1(x)
        x = torch.clamp(x, -1, 1)
#         x = F.avg_pool2d(x, (2, 2))
        x = F.relu(x)
        # If the size is a square you can only specify a single number
        x = self.conv2(x)
        x = torch.clamp(x, -1, 1)
#         x = F.avg_pool2d(x, (2, 2))
        x = F.relu(x)
        
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = torch.clamp(x, -1, 1)
        x = F.relu(self.fc2(x))
        x = torch.clamp(x, -1, 1)
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(32*32, 50)
        self.fc1_drop = nn.Dropout(0.2)
        self.fc2 = nn.Linear(50, 50)
        self.fc2_drop = nn.Dropout(0.2)
        self.fc3 = nn.Linear(50, 10)

    def forward(self, x):
        x = x.view(-1, 32*32)
        x = F.relu(self.fc1(x))
        x = self.fc1_drop(x)        
#         x = torch.clamp(x, -1, 1)
        x = F.relu(self.fc2(x))
        x = self.fc2_drop(x)
#         x = torch.clamp(x, -1, 1)
        return F.log_softmax(self.fc3(x), dim=1)

class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(1024, 512, bias=True) 
        self.lin2 = nn.Linear(512, 256, bias=True)
        self.lin3 = nn.Linear(256, 10, bias=True)

    def forward(self, xb):
        x = xb.view(-1,1024) 
        x = F.relu(self.lin1(x))
        x = torch.clamp(x, -1, 1)
        x = F.relu(self.lin2(x))
        x = torch.clamp(x, -1, 1)
        return self.lin3(x)

net = Mnist_NN()
net.to(device)

Mnist_NN(
  (lin1): Linear(in_features=1024, out_features=512, bias=True)
  (lin2): Linear(in_features=512, out_features=256, bias=True)
  (lin3): Linear(in_features=256, out_features=10, bias=True)
)

In [7]:
num_bits = 8
clipper = SC_Weight_Clipper(bits = num_bits)

In [8]:
summary(net,torch.zeros((1, 1, 32, 32)).to(device))

       Kernel Shape Output Shape  Params  Mult-Adds
Layer                                              
0_lin1  [1024, 512]     [1, 512]  524800     524288
1_lin2   [512, 256]     [1, 256]  131328     131072
2_lin3    [256, 10]      [1, 10]    2570       2560
---------------------------------------------------
                      Totals
Total params          658698
Trainable params      658698
Non-trainable params       0
Mult-Adds             657920


Unnamed: 0_level_0,Kernel Shape,Output Shape,Params,Mult-Adds
Layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0_lin1,"[1024, 512]","[1, 512]",524800,524288
1_lin2,"[512, 256]","[1, 256]",131328,131072
2_lin3,"[256, 10]","[1, 10]",2570,2560


In [9]:
# dir(net.conv1.state_dict)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=2e-3)

In [None]:
for epoch in range(16):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()
        
    net.apply(clipper)
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # print('Accuracy of the network on the 10000 test images: %f %%' % (
    #     100 * correct / total))
    print('Train - Epoch %d, Loss: %f, Test Accuracy: %f %%' \
          % (epoch, loss.detach().cpu().item(), 100 * correct / total))

print('Finished Training')

In [None]:
model_path = cwd+"/saved_model_state_dict"+"_"+str(num_bits)
torch.save(net.state_dict(), model_path)

In [13]:
print("conv1 weight")
print(torch.min(net.conv1.weight).item(), "~", torch.max(net.conv1.weight).item())
print()
print("conv1 bias")
print(torch.min(net.conv1.bias).item(), "~", torch.max(net.conv1.bias).item())
print()

print("conv2 weight")
print(torch.min(net.conv2.weight).item(), "~", torch.max(net.conv2.weight).item())
print()
print("conv2 bias")
print(torch.min(net.conv2.bias).item(), "~", torch.max(net.conv2.bias).item())
print()

print("fc1 weight")
print(torch.min(net.fc1.weight).item(), "~", torch.max(net.fc1.weight).item())
print()
print("fc1 bias")
print( torch.min(net.fc1.bias).item(), "~", torch.max(net.fc1.bias).item())
print()

print("fc2 weight")
print(torch.min(net.fc2.weight).item(), "~", torch.max(net.fc2.weight).item())
print()
print("fc2 bias")
print(torch.min(net.fc2.bias).item(), "~", torch.max(net.fc2.bias).item())
print()

print("fc3 weight")
print(torch.min(net.fc3.weight).item(), "~", torch.max(net.fc3.weight).item())
print()
print("fc3 bias")
print(torch.min(net.fc3.bias).item(), "~", torch.max(net.fc3.bias).item())


conv1 weight
-0.953125 ~ 1.0

conv1 bias
-1.0 ~ 0.125

conv2 weight
-1.0 ~ 0.796875

conv2 bias
-0.3671875 ~ 0.09375

fc1 weight
-1.0 ~ 1.0

fc1 bias
-0.515625 ~ 0.359375

fc2 weight
-1.0 ~ 1.0

fc2 bias
-0.359375 ~ 0.21875

fc3 weight
-1.0 ~ 1.0

fc3 bias
-0.265625 ~ 0.171875


In [14]:
model = LeNet()
model.load_state_dict(torch.load(model_path))
model.eval()
model.to(device)

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [15]:
model.apply(clipper)
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %f %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 98.410000 %
