In [1]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns
import time
import h5py
import sys
from dataset import Dataset, SpikingDataset, SpikingMNISTDataset
from torch.utils.data.dataloader import DataLoader
import torch
import torch.nn as nn
import torchvision
random.seed(1338)
torch.__version__
dtype = torch.float

# Check whether a GPU is available
if torch.cuda.is_available():
    device = torch.device('cuda:1')
    #device = torch.device("cuda")     
else:
    device = torch.device("cpu")
tau_mem = 10e-3
tau_syn = 5e-3
time_step = 1e-3
alpha   = float(np.exp(-time_step/tau_syn))
beta    = float(np.exp(-time_step/tau_mem))

weight_scale = 7*(1.0-beta) 

print("init done")

#The class is based on the code in: https://github.com/fzenke/spytorch.git
class SurrGradSpike(torch.autograd.Function):
    scale = 100.0

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
        return grad
    
spike_fn  = SurrGradSpike.apply

class FakeQuantize(torch.autograd.Function):
    @staticmethod
    def forward(ctx, in_data, num_bits, scale):
        #if min_data is None or max_data is None:
        min_data, max_data = torch.min(in_data), torch.max(in_data)
        upper = torch.max(torch.abs(min_data), torch.abs(max_data))
        lower = torch.min(torch.abs(min_data), torch.abs(max_data))
        qmin = 0
        qmax = 2**num_bits - 1 
        length = upper #+ torch.FloatTensor(1,).uniform_(-0.05*upper, 0.05*upper).to(device)
        scale[0] = length / (qmax - qmin)
        #new_data = torch.clamp(in_data, min=min_data.item(), max=max_data.item())
        output = torch.round(in_data / scale[0]) 
        output = torch.clamp(output, min=-255, max=255)
        output = output * scale[0]
        #output = torch.round((new_data - min_data) / scale) * scale  + min_data
        #output = torch.clamp(output, min=min_data.item(), max=max_data.item())
        return output

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None, None, None
q_fn = FakeQuantize.apply

class QLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=False, num_bits=8):
        super(QLinear, self).__init__(in_features, out_features, bias)
        self.num_bits = num_bits
        self.scale = [1.0]
        
    def forward(self, x, q=False):
        if q:
            W = q_fn(self.weight, self.num_bits, self.scale)
            output = F.linear(x, W)
        else:
            output = F.linear(x, self.weight)
            self.scale[0] = torch.ones([1], dtype=torch.float32, device=device)
        return output, self.scale[0]

init done


In [2]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

#The class is based on the code in: https://github.com/fzenke/spytorch.git
def Recurrent(h1, nb_hidden, nb_steps, scale):
    batch_size = h1.shape[0]
    syn = torch.zeros((batch_size,nb_hidden), device=h1.device)
    mem = torch.zeros((batch_size,nb_hidden), device=h1.device)

    mem_rec = [mem]
    spk_rec = [mem]

    for t in range(nb_steps):
        #mthr = mem-np.rint(1.0 / scale.cpu().item()) * scale.cpu().item()
        mthr = mem-1.0
        out = spike_fn(mthr)
        rst = torch.zeros_like(mem)
        c   = (mthr > 0)
        rst[c] = torch.ones_like(mem)[c]

        new_syn = alpha*syn +h1[:,t]
        new_mem = beta*mem +syn -rst

        mem = new_mem
        syn = new_syn

        mem_rec.append(mem)
        spk_rec.append(out)

    mem_rec = torch.stack(mem_rec,dim=1)
    spk_rec = torch.stack(spk_rec,dim=1)
    return spk_rec

#The class is based on the code in: https://github.com/fzenke/spytorch.git
def Readout(h2, nb_outputs):
    # Readout layer
    batch_size = h2.shape[0]
    flt = torch.zeros((batch_size,nb_outputs), device=h2.device)
    out = torch.zeros((batch_size,nb_outputs), device=h2.device)
    out_rec = [out]
    for t in range(nb_steps):
        new_flt = alpha*flt +h2[:,t]
        new_out = beta*out +flt

        flt = new_flt
        out = new_out

        out_rec.append(out)

    out_rec = torch.stack(out_rec,dim=1)
    #other_recs = [mem_rec, spk_rec]
    return out_rec  #, other_recs

class SpikingDenseNet(nn.Module):
    def __init__(self, input_size, nb_steps, num_output, num_h):
        super(SpikingDenseNet, self).__init__()
        self.fc1 = QLinear(input_size, num_h, bias=False, num_bits=8)
        self.fc2 = QLinear(num_h, num_h, bias=False, num_bits=8)
        self.fc3 = QLinear(num_h, num_output, bias=False, num_bits=8)
        self.num_h1 = num_h
        self.num_h2 = num_h
        self.num_output = num_output
        self.nb_steps = nb_steps
        

    def forward(self, x):
        if self.training:
            n = 3
            vec = [False for _ in range(n)]
            idx = np.random.randint(low=0, high=n)
            vec[idx] = True
            #print(idx)
        else:
            vec = [False for _ in range(n)]
        x, scale = self.fc1(x, vec[0])
        x = Recurrent(x, self.num_h1, self.nb_steps, scale)
        x, scale = self.fc2(x, vec[1])
        x = Recurrent(x, self.num_h2, self.nb_steps, scale)
        x, scale = self.fc3(x, vec[2])
        x = Readout(x, self.num_output)
        return x

