In [None]:
%matplotlib widget
import spike_data_augmentation
from spike_data_augmentation.datasets.dataloader import Dataloader
import spike_data_augmentation.transforms as transforms
from sklearn.cluster import MiniBatchKMeans
import ipdb
import numpy as np
from utils.helper import plot_centers
from tqdm.auto import tqdm

### Choose dataset and representation

In [None]:
%%capture
surface_dimensions = (11,11)
dropout_probability = 0.5
time_constant = 5e3
n_of_centers = 25

transform = transforms.Compose([transforms.DropEvent(drop_probability=dropout_probability)])
representation = spike_data_augmentation.representations.Timesurface(surface_dimensions=surface_dimensions, tau=time_constant, merge_polarities=True)

testset = spike_data_augmentation.datasets.NMNIST(save_to='./data', train=False, transform=transform, representation=representation, download=True)

### Read timesurfaces and use minibatch clustering

In [None]:
testloader = Dataloader(testset, shuffle=True)
testiterator = iter(testloader)

batch_size = 100
kmeans = MiniBatchKMeans(n_clusters=n_of_centers, batch_size=batch_size)

for surfaces, label in tqdm(testiterator):
    surfaces = surfaces.reshape(-1, np.prod(surface_dimensions))
    split = len(surfaces) // batch_size
    surfs = np.array_split(surfaces, split)
    for surf in surfs:
        kmeans.partial_fit(surf)

### plot centers

In [None]:
centers = kmeans.cluster_centers_.reshape((-1,) + surface_dimensions)
activations = kmeans.counts_

plot_centers(centers, activations)

### Build histograms for each datapoint

In [None]:
testloader = Dataloader(testset, shuffle=True)
testiterator = iter(testloader)

all_kmeans_labels = []
all_labels = []
for surfaces, label in tqdm(testiterator):
    surfaces = surfaces.reshape(-1, np.prod(surface_dimensions))
    surf_labels = kmeans.predict(surfaces)
    all_kmeans_labels.append(surf_labels)
    all_labels.append(label)
    
hists = []
[hists.append(np.histogram(x, bins=np.arange(0,n_of_centers+1))[0]) for x in all_kmeans_labels]
print('Histogram of first data point: ' + str(hists[0]))

assert len(hists) == len(all_labels)

### Split into training and testing set

In [None]:
import collections
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(hists,all_labels,test_size=0.25, shuffle=True)
training = collections.Counter(y_train)
testing = collections.Counter(y_test)
print('Training: ' + str(training)) 
print('Testing:  ' + str(testing))

### Classify

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
logreg = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=10000)
logreg.fit(X_train,y_train)
#print(logreg.score(X_test, y_test))
print(classification_report(y_test, logreg.predict(X_test)))
print(confusion_matrix(y_test, logreg.predict(X_test)))