In [None]:
%matplotlib widget
import spike_data_augmentation
from spike_data_augmentation.datasets.dataloader import Dataloader
from spike_data_augmentation import datasets
import spike_data_augmentation.transforms as transforms
import ipdb
import numpy as np
from utils.helper import plot_centers, create_histograms
from tqdm.auto import tqdm
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix
from time import gmtime, strftime
print(strftime("Started on %a, %d %b %Y %H:%M:%S", gmtime()))

### Choose dataset and representation

In [None]:
n_of_centers = 100
surface_dimensions = (11,11)
transform = transforms.Compose([transforms.ToTimesurface(surface_dimensions=surface_dimensions, tau=5e3, merge_polarities=True)])

testset = spike_data_augmentation.datasets.NMNIST(save_to='./data', transform=transform, download=False)

### Read all timesurfaces and associated labels

In [None]:
testloader = Dataloader(testset, shuffle=True)
testiterator = iter(testloader)

all_surfaces = []
all_labels = []
for surfaces, label in tqdm(testiterator):
    surfaces = surfaces.reshape((-1, *surface_dimensions))
    all_surfaces.append(surfaces)
    all_labels.append(label)
stack = np.vstack(all_surfaces)
print('Read ' + str(stack.shape[0]) + ' surfaces.')

### Clustering

In [None]:
kmeans = KMeans(n_clusters=n_of_centers).fit(stack)
centers = kmeans.cluster_centers_.reshape((-1,)+surface_dimensions)
activations = kmeans.counts_

plot_centers(centers, activations)

### Build histograms for each datapoint

In [None]:
recording_indices = []
running_i = 0
for x in all_surfaces:
    running_i += len(x)
    recording_indices.append(running_i)
del recording_indices[-1]

# build histograms from split kmean labels
kmeans_split = np.split(kmeans.labels_, recording_indices)
hists = []
[hists.append(np.histogram(x, bins=np.arange(0,n_of_centers+1))[0]/len(x)) for x in kmeans_split]
print('Histogram of first data point: ' + str(hists[0]))

assert np.sum(hists[-1]) == len(all_surfaces[-1])
assert len(hists) == len(all_labels)

### Split into training and testing set

In [None]:
# split X and y into training and testing sets
import collections
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(hists,all_labels,test_size=0.25, shuffle=True)
training = collections.Counter(y_train)
testing = collections.Counter(y_test)
print('Training: ' + str(training) + ', testing: ' + str(testing))

### Classify

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
logreg = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=1000)
logreg.fit(X_train,y_train)
#print(logreg.score(X_test, y_test))
print(classification_report(y_test, logreg.predict(X_test)))
#print(confusion_matrix(y_test, logreg.predict(X_test)))