In [14]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy
import cv2
from matplotlib import pyplot as plt
from torchvision import datasets
import torchvision
import copy

In [15]:
# Hyper-parameters
batch_size = 128
learning_rate = 0.001
epochs = 5

In [16]:
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda')
print(device)

cuda


In [17]:
# Loading our dataset
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_data = datasets.MNIST(root='./../data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./../data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_data,batch_size=batch_size,shuffle=False)

In [18]:
# Create our model
class mnist_nn(nn.Module):
    def __init__(self):
        super(mnist_nn,self).__init__()
        
        self.linear_1 = nn.Linear(in_features=784,out_features=1600)
        self.activation = nn.ReLU()
        self.linear_2 = nn.Linear(in_features=1600,out_features=10, bias=False)

    def forward(self,x):
        x = self.linear_1(x)
        x = self.activation(x)
        x = self.linear_2(x)
        return(x)

In [19]:
model = mnist_nn().to(device)

In [20]:
criteration = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr= learning_rate)

In [21]:
def accuracy():
    with torch.no_grad():
        n_correct = 0
        n_samples = 0
        for images, labels in test_loader:
            images = images.reshape(-1, 28*28).to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            n_samples += labels.size(0)
            n_correct += (predicted == labels).sum().item() 
        return(n_correct/n_samples * 100)

In [22]:
# Generating non duplicate random nodes to be stashed out of hidden layer per epoch 
import random
stashed = random.sample(range(1600), 8)
stashed

[1264, 958, 1154, 940, 609, 1451, 663, 1426]

In [23]:
%%time
j = 0
local_stashed = []
stashed_lin1_wt_dict = {}
stashed_lin1_bias_dict = {}
stashed_lin2_wt_dict = {}
n_total_steps = len(train_loader)
for epoch_iter in range(epochs):
    if(epoch_iter == (epochs-1)):
        print("Joined all the stashed nodes")
        for i in stashed:
            assert model.linear_2.weight[:,i].any() == False
            assert model.linear_1.weight[i,:].any() == False
            assert model.linear_1.bias[i] == 0
        print("Assertions Passed")
        with torch.no_grad():
            for key in stashed_lin1_wt_dict:
                model.linear_1.weight[key,:] = stashed_lin1_wt_dict[key]
                model.linear_1.bias[key] = stashed_lin1_bias_dict[key]
                model.linear_2.weight[:,key] = stashed_lin2_wt_dict[key]
   
    for i,(data, label) in enumerate(train_loader):
        x = data.reshape(-1,28*28).to(device)
        y = label.to(device)
        optimizer.zero_grad()
        pred = model(x)
        loss = criteration(pred,y)
        loss.backward()
        optimizer.step()
        if (i+1) % 100 == 0:
             print (f'Epoch [{epoch_iter+1}/{epochs}], Step[{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
    
        if((epoch_iter<(epochs-1)) and (i+1) % 200 == 0):
            print("Accuracy before stashing", accuracy())
            print("Node stashed", stashed[j])
            print("-------------------------")
            local_stashed.append(stashed[j])
            stashed_lin1_wt_dict[stashed[j]] =model.linear_1.weight[stashed[j],:].clone().detach()
            stashed_lin1_bias_dict[stashed[j]] = model.linear_1.bias[stashed[j]].clone().detach()
            stashed_lin2_wt_dict[stashed[j]] = model.linear_2.weight[:,stashed[j]].clone().detach()
            j+=1
        
        if(epoch_iter<(epochs-1)):
            for k in local_stashed:
                with torch.no_grad():
                    model.linear_1.weight[k,:] = 0
                    model.linear_1.bias[k] = 0
                    model.linear_2.weight[:,k] = 0
            
            if((i+1) % 200 == 0):
                print("Accuracy after_stashing is", accuracy())

Epoch [1/5], Step[100/469], Loss: 0.4226
Epoch [1/5], Step[200/469], Loss: 0.2443
Accuracy before stashing 94.35
Node stashed 1264
-------------------------
Accuracy after_stashing is 94.35
Epoch [1/5], Step[300/469], Loss: 0.1323
Epoch [1/5], Step[400/469], Loss: 0.1663
Accuracy before stashing 96.36
Node stashed 958
-------------------------
Accuracy after_stashing is 96.38
Epoch [2/5], Step[100/469], Loss: 0.0500
Epoch [2/5], Step[200/469], Loss: 0.0608
Accuracy before stashing 97.2
Node stashed 1154
-------------------------
Accuracy after_stashing is 97.2
Epoch [2/5], Step[300/469], Loss: 0.1072
Epoch [2/5], Step[400/469], Loss: 0.1364
Accuracy before stashing 97.23
Node stashed 940
-------------------------
Accuracy after_stashing is 97.22
Epoch [3/5], Step[100/469], Loss: 0.0672
Epoch [3/5], Step[200/469], Loss: 0.0869
Accuracy before stashing 97.75
Node stashed 609
-------------------------
Accuracy after_stashing is 97.72999999999999
Epoch [3/5], Step[300/469], Loss: 0.0599
Ep

In [24]:
accuracy()

97.99

In [25]:
torch.save(model, 'software_stashing_randomly.pt')

In [26]:
# batch = next(iter(train_loader))
# yhat = model(batch[0].reshape(-1,28*28).to(device)) # Give dummy batch to forward().

# from torchviz import make_dot

# make_dot(yhat, params=dict(list(model.named_parameters()))).render("simple_nn", format="png")