In [1]:
import functools
import time
from time import localtime, strftime

from sklearn import datasets
from snntorch import spikegen
from snntorch import functional as SF
import numpy as np
import torch
import torch.nn as nn
import snntorch as snn

from z3 import *
from collections import defaultdict

import logging
logging.basicConfig(filename=f"dead_neuron_pruning_{strftime('%m%d_%H-%M-%S', localtime())}.log", level=logging.INFO)
logger = logging.getLogger()

np.random.seed(42)
torch.manual_seed(42)
torch.use_deterministic_algorithms(True)

shuffle = True
beta = 0.95
num_steps = 25
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
train = False
file_name = 'model_iris.pth'


def compare(x, y):
    xx, yy = int(x.name().split('_')[-1]), int(y.name().split('_')[-1])
    return xx-yy


num_input = 4
num_hidden = 5
num_output = 3


class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_input, num_hidden, bias=False)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_output, bias=False)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x[step])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)


iris = datasets.load_iris()

iris_data = iris.data / iris.data.max(axis=0)
iris_targets = iris.target

if shuffle:
    assert len(iris_data) == len(iris_data)
    perm = np.random.permutation(len(iris_data))
    iris_data, iris_targets = iris_data[perm], iris_targets[perm]


num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

if train:
    net = Net()
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
    #loss = nn.CrossEntropyLoss()
    loss = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

    # Outer training loop
    for epoch in range(num_epochs):
        iter_counter = 0

        # Minibatch training loop
        for number in range(len(iris_targets)):
            data = torch.tensor(iris_data[number], dtype=torch.float)
            #targets = torch.tensor([0 if i != iris_targets[number] else 1 for i in range(max(iris_targets)+1)],dtype=torch.float)
            targets = torch.tensor([iris_targets[number]])

            # make spike trains
            data_spike = spikegen.rate(data, num_steps=num_steps)

            # forward pass
            net.train()
            spk_rec, mem_rec = net(data_spike.view(num_steps, -1))

            # initialize the loss & sum over time
            loss_val = torch.zeros((1), dtype=torch.float)
            for step in range(num_steps):
                loss_val += loss(mem_rec[step], targets)

            # Gradient calculation + weight update
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            # Store loss history for future plotting
            loss_hist.append(loss_val.item())

            if counter % 20 == 0:
                print(f"Epoch {epoch}, Iteration {iter_counter}")
            counter += 1
            iter_counter += 1
    # print("Saving model.pth")
    logger.info("Saving model.pth")
    torch.save(net, file_name)
else:
    net = torch.load(file_name)
    # print("Model loaded")
    logger.info("Model loaded")

check = True
if check:
    acc = 0
    perm = np.random.permutation(len(iris_data))
    test_data, test_targets = torch.tensor(iris_data[perm][:100], dtype=torch.float), torch.tensor(iris_targets[perm][:100])
    for i, data in enumerate(test_data):
        spike_data = spikegen.rate(data, num_steps=num_steps)
        spk_rec, mem_rec = net(spike_data.view(num_steps, -1))
        idx = np.argmax(spk_rec.sum(dim=0).detach().numpy())
        if idx == test_targets[i]:
            #print(f'match for {test_targets[i]}')
            acc += 1
        else:
            #print(f'Not match for {test_targets[i]}')
            pass
    # print(f'Accuracy of the model : {acc}%')
    logger.info(f'Accuracy of the model : {acc}%')

# print()
logger.info("")

$$1 \le t \le num\_steps$$
$$0 \le j \le num\_layers-1$$
$$0 \le i \le num\_nodes-1$$

In [2]:
# SMT encoding

# take a random input and make it into a spike train
layers = [num_input, num_hidden, num_output]
spike_indicators = {}
for t in range(num_steps):
    for j, m in enumerate(layers):
        for i in range(m):
            spike_indicators[(i, j, t+1)] = Bool(f'x_{i}_{j}_{t+1}')

potentials = {}
for t in range(num_steps+1):
    for j, m in enumerate(layers):
        if j == 0:
            continue
        for i in range(m):
            potentials[(i, j, t)] = Real(f'P_{i}_{j}_{t}')

weights = defaultdict(float)
w1 = net.fc1.weight
for j in range(len(w1)):
    for i in range(len(w1[j])):
        weights[(i, j, 0)] = float(w1[j][i])
w2 = net.fc2.weight
for j in range(len(w2)):
    for i in range(len(w2[j])):
        weights[(i, j, 1)] = float(w2[j][i])


$$\epsilon_0(i,0) \triangleq (P_{i,0}=0)$$

In [3]:
w1.shape

torch.Size([5, 4])

In [4]:
weights.keys()

