In [1]:
%matplotlib widget
import spike_data_augmentation
from spike_data_augmentation.datasets.dataloader import Dataloader
import spike_data_augmentation.transforms as transforms
from sklearn.cluster import MiniBatchKMeans
import ipdb
import numpy as np
from utils.helper import plot_centers, create_histograms
from tqdm.auto import tqdm
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
from time import gmtime, strftime
print(strftime("Started on %a, %d %b %Y %H:%M:%S", gmtime()))

Started on Tue, 05 Nov 2019 11:29:12


### Parametrise notebook using papermill

In [2]:
surface_dimensions = [9,9]
dropout_probability = 0.9
time_constant = 50e3
n_of_centers = 49
file_name = 'placeholder'

### Choose training dataset and representation

In [5]:
#%%capture
transform = transforms.Compose([transforms.DropEvent(drop_probability=dropout_probability)])
representation = spike_data_augmentation.representations.Timesurface(surface_dimensions=surface_dimensions, tau=time_constant, merge_polarities=True)

trainset = spike_data_augmentation.datasets.NMNIST(save_to='./data', train=True, transform=transform, representation=representation, download=True)

Using downloaded and verified file: ./data/nmnist_train.zip
Extracting ./data/nmnist_train.zip to ./data


KeyboardInterrupt: 

### Read timesurfaces and use minibatch clustering

In [None]:
trainloader = Dataloader(trainset, shuffle=True)
trainiterator = iter(trainloader)

kmeans = MiniBatchKMeans(n_clusters=n_of_centers)

dims_prod = np.prod(surface_dimensions)
[kmeans.partial_fit(surfaces.reshape(-1, dims_prod)) for surfaces, label in tqdm(trainiterator)]

### plot centers

In [None]:
centers = kmeans.cluster_centers_.reshape([-1,] + surface_dimensions)
activations = kmeans.counts_
plot_centers(centers, activations)

### Train classifier

In [None]:
trainloader = Dataloader(trainset, shuffle=True)
trainiterator = iter(trainloader)

all_kmeans_labels = []
all_labels = []
dims_prod = np.prod(surface_dimensions)
for surfaces, label in tqdm(trainiterator):
    surfaces = surfaces.reshape(-1, dims_prod)
    surf_labels = kmeans.predict(surfaces)
    all_kmeans_labels.append(surf_labels)
    all_labels.append(label)

hists = create_histograms(all_kmeans_labels, n_of_centers)

logreg = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=2000)
logreg.fit(hists, all_labels)

### Build histograms for each testing datapoint and classify

In [None]:
testset = spike_data_augmentation.datasets.NMNIST(save_to='./data', train=False, transform=transform, representation=representation, download=True)
testloader = Dataloader(testset, shuffle=True)
testiterator = iter(testloader)

all_kmeans_labels = []
all_labels = []
for surfaces, label in tqdm(testiterator):
    surfaces = surfaces.reshape(-1, np.prod(surface_dimensions))
    surf_labels = kmeans.predict(surfaces)
    all_kmeans_labels.append(surf_labels)
    all_labels.append(label)

In [None]:
hists = create_histograms(all_kmeans_labels, n_of_centers)
assert len(hists) == len(all_labels)

score = logreg.score(hists, all_labels)
print(classification_report(all_labels, logreg.predict(hists)))
print(confusion_matrix(all_labels, logreg.predict(hists)))
print(strftime("Finished on %a, %d %b %Y %H:%M:%S", gmtime()))

### don't look at this hacky bit to list scores in nb filenames generated by papermill ;P

In [None]:
import os
new_file_name = './milled_nbs/' + str(round(score, 2)) + file_name
os.rename('./milled_nbs/' + file_name, new_file_name)