In [None]:
%matplotlib widget
import spike_data_augmentation
from spike_data_augmentation.datasets.dataloader import Dataloader
import spike_data_augmentation.transforms as transforms
import ipdb
import numpy as np
from utils.helper import plot_centers, create_histograms
from tqdm.auto import tqdm
from sklearn.cluster import MiniBatchKMeans
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix
from time import gmtime, strftime
print(strftime("Started on %a, %d %b %Y %H:%M:%S", gmtime()))

### Parametrise notebook using papermill

In [None]:
surface_dimensions = [9,9]
dropout_probability = 0
refractory_period = 0
time_constant = 10e3
n_of_centers = 500
first_saccade_only = True
file_name = 'placeholder'

### Choose training dataset and representation

In [None]:
if file_name != 'placeholder':
    import time
    time.sleep(np.random.rand(1)*100)

transform = transforms.Compose([transforms.RefractoryPeriod(refractory_period=refractory_period), transforms.DropEvent(drop_probability=dropout_probability)])
representation = spike_data_augmentation.representations.Timesurface(surface_dimensions=surface_dimensions, tau=time_constant, merge_polarities=True)

trainset = spike_data_augmentation.datasets.IBMGesture(save_to='./data', train=True, representation=representation, download=True)
#trainset = spike_data_augmentation.datasets.NMNIST(save_to='./data', train=True, representation=representation, download=True, first_saccade_only=first_saccade_only)
trainloader = Dataloader(trainset, shuffle=True)

augmentation = False if dropout_probability == 0 and refractory_period == 0 else True

if augmentation:
    trainset_augmented = spike_data_augmentation.datasets.IBMGesture(save_to='./data', train=True, transform=transform, 
                                                                     representation=representation, download=True)
    #trainset_augmented = spike_data_augmentation.datasets.NMNIST(save_to='./data', train=True, transform=transform, representation=representation, download=True, first_saccade_only=first_saccade_only)
    trainloader_augmented = Dataloader(trainset_augmented, shuffle=True)

### Read timesurfaces and use minibatch clustering

In [None]:
kmeans = MiniBatchKMeans(n_clusters=n_of_centers)
dims_prod = np.prod(surface_dimensions)

if augmentation:  # mix normal and augmented training sets
    mixed_loaders = zip(trainloader, trainloader_augmented)
    #for rec, rec_aug in tqdm(mixed_loaders):
    for rec, rec_aug in mixed_loaders:
        surf, label = rec
        kmeans.partial_fit(surf.reshape(-1, dims_prod))
        surf_aug, label = rec_aug
        kmeans.partial_fit(surf_aug.reshape(-1, dims_prod))
else:  # only take training set without transforms
    trainiterator = iter(trainloader)
    #result = [kmeans.partial_fit(surfaces.reshape(-1, dims_prod)) for surfaces, label in tqdm(trainiterator)]
    result = [kmeans.partial_fit(surfaces.reshape(-1, dims_prod)) for surfaces, label in trainiterator]

### plot centers

In [None]:
centers = kmeans.cluster_centers_.reshape([-1,] + surface_dimensions)
activations = kmeans.counts_
#plot_centers(centers, activations)

### Train classifiers

In [None]:
trainloader = Dataloader(trainset, shuffle=True)
trainiterator = iter(trainloader)

training_cluster_assignments = []
training_labels = []
#for surfaces, label in tqdm(trainiterator):
for surfaces, label in trainiterator:
    surfaces = surfaces.reshape(-1, dims_prod)
    surf_labels = kmeans.predict(surfaces)
    training_cluster_assignments.append(surf_labels)
    training_labels.append(label)

training_features = create_histograms(training_cluster_assignments, n_of_centers)

logreg = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=2000)
logreg.fit(training_features, training_labels)

gnb = GaussianNB()
gnb.fit(training_features, training_labels)

knn = KNeighborsClassifier()
knn.fit(training_features, training_labels)

### Build testing features and classify

In [None]:
testset = spike_data_augmentation.datasets.IBMGesture(save_to='./data', train=False, representation=representation, download=True)
#testset = spike_data_augmentation.datasets.NMNIST(save_to='./data', train=False, representation=representation, download=True, first_saccade_only=first_saccade_only)
testloader = Dataloader(testset, shuffle=True)
testiterator = iter(testloader)

testing_cluster_assignments = []
testing_labels = []
for surfaces, label in tqdm(testiterator):
    surfaces = surfaces.reshape(-1, np.prod(surface_dimensions))
    surf_labels = kmeans.predict(surfaces)
    testing_cluster_assignments.append(surf_labels)
    testing_labels.append(label)

In [None]:
testing_features = create_histograms(testing_cluster_assignments, n_of_centers)
assert len(testing_features) == len(testing_labels)

scores = dict(zip(['logreg', 'gnb', 'knn'], [0,0,0]))
scores['logreg'] = logreg.score(testing_features, testing_labels)
scores['gnb'] = gnb.score(testing_features, testing_labels)
scores['knn'] = knn.score(testing_features, testing_labels)
scores = {k: round(v, 4) for k,v in scores.items()}
winner_classifier = max(scores.keys(), key=(lambda key: scores[key]))
print(str(scores))
#print(classification_report(testing_labels, logreg.predict(testing_features)))
print(confusion_matrix(testing_labels, logreg.predict(testing_features)))
print(strftime("Finished on %a, %d %b %Y %H:%M:%S", gmtime()))

### don't look at this hacky bit to list scores in nb filenames generated by papermill ;P

In [None]:
import os
new_file_name = './milled_nbs/' + str(scores[winner_classifier]) + '_' + winner_classifier + file_name
os.rename('./milled_nbs/' + file_name, new_file_name)