dict_keys([(0, 0, 0), (1, 0, 0), (2, 0, 0), (3, 0, 0), (0, 1, 0), (1, 1, 0), (2, 1, 0), (3, 1, 0), (0, 2, 0), (1, 2, 0), (2, 2, 0), (3, 2, 0), (0, 3, 0), (1, 3, 0), (2, 3, 0), (3, 3, 0), (0, 4, 0), (1, 4, 0), (2, 4, 0), (3, 4, 0), (0, 0, 1), (1, 0, 1), (2, 0, 1), (3, 0, 1), (4, 0, 1), (0, 1, 1), (1, 1, 1), (2, 1, 1), (3, 1, 1), (4, 1, 1), (0, 2, 1), (1, 2, 1), (2, 2, 1), (3, 2, 1), (4, 2, 1)])

In [5]:
#=====================================================
# Potential Initializations
pot_init = []
for j, m in enumerate(layers):
    if j == 0:
        continue
    for i in range(m):
        pot_init.append(potentials[(i, j, 0)] == 0)

In the paper, S means:
$$\epsilon_1(i,t) \triangleq \left(S_{i,t}=\sum_{j\in inSynapse(N_i)}x_{j,t}\cdot w_{j,i}\right)$$
but in this code, it has different meaning:
$$S = P_{i,j,t-1} + \sum_{i\in Layer_{j-1}}x_{i,j-1,t}\cdot w_{j-1,i}$$

In [6]:
# Assign Inputs
'''
assign = []
for i, spikes_t in enumerate(sample_spike):
    for j, spike in enumerate(spikes_t):
        if spike == 1:
            assign.append(spike_indicators[(j, 0, i+1)])
        else:
            assign.append(Not(spike_indicators[(j, 0, i + 1)]))
'''
assign = []

# Node eqn
node_eqn = []

for i in range(len(w1)):
    node_eqn.append(
        Implies(
            Sum(
                [
                    If(weights[(k, i, 0)]>=0, weights[(k, i, 0)], 0) for k in range(len(w1[i]))
                ]
            ) < 1 * (1-0.95), # threshold * (1-lambda)
            Not(Or(
                [
                    spike_indicators[(i, 1, t)] for t in range(1, num_steps+1)
                ]
            ))
        )
    )

for i in range(len(w2)):
    node_eqn.append(
        Implies(
            Sum(
                [
                    If(weights[(k, i, 1)]>=0, weights[(k, i, 1)], 0) for k in range(len(w2[i]))
                ]
            ) < 1 * (1-0.95), # threshold * (1-lambda)
            Not(Or(
                [
                    spike_indicators[(i, 2, t)] for t in range(1, num_steps+1)
                ]
            ))
        )
    )
    
for t in range(1, num_steps+1):
    for j, m in enumerate(layers):
        if j == 0:
            continue

        for i in range(m):
            S = sum([spike_indicators[(k, j-1, t)]*weights[(k, i, j-1)] for k in range(layers[j-1])]) + potentials[(i, j, t-1)] # epsilon_1
            node_eqn.append(
                And(
                    Implies(
                        S >= 1.0,
                        And(spike_indicators[(i, j, t)], potentials[(i, j, t)] == S - 1) # epsilon_2 & epsilon_4
                    ),
                    Implies(
                        S < 1.0,
                        And(Not(spike_indicators[(i, j, t)]), potentials[(i, j, t)] == beta*S) # epsilon_3 & epsilon_5
                    )
                )
            )
            #print(f'==========================================================\nAdded equation {(i,j,t)}')
                

#S.push()
#print("Equations Created")

In [7]:
num_samples = 15


samples = iris_data[np.random.choice(range(len(iris_data)), num_samples)]
# print(samples)
logger.info(samples)
deltas = [1, 2, 3]

delta_v = {d: 0 for d in deltas}

