-
Notifications
You must be signed in to change notification settings - Fork 1
/
iris.py
151 lines (109 loc) · 4.68 KB
/
iris.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import snntorch as snn
from matplotlib import pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from torch import nn
import snntorch.spikeplot as splt
CUT_TRAINING_SHORT = False
TRAINING_CUTOFF = 1
dtype = torch.float32
device = 'cpu'
iris = datasets.load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=0)
input_nodes = 4 # Sepal length, sepal width, petal length, petal width
hidden_neurons = 500 # Arbitrary number of hidden layer neurons
output_neurons = 3 # Setosa, Virginica, Versicolor
dt_steps = 500 # simulation steps
beta = 0.95 # used for LIF decay rate
epochs = 300
class SpikingIrisClassifier(nn.Module):
def __init__(self):
super().__init__()
# What about changing the architecture to skip the transfer1 layer?
# We need something to map 500 to 3, they maybe have 3 neurons all the
# time in the transfer layer?
# In the end, it's neuromorphic, not neuroreal
self.input = nn.Linear(input_nodes, hidden_neurons)
self.lif1 = snn.Leaky(beta=beta)
self.transfer1 = nn.Linear(hidden_neurons, output_neurons)
self.output = snn.Leaky(beta=beta)
def forward(self, x):
# We need to init the membranes since they handle the spiking
membrane_lif1 = self.lif1.init_leaky()
membrane_output = self.output.init_leaky()
spikes_output = []
membrane_output_voltages = []
input_current = self.input(x)
spikes_lif1, membrane_lif1 = self.lif1(input_current, membrane_lif1)
# because there is only one membrane, we just energized it
transfer_current = self.transfer1(spikes_lif1)
spikes_transfer, membrane_output = self.output(transfer_current,
membrane_output)
spikes_output.append(spikes_transfer)
membrane_output_voltages.append(membrane_output)
return spikes_transfer, membrane_output
# Pushing the whole dataset in every pass since it can cause overfitting when
# we start minibatching and giving bad classifications
def train_classifier(classifier, loss_function, optimizer, inputs, actuals):
print('Training started')
print('\n')
loss_history = []
for epoch in range(epochs):
print(f'Epoch: {epoch + 1}/{epochs}')
classifier.train()
spikes, membrane_voltages = classifier(
torch.tensor(inputs).to(device, dtype=dtype))
loss_values = torch.zeros(1, dtype=dtype, device=device)
loss_values += loss_function(membrane_voltages,
torch.LongTensor(actuals))
optimizer.zero_grad()
loss_values.backward()
optimizer.step()
current_loss_value = loss_values.item()
loss_history.append(current_loss_value)
print(f'Current loss: {current_loss_value}')
if epoch % 25 == 0:
fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
# s: size of scatter points; c: color of scatter points
splt.raster(spikes, ax, s=1.5, c="black")
plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()
if CUT_TRAINING_SHORT and current_loss_value < TRAINING_CUTOFF:
break
return loss_history
def test_classifier(classifier, inputs, actuals):
print('Testing started')
print('\n')
with torch.no_grad():
classifier.eval()
input = torch.tensor(inputs).to(device, dtype=dtype)
testing_spikes, _ = classifier(input)
_, predicted = torch.max(testing_spikes, 1)
tensored_actuals = torch.LongTensor(actuals)
return (predicted == tensored_actuals).sum().item(), \
tensored_actuals.size()[0]
classifier = SpikingIrisClassifier().to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=5e-4,
betas=(0.9, 0.999))
training_loss_history = train_classifier(classifier, loss_function, optimizer,
X_train, y_train)
correct_guesses, total_guesses = test_classifier(classifier, X_test, y_test)
fig = plt.figure(facecolor="w", figsize=(10, 5))
plt.plot(training_loss_history)
plt.title("Loss Curve")
plt.legend(["Train Loss"])
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()
print("------")
print(f"Testing state: {correct_guesses}/{total_guesses}")
print(f"Test Set Accuracy: {100 * correct_guesses / total_guesses:.2f}%")
print("------")