In [None]:
%matplotlib widget
import spike_data_augmentation as sda
import numpy as np
import sklearn as skl
import ipdb
import time
from sklearn.cluster import MiniBatchKMeans
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from utils import plot_centers, create_histograms
from tqdm.auto import tqdm
print(time.strftime("Started on %a, %d %b %Y %H:%M:%S", time.gmtime()))

### Parametrise notebook using papermill

In [None]:
surface_dimensions = [11,11]
dropout_probability = 0
refractory_period = 0
time_constant = 20e3
n_of_centers = 500
dataset = 'NCARS'
first_saccade_only = False
file_name = 'placeholder'

### Choose training dataset and transforms

In [None]:
transform = sda.transforms.Compose([sda.transforms.RefractoryPeriod(refractory_period=refractory_period), 
                                sda.transforms.DropEvents(drop_probability=dropout_probability)])

surface_transform = sda.transforms.Compose([sda.transforms.ToTimesurface(surface_dimensions=surface_dimensions, tau=time_constant, merge_polarities=True)])
transform.transforms += surface_transform.transforms

args = dict(save_to='./data', train=True, transform=surface_transform, download=False)
if dataset == 'IBMGesture':
    trainset = sda.datasets.IBMGesture(**args)
elif dataset == 'NCARS':
    trainset = sda.datasets.NCARS(**args)
elif dataset == 'NMNIST':
    trainset = sda.datasets.NMNIST(**args, first_saccade_only=first_saccade_only)
trainloader = sda.datasets.dataloader.Dataloader(trainset, shuffle=True)

augmentation = False if dropout_probability == 0 and refractory_period == 0 else True
if augmentation:
    args_augmented = dict(save_to='./data', train=True, transform=transform, download=False)
    if dataset == 'IBMGesture':
        trainset_augmented = sda.datasets.IBMGesture(**args_augmented)
    elif dataset == 'NCARS':
        trainset_augmented = sda.datasets.NCARS(**args_augmented)
    elif dataset == 'NMNIST':
        trainset_augmented = sda.datasets.NMNIST(**args_augmented, first_saccade_only=first_saccade_only)
    trainloader_augmented = sda.datasets.dataloader.Dataloader(trainset_augmented, shuffle=True)

### Read timesurfaces and use minibatch clustering

In [None]:
kmeans = MiniBatchKMeans(n_clusters=n_of_centers, verbose=True, reassignment_ratio=0.001)
dims_prod = np.prod(surface_dimensions)

if augmentation:  # mix normal and augmented training sets
    mixed_loaders = zip(trainloader, trainloader_augmented)
    for rec, rec_aug in tqdm(mixed_loaders):
    #for rec, rec_aug in mixed_loaders:
        surf, label = rec
        kmeans.partial_fit(surf.reshape(-1, dims_prod))
        surf_aug, label = rec_aug
        kmeans.partial_fit(surf_aug.reshape(-1, dims_prod))
else:  # only take training set without transforms
    trainiterator = iter(trainloader)
    result = [kmeans.partial_fit(surfaces.reshape(-1, dims_prod)) for surfaces, label in tqdm(trainiterator)]
    #result = [kmeans.partial_fit(surfaces.reshape(-1, dims_prod)) for surfaces, label in trainiterator]

In [None]:
centers = kmeans.cluster_centers_.reshape([-1,] + surface_dimensions)
activations = kmeans.counts_

### model persistence

In [None]:
import pickle
if True:
    with open('saved_models/mb-kmeans{0}.pkl'.format(file_name), 'wb') as f:
        pickle.dump(kmeans, f)
else:
    with open('saved_models/mb-kmeans{0}.pkl'.format(file_name), 'rb') as f:
        kmeans = pickle.load(f)

### Train classifiers

In [None]:
trainloader = sda.datasets.dataloader.Dataloader(trainset, shuffle=True)
trainiterator = iter(trainloader)

training_cluster_assignments = []
Y_train = []
for surfaces, label in tqdm(trainiterator):
    surfaces = surfaces.reshape(-1, dims_prod)
    surf_labels = kmeans.predict(surfaces)
    training_cluster_assignments.append(surf_labels)
    Y_train.append(label)

X_train = create_histograms(training_cluster_assignments, n_of_centers)
scaler = skl.preprocessing.StandardScaler().fit(X_train)
X_train = scaler.transform(X_train)

logreg = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=1000)
logreg.fit(X_train, Y_train)

gnb = GaussianNB()
gnb.fit(X_train, Y_train)

knn = KNeighborsClassifier()
knn.fit(X_train, Y_train)

### Build testing features and classify

In [None]:
args_test = dict(save_to='./data', train=False, transform=surface_transform, download=True)
if dataset == 'IBMGesture':
    testset = sda.datasets.IBMGesture(**args_test)
elif dataset == 'NCARS':
    testset = sda.datasets.NCARS(**args_test)
elif dataset == 'NMNIST':
    testset = sda.datasets.NMNIST(**args_test, first_saccade_only=first_saccade_only)
testloader = sda.datasets.dataloader.Dataloader(testset, shuffle=True)
testiterator = iter(testloader)

testing_cluster_assignments = []
Y_test = []
for surfaces, label in tqdm(testiterator):
    surfaces = surfaces.reshape(-1, np.prod(surface_dimensions))
    surf_labels = kmeans.predict(surfaces)
    testing_cluster_assignments.append(surf_labels)
    Y_test.append(label)

In [None]:
X_test = create_histograms(testing_cluster_assignments, n_of_centers)
X_test = scaler.transform(X_test)
assert len(X_test) == len(Y_test)

scores = dict(zip(['logreg', 'gnb', 'knn'], [0,0,0]))
scores['logreg'] = logreg.score(X_test, Y_test)
scores['gnb'] = gnb.score(X_test, Y_test)
scores['knn'] = knn.score(X_test, Y_test)
scores = {k: round(v, 4) for k,v in scores.items()}
winner_classifier = max(scores.keys(), key=(lambda key: scores[key]))
print(str(scores))
#print(skl.metrics.classification_report(Y_test, logreg.predict(X_test)))
print(skl.metrics.confusion_matrix(Y_test, logreg.predict(X_test)))
print(time.strftime("Finished on %a, %d %b %Y %H:%M:%S", time.gmtime()))

In [None]:
np.save("ncars_normalised_new.npy", scores, allow_pickle=True)

In [None]:
np.load("ncars_normalised_new.npy", allow_pickle=True)

In [None]:
X_train_orig = create_histograms(training_cluster_assignments, n_of_centers)
X_test_orig = create_histograms(testing_cluster_assignments, n_of_centers)

np.save("X_train.npy", X_train_orig)
np.save("X_test.npy", X_test_orig)
np.save("Y_train.npy", Y_train)
np.save("Y_test.npy", Y_test)


### don't look at this hacky bit to list scores in nb filenames generated by papermill ;P

In [None]:
import os
new_file_name = './milled_nbs/' + str(scores[winner_classifier]) + '_' + winner_classifier + file_name
os.rename('./milled_nbs/' + file_name, new_file_name)