for delta in deltas:
    avt = 0
    for sample_no, sample in enumerate(samples):
        sample_spike = spikegen.rate(torch.tensor(sample, dtype=torch.float), num_steps=num_steps)

        spk_rec, mem_rec = net(sample_spike.view(num_steps, -1)) # epsilon 1~5
        label = int(spk_rec.sum(dim=0).argmax())

        S = Solver()
        S.add(assign+node_eqn+pot_init)
        # S.add(node_eqn + pot_init)

        sum_val = []
        for timestep, spike_train in enumerate(sample_spike):
            for i, spike in enumerate(spike_train.view(num_input)):
                if spike == 1:
                    sum_val.append(If(spike_indicators[(i, 0, timestep + 1)], 0.0, 1.0))
                else:
                    sum_val.append(If(spike_indicators[(i, 0, timestep + 1)], 1.0, 0.0))
        prop = [sum(sum_val) <= delta]
        S.add(prop)
        '''
        s = [[] for i in range(num_steps)]
        sv = [Int(f's_{i + 1}') for i in range(num_steps)]
        prop = []
        for timestep, spike_train in enumerate(sample_spike):
            for i, spike in enumerate(spike_train.view(num_input)):
                if spike == 1:
                    s[timestep].append(If(spike_indicators[(i, 0, timestep + 1)], 0.0, 1.0))
                else:
                    s[timestep].append(If(spike_indicators[(i, 0, timestep + 1)], 1.0, 0.0))
        prop = [sv[i] == sum(s[i]) for i in range(num_steps)]
        prop.append(sum(sv) <= delta)
        # print(prop[0])
        #print(f"Inputs Property Done in {time.time() - tx} sec")
        '''

        # Output property
        #tx = time.time()
        op = []
        intend_sum = sum([2 * spike_indicators[(label, 2, timestep + 1)] for timestep in range(num_steps)])
        for t in range(num_output):
            if t != op:
                op.append(
                    Not(intend_sum > sum([2 * spike_indicators[(t, 2, timestep + 1)] for timestep in range(num_steps)]))
                )
        #print(f'Output Property Done in {time.time() - tx} sec')
        S.add(op)
        tx = time.time()
        res = S.check()
        if str(res) == 'unsat':
            delta_v[delta] += 1
        else:
            '''
            sadv = np.zeros((num_steps, num_input), dtype=float)
            m = S.model()
            for tt in range(num_steps):
                for k in range(num_input):
                    sadv[tt][k] = 1 if str(m[spike_indicators[(k, 0, tt + 1)]]) == 'True' else 0
            print()
            '''
            pass
        del S
        tss = time.time()-tx
        # print(f'Completed for delta = {delta}, sample = {sample_no} in {tss} sec as {res}')
        logger.info(f'Completed for delta = {delta}, sample = {sample_no} in {tss} sec as {res}')
        avt = (avt*sample_no + tss)/(sample_no+1)
    # print(f'Completed for delta = {delta} with {delta_v[delta]} in avg time {avt} sec')
    logger.info(f'Completed for delta = {delta} with {delta_v[delta]} in avg time {avt} sec')


'''
m = S.model()
for k in range(num_output):
    names = []
    for i in m.decls():
        t = i.name().split('_')
        if t[0] == 'x' and t[1] == f'{k}' and t[2] == '2':
            names.append(i)
    for i in sorted(names, key=functools.cmp_to_key(compare)):
        print(f'{i}->{m[i]}')
    input()
    print()
'''



print()




In [8]:
# [[0.84810127 0.70454545 0.63768116 0.56      ]
#  [0.81012658 0.65909091 0.62318841 0.52      ]
#  [0.59493671 0.72727273 0.23188406 0.08      ]
#  [0.81012658 0.70454545 0.79710145 0.72      ]
#  [0.73417722 0.90909091 0.17391304 0.08      ]]
# Completed for delta = 1, sample = 0 in 14.885828018188477 sec as unsat
# Completed for delta = 1, sample = 1 in 14.03385615348816 sec as unsat
# Completed for delta = 1, sample = 2 in 8.778179407119751 sec as unsat
# Completed for delta = 1, sample = 3 in 13.298641443252563 sec as unsat
# Completed for delta = 1, sample = 4 in 2.512458562850952 sec as sat
# Completed for delta = 1 with 4 in avg time 10.70179271697998 sec
# Completed for delta = 2, sample = 0 in 182.78500318527222 sec as unsat
# Completed for delta = 2, sample = 1 in 50.21033048629761 sec as unsat
# Completed for delta = 2, sample = 2 in 1.3425533771514893 sec as sat
# Completed for delta = 2, sample = 3 in 121.07727599143982 sec as unsat
# Completed for delta = 2, sample = 4 in 3.9044606685638428 sec as sat
# Completed for delta = 2 with 3 in avg time 71.863924741745 sec

# [[0.84810127 0.75       0.82608696 0.84      ]
#  [0.87341772 0.70454545 0.71014493 0.6       ]
#  [0.62025316 0.68181818 0.20289855 0.08      ]
#  [0.6835443  0.77272727 0.24637681 0.08      ]
#  [0.84810127 0.75       0.82608696 0.84      ]]
# Completed for delta = 1, sample = 0 in 13.993758916854858 sec as unsat
# Completed for delta = 1, sample = 1 in 16.232289791107178 sec as unsat
# Completed for delta = 1, sample = 2 in 3.728602647781372 sec as sat
# Completed for delta = 1, sample = 3 in 3.0936505794525146 sec as sat
# Completed for delta = 1, sample = 4 in 15.692162275314331 sec as unsat
# Completed for delta = 1 with 3 in avg time 10.54809284210205 sec
# Completed for delta = 2, sample = 0 in 206.85294008255005 sec as unsat
# Completed for delta = 2, sample = 1 in 91.93414258956909 sec as unsat
# Completed for delta = 2, sample = 2 in 5.711108207702637 sec as sat
# Completed for delta = 2, sample = 3 in 3.172511577606201 sec as sat
# Completed for delta = 2, sample = 4 in 117.94588279724121 sec as unsat
# Completed for delta = 2 with 3 in avg time 85.12331705093384 sec

