In [None]:
%matplotlib widget
import spike_data_augmentation
from spike_data_augmentation.datasets.dataloader import Dataloader
import spike_data_augmentation.transforms as transforms
from utils.helper import plot_centers
from sklearn.cluster import MiniBatchKMeans
import ipdb
import numpy as np
from tqdm.auto import tqdm

### Choose dataset and representation

In [None]:
%%capture
surface_dimensions = (11,11)
transform = transforms.Compose([transforms.DropEvent(drop_probability=0.0)])
representation = spike_data_augmentation.representations.Timesurface(surface_dimensions=surface_dimensions, tau=5e3, merge_polarities=True)

testset = spike_data_augmentation.datasets.POKERDVS(save_to='./data', transform=transform, representation=representation)

### Read timesurfaces and use minibatch clustering

In [None]:
testloader = Dataloader(testset, shuffle=False)
testiterator = iter(testloader)

n_of_centers = 16
kmeans = MiniBatchKMeans(n_clusters=n_of_centers, batch_size=n_of_centers, max_no_improvement=None)

all_labels = []
for surfaces, label in tqdm(testiterator):
    surfaces = surfaces.reshape(-1, 11*11)
    split = len(surfaces) // n_of_centers
    surfs = np.array_split(surfaces, split)
    for surf in surfs:
        kmeans.partial_fit(surf)
    all_labels.append(label)

### plot centers

In [None]:
centers = kmeans.cluster_centers_.reshape(-1, surface_dimensions[0], surface_dimensions[1])
activations = kmeans.counts_

plot_centers(centers, activations)

### Build histograms for each datapoint

In [None]:
testloader = Dataloader(testset, shuffle=False)
testiterator = iter(testloader)

all_kmeans_labels = []
for surfaces, label in testiterator:
    surfaces = surfaces.reshape(-1, 11*11)
    surf_labels = kmeans.predict(surfaces)
    all_kmeans_labels.append(surf_labels)
    
hists = []
[hists.append(np.histogram(x, bins=np.arange(0,n_of_centers+1))[0]) for x in all_kmeans_labels]
print('Histogram of first data point: ' + str(hists[0]))

assert len(hists) == len(all_labels)

### Split into training and testing set

In [None]:
import collections
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(hists,all_labels,test_size=0.25, shuffle=True)
training = collections.Counter(y_train)
testing = collections.Counter(y_test)
print('Training: ' + str(training) + ', testing: ' + str(testing))

### Classify

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
logreg = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=1000)
logreg.fit(X_train,y_train)
#print(logreg.score(X_test, y_test))
print(classification_report(y_test, logreg.predict(X_test)))
#print(confusion_matrix(y_test, logreg.predict(X_test)))