In [3]:
def train(trainloader, testloader, model, lr=2e-3, nb_epochs=10):
    #params = [w1,w2]
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9,0.999))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    log_softmax_fn = nn.LogSoftmax(dim=1)
    loss_fn = nn.NLLLoss()
    
    loss_hist = []
    for e in range(nb_epochs):
        local_loss = []
        t0 = time.time()
        for x_local, y_local in trainloader:
            x_local = x_local.float().to(device)
            y_local = y_local.long().to(device)
            #print(x_local.shape, y_local.shape)
            output = model(x_local)
            m,_=torch.max(output,1)
            log_p_y = log_softmax_fn(m)
            loss_val = loss_fn(log_p_y, y_local)

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            local_loss.append(loss_val.item())
        t1 = time.time()
        print(t1 - t0)
        print("epoch: %d", e)
        if e % 30 == 0:
            print("Training accuracy: %.3f"%(compute_classification_accuracy(trainloader, net, "train")))
            print("Test accuracy: %.3f"%(compute_classification_accuracy(testloader, net, name)))
        scheduler.step()
        mean_loss = np.mean(local_loss)
        #print("Epoch %i: loss=%.5f"%(e+1,mean_loss))
        loss_hist.append(mean_loss)
        
    return loss_hist
        
def Extract(x, y, res, label):
    x = x.cpu().numpy()
    y = y.cpu().numpy()
    batch_size, time_step, inputs = x.shape
    
    for i in range(batch_size):
        for j in range(inputs):
            res[j].append(np.nonzero(x[i, :, j])[0])
        label.append(y[i])
    
def compute_classification_accuracy(dataloader, model, name):
    """ Computes classification accuracy on supplied data in batches. """
    accs = []
    num_input = 28 * 28 * 1
    res, label = [[] for _ in range(num_input)], []
    for x_local, y_local in dataloader:
        x_local = x_local.float().to(device)
        y_local = y_local.long().to(device)
        #Extract(x_local, y_local, res, label)
        output = model(x_local)
        m,_= torch.max(output,1) # max over time
        _,am=torch.max(m,1)      # argmax over output units
        tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
        accs.append(tmp)
    #with h5py.File('test.hdf5', 'w') as f: 
    #    f.create_dataset("data", data = np.array(res).astype(np.float32))
    #    f.create_dataset("label", data = np.array(label).astype(np.float32))
    return np.mean(accs)

In [4]:
def GetMNIST():
    import tensorflow as tf
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
  
    #scale input data to 0 and 1
    image_size = 28 * 28 * 1
    x_train = x_train.reshape(-1, image_size).astype(np.float32) / 255.0
    x_test = x_test.reshape(-1, image_size).astype(np.float32) / 255.0
    return x_train, y_train, x_test, y_test
    """
    #label to one_hot label
    def one_hot(label, classes):
        row = label.shape[0]
        one_hot_label = np.zeros((row, classes))
        one_hot_label[np.arange(row), label] = 1
        return one_hot_label
        y_train = one_hot(y_train.flatten(), 10)
        y_test = one_hot(y_test.flatten(), 10)
    """
x_train, y_train, x_test, y_test = GetMNIST()
print(x_train.shape)

(60000, 784)


In [5]:
nb_inputs  = 28 * 28 * 1
nb_hidden  = 96
nb_outputs = 10

time_step = 1e-3
nb_steps  = 100

batch_size = 512
#device = torch.device("cuda")


loss_hists = []
for i in range(1):
    #load dataset
    net = SpikingDenseNet(input_size=nb_inputs, nb_steps=nb_steps, num_output=nb_outputs, num_h=nb_hidden).float().to(device)
    
    train_data = SpikingMNISTDataset(x_train, y_train, nb_steps)
    test_data = SpikingMNISTDataset(x_test, y_test, nb_steps)
    trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)
    testloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=8)
    lr = 1e-3
    name = 'MNIST'
    loss_hist = train(trainloader, testloader, net, lr=lr, nb_epochs=80)
    print("Training accuracy: %.3f"%(compute_classification_accuracy(trainloader, net, "train")))
    print("Test accuracy: %.3f"%(compute_classification_accuracy(testloader, net, name)))
    loss_hists.append(loss_hist)
    f = open('weight_QAT_' + name + '.npy', 'wb')
    for name, param in net.named_parameters():
        print(name, param.shape)
        np.save(f, param.detach().cpu().numpy())
    f.close()


63.00795841217041
epoch: %d 0
Training accuracy: 0.430
Test accuracy: 0.433
128.59041666984558
epoch: %d 1
128.63519740104675
epoch: %d 2
127.83162713050842
epoch: %d 3
128.08943629264832
epoch: %d 4
128.95740628242493
epoch: %d 5
129.44070982933044
epoch: %d 6
128.50268268585205
epoch: %d 7
128.65231370925903
epoch: %d 8
127.91505336761475
epoch: %d 9
127.86331057548523
epoch: %d 10
128.37370944023132
epoch: %d 11
128.77973818778992
epoch: %d 12
128.64366841316223
epoch: %d 13
127.42506122589111
epoch: %d 14
127.74420237541199
epoch: %d 15
127.83533644676208
epoch: %d 16
128.3978168964386
epoch: %d 17
129.07325911521912
epoch: %d 18
128.74108982086182
epoch: %d 19
130.00708961486816
epoch: %d 20
129.51887917518616
epoch: %d 21
127.26081085205078
epoch: %d 22
128.32207345962524
epoch: %d 23
128.65867161750793
epoch: %d 24
129.85766983032227
epoch: %d 25
127.5685362815857
epoch: %d 26
128.26856684684753
epoch: %d 27
127.74617958068848
epoch: %d 28
128.88242864608765
epoch: %d 29
129.352