# [[0.84810127 0.70454545 0.63768116 0.56      ]
#  [0.64556962 0.86363636 0.23188406 0.08      ]
#  [0.72151899 0.59090909 0.50724638 0.4       ]
#  [0.7721519  0.68181818 0.71014493 0.72      ]
#  [0.65822785 0.79545455 0.2173913  0.08      ]]
# Completed for delta = 1, sample = 0 in 15.12482500076294 sec as unsat
# Completed for delta = 1, sample = 1 in 2.7559990882873535 sec as sat
# Completed for delta = 1, sample = 2 in 6.47345232963562 sec as unsat
# Completed for delta = 1, sample = 3 in 9.045020818710327 sec as unsat
# Completed for delta = 1, sample = 4 in 2.6043717861175537 sec as sat
# Completed for delta = 1 with 3 in avg time 7.200733804702759 sec
# Completed for delta = 2, sample = 0 in 43.941203355789185 sec as unsat
# Completed for delta = 2, sample = 1 in 2.511519432067871 sec as sat
# Completed for delta = 2, sample = 2 in 94.44714975357056 sec as unsat
# Completed for delta = 2, sample = 3 in 153.38948512077332 sec as unsat
# Completed for delta = 2, sample = 4 in 5.546740531921387 sec as sat
# Completed for delta = 2 with 3 in avg time 59.96721963882446 sec

# [[0.59493671 0.72727273 0.1884058  0.08      ]
#  [0.72151899 0.65909091 0.60869565 0.52      ]
#  [0.87341772 0.70454545 0.73913043 0.92      ]
#  [0.84810127 0.56818182 0.84057971 0.72      ]
#  [0.84810127 0.56818182 0.84057971 0.72      ]]
# Completed for delta = 1, sample = 0 in 3.5639681816101074 sec as sat
# Completed for delta = 1, sample = 1 in 13.699343204498291 sec as unsat
# Completed for delta = 1, sample = 2 in 16.448461771011353 sec as unsat
# Completed for delta = 1, sample = 3 in 14.851109743118286 sec as unsat
# Completed for delta = 1, sample = 4 in 11.387865781784058 sec as unsat
# Completed for delta = 1 with 4 in avg time 11.990149736404419 sec
# Completed for delta = 2, sample = 0 in 4.561338901519775 sec as sat
# Completed for delta = 2, sample = 1 in 121.3413896560669 sec as unsat
# Completed for delta = 2, sample = 2 in 108.89543128013611 sec as unsat
# Completed for delta = 2, sample = 3 in 121.82413840293884 sec as unsat
# Completed for delta = 2, sample = 4 in 65.35597896575928 sec as unsat
# Completed for delta = 2 with 4 in avg time 84.39565544128418 sec

# [[0.59493671 0.72727273 0.23188406 0.08      ]
#  [0.60759494 0.70454545 0.23188406 0.08      ]
#  [0.55696203 0.65909091 0.20289855 0.08      ]
#  [0.86075949 0.72727273 0.85507246 0.92      ]
#  [0.82278481 0.63636364 0.66666667 0.6       ]]
# Completed for delta = 1, sample = 0 in 12.557798862457275 sec as unsat
# Completed for delta = 1, sample = 1 in 2.2979931831359863 sec as sat
# Completed for delta = 1, sample = 2 in 7.2437708377838135 sec as sat
# Completed for delta = 1, sample = 3 in 19.177116870880127 sec as unsat
# Completed for delta = 1, sample = 4 in 11.036794185638428 sec as unsat
# Completed for delta = 1 with 3 in avg time 10.462694787979126 sec
# Completed for delta = 2, sample = 0 in 4.473829507827759 sec as sat
# Completed for delta = 2, sample = 1 in 2.118896484375 sec as sat
# Completed for delta = 2, sample = 2 in 5.217900276184082 sec as sat
# Completed for delta = 2, sample = 3 in 125.89900875091553 sec as unsat
# Completed for delta = 2, sample = 4 in 66.86198091506958 sec as unsat
# Completed for delta = 2 with 2 in avg time 40.91432318687439 sec

# 6min 33s ± 1min 26s per loop (mean ± std. dev. of 5 runs, 1 loop each)