In [2]:
#libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import pandas as pd
import scipy.stats
import random

In [3]:
#load MNIST data

train_data = pd.read_csv('./mnist_train.csv', sep=',', header=None)
train_labels = train_data[0]
train_data = train_data.drop(0, axis=1)

test_data = pd.read_csv('./mnist_test.csv', sep=',', header=None)
test_labels = test_data[0]
test_data = test_data.drop(0, axis=1)

#separate data for generating graphs
graph_data = train_data.sample(n = 10000, random_state=100)
graph_labels = train_labels.sample(n = 10000, random_state=100)
train_data = train_data.drop(graph_data.index)
train_labels = train_labels.drop(graph_data.index)

#convert data to pytorch tensors
train_data = torch.FloatTensor(train_data.to_numpy())
train_labels = torch.LongTensor(train_labels.to_numpy())
test_data = torch.FloatTensor(test_data.to_numpy())
graph_data = torch.FloatTensor(graph_data.to_numpy())

In [4]:
input_size = 784
output_size = 10

In [364]:
#(64, 64) Vanilla FCN

class Vanilla_Net(nn.Module):
    def __init__(self):
        super(Vanilla_Net, self).__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.lrelu1 = nn.LeakyReLU(0.01) #default negative slope
        self.fc2 = nn.Linear(50, 50)
        self.lrelu2 = nn.LeakyReLU(0.01) #default negative slope
        self.fc3 = nn.Linear(50, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.lrelu1(x)
        x = self.fc2(x)
        x = self.lrelu2(x)
        x = self.fc3(x)
        return F.log_softmax(x)

In [243]:
def SGD(net, optimizer, loss, epochs, train_data, train_labels, batch_size):
    for i in range(epochs):
        for j in range(0, train_data.shape[0], batch_size):
            data_minibatch = Variable(train_data[j : j+batch_size])
            label_minibatch = Variable(train_labels[j: j+batch_size])
            optimizer.zero_grad()
            net_out = net(data_minibatch)
            net_loss = loss(net_out, label_minibatch)
            net_loss.backward()
            optimizer.step()

In [6]:
def test_accuracy(net, test_data, test_labels):
    net_out = net(test_data)
    test_out = torch.max(net_out.data, 1)[1].numpy()
    return np.count_nonzero(test_out==test_labels) / len(test_labels)

In [279]:
#determine optimal # of epochs for SGD

epochs = 30
batch_size = 50 #typical value
learning_rate = 0.0001 #tested to not cause dead neurons
mmt = 0.9 #typical value
cur_accuracy = 0
prev_accuracy = 0
while True:
    prev_accuracy = cur_accuracy
    my_net = Vanilla_Net()
    optimizer = torch.optim.SGD(my_net.parameters(), lr=learning_rate, momentum=mmt)
    loss = nn.CrossEntropyLoss()
    SGD(my_net, optimizer, loss, epochs, train_data, train_labels, batch_size)
    cur_accuracy = test_accuracy(my_net, test_data, test_labels)
    print(cur_accuracy)
    if (cur_accuracy <= prev_accuracy-0.005):
        break
    epochs += 1
epochs -= 1
print(epochs)

  return F.log_softmax(x)


KeyboardInterrupt: 

After running the above code a few times, it seems like the network typically achieves high accuracy after around 20 epochs. 

In [8]:
# join train and test data

train_test_data = torch.cat((train_data, test_data))
train_test_labels = torch.cat((train_labels, torch.LongTensor(test_labels.to_numpy())))

In [377]:
# train vanilla network with train+test data

vanilla_net = Vanilla_Net()
optimizer = torch.optim.SGD(vanilla_net.parameters(), lr=0.0001, momentum=0.9)
loss = nn.CrossEntropyLoss()
SGD(vanilla_net, optimizer, loss, epochs=20, train_data=train_test_data, train_labels=train_test_labels, batch_size=50)

  return F.log_softmax(x)


In [83]:
#(300, 100) FCN trained with dropout

p = 0.5 #same as paper

class Dropout_Net(nn.Module):
    def __init__(self):
        super(Dropout_Net, self).__init__()
        self.fc1 = nn.Linear(input_size, 300)
        self.fc2 = nn.Linear(300, 100)
        self.fc3 = nn.Linear(100, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p)
        x = F.relu(self.fc2(x))
        #x = F.dropout(x, p)
        x = self.fc3(x)
        return F.log_softmax(x)

In [84]:
#determine optimal # of epochs for dropout_net

epochs = 1
batch_size = 20 #typical value
learning_rate = 0.001 #default value
mmt = 0.9 #typical value
cur_accuracy = 0
prev_accuracy = 0
while True:
    prev_accuracy = cur_accuracy
    my_net = Dropout_Net()
    optimizer = torch.optim.SGD(my_net.parameters(), lr=learning_rate, momentum=mmt)
    loss = nn.CrossEntropyLoss()
    SGD(my_net, optimizer, loss, epochs, train_data, train_labels, batch_size)
    cur_accuracy = test_accuracy(my_net, test_data, test_labels)
    print(cur_accuracy)
    if (cur_accuracy <= prev_accuracy-0.01):
        break
    epochs += 1
epochs -= 1
print(epochs)

  return F.log_softmax(x)


0.8501
0.8736
0.8971
0.9055
0.9127
0.9146
0.9201
0.9263
0.9291
0.9344
0.9218
10


It takes around 10 epochs to train the dropout network. 

In [314]:
# train dropout network

dropout_net = Dropout_Net()
optimizer = torch.optim.SGD(dropout_net.parameters(), lr=0.001, momentum=0.9)
loss = nn.CrossEntropyLoss()
SGD(dropout_net, optimizer, loss, epochs=10, train_data=train_test_data, train_labels=train_test_labels, batch_size=20)

NameError: name 'Dropout_Net' is not defined

In [212]:
# (64, 64) FCN trained with batch norm

class BN_Net(nn.Module):
    def __init__(self):
        super(BN_Net, self).__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.bn1 = nn.BatchNorm1d(50)
        self.fc2 = nn.Linear(50, 50)
        self.bn2 = nn.BatchNorm1d(50)
        self.fc3 = nn.Linear(50, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return F.log_softmax(x)

In [213]:
#determine optimal # of epochs for batch_norm_net

epochs = 20
batch_size = 50 #typical value
learning_rate = 0.0001 #tested to not cause neuron death
mmt = 0.9 #typical value
cur_accuracy = 0
prev_accuracy = 0
while True:
    prev_accuracy = cur_accuracy
    my_net = BN_Net()
    optimizer = torch.optim.SGD(my_net.parameters(), lr=learning_rate, momentum=mmt)
    loss = nn.CrossEntropyLoss()
    SGD(my_net, optimizer, loss, epochs, train_data, train_labels, batch_size)
    cur_accuracy = test_accuracy(my_net, test_data, test_labels)
    print(cur_accuracy)
    if (cur_accuracy <= prev_accuracy-0.01):
        break
    epochs += 1
epochs -= 1
print(epochs)

  return F.log_softmax(x)


0.9529


KeyboardInterrupt: 

The batch norm network achieves high accuracy after 20 epochs. 

In [214]:
# train batch norm network

batch_norm_net = BN_Net()
optimizer = torch.optim.SGD(batch_norm_net.parameters(), lr=0.0001, momentum=0.9)
loss = nn.CrossEntropyLoss()
SGD(batch_norm_net, optimizer, loss, epochs=20, train_data=train_test_data, train_labels=train_test_labels, batch_size=50)

  return F.log_softmax(x)


In [372]:
def neuron_values(net, data):
    activations = []
    def get_activation():
        def hook(model, input, output):
            activations.append(output.detach())
        return hook
    
    net.lrelu1.register_forward_hook(get_activation())
    net.lrelu2.register_forward_hook(get_activation())
    net(data)
    
    activations[0] = activations[0].numpy()
    activations[1] = activations[1].numpy()
    neurons = np.concatenate((activations[0].T, activations[1].T))
    return neurons

In [378]:
vanilla_neurons = neuron_values(vanilla_net, graph_data)
#dropout_neurons = neuron_values(dropout_net, graph_data)
#batch_norm_neurons = neuron_values(batch_norm_net, graph_data)

  return F.log_softmax(x)


In [193]:
def correlation_graph(neurons):
    n = len(neurons)
    adj_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if (i == j):
                adj_matrix[i][j] = 0
            else: 
                adj_matrix[i][j] = abs(scipy.stats.pearsonr(neurons[i], neurons[j])[0])
    return adj_matrix

In [379]:
#split neuron data into 10 subsets to construct 10 graphs
vanilla_subsets = np.array_split(vanilla_neurons, 5, 1)
#dropout_subsets = np.array_split(dropout_neurons, 10, 1)
#batch_norm_subsets = np.array_split(batch_norm_neurons, 10, 1)

In [256]:
network_graphs = []

for i in range(10):
    network_graphs.append(correlation_graph(vanilla_subsets[i]))
    #network_graphs.append(correlation_graph(dropout_subsets[i]))
    network_graphs.append(correlation_graph(batch_norm_subsets[i]))

  adj_matrix[i][j] = abs(scipy.stats.pearsonr(neurons[i], neurons[j])[0])


In [384]:
#test code

vanilla_test = neuron_values(vanilla_net, graph_data)
count = 0
for i in range(5):
    for j in range(100):
        if (np.count_nonzero(vanilla_subsets[i][j]) == 0):
            print((i, j))
            count +=1
    
print(count)

for i in vanilla_neurons:
    if (np.count_nonzero(i) == 0):
        print("no")

print(np.matmul(vanilla_net.fc1.weight.data.numpy()[0], graph_data[0].numpy())*0.1)
print(np.matmul(vanilla_net.fc1.weight.data.numpy()[0], graph_data[1].numpy()))
print(vanilla_neurons)


#test_graph = correlation_graph(vanilla_test)
#print(np.count_nonzero(~np.isnan(test_graph)))

0
-2.860476303100586
146.9675
[[-2.8592715e-01  1.4697955e+02  4.5925129e+01 ...  9.6395020e+01
   5.0377380e+01  5.8842236e+01]
 [ 3.4607075e+01 -9.0028960e-01 -1.1234105e+00 ... -5.5461705e-01
  -5.3529197e-01  1.2493665e+01]
 [-3.8031331e-01 -1.3355006e+00 -1.2671062e+00 ... -8.0515796e-01
  -8.1621784e-01 -5.5112785e-01]
 ...
 [-2.8656247e-01 -1.1366471e-01 -3.3285168e-01 ...  6.0003614e-01
  -1.3351262e-01 -2.2673206e-01]
 [-5.9106003e-02 -4.3972909e-02  1.8050117e+01 ...  2.5170645e+01
  -1.1794520e-01  5.9268980e+00]
 [ 7.8760047e+00 -1.4147338e-01  2.5203075e+01 ... -3.1218180e-01
  -1.7134462e-01  3.2964664e+01]]


  return F.log_softmax(x)


In [241]:
#more test code

test_comp = np.matmul(vanilla_net.fc1.weight.data.numpy(), graph_data.numpy()[0])
test_comp = test_comp + vanilla_net.fc1.bias.data.numpy()
test_comp = (abs(test_comp) + test_comp)/2
#test_comp = np.matmul(vanilla_net.fc2.weight.data.numpy(), test_comp)
#test_comp = test_comp + vanilla_net.fc2.bias.data.numpy()
#test_comp = (abs(test_comp) + test_comp)/2
print(test_comp)

test = neuron_values(vanilla_net, graph_data[0])[:100]

print(test)

[  0.         0.         0.         0.       249.30411    0.
  11.832406   0.         0.       445.0323     0.         0.
   0.         0.         0.         0.         0.       660.44714
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.         0.       458.20786    0.         0.
   0.         0.         0.       320.37054    0.         0.
   0.       104.83368    0.         0.         0.         0.
   0.         0.      ]
[  0.         0.         0.         0.       249.30411    0.
  11.832425   0.         0.       445.0323     0.         0.
   0.         0.         0.         0.         0.       660.44714
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.         0.       458.2079     0.         0.
   0.         0.         0.       320.37057    0.         0.
   0.       104.83367    0.         0.         0.  

  return F.log_softmax(x)


In [260]:
#topological clustering framework

import sys
import math
import random
import numpy as np

from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from sklearn.metrics.cluster import contingency_matrix
from matplotlib import pyplot as plt


class TopClustering:
    """Topological clustering.
    
    Attributes:
        n_clusters: 
          The number of clusters.
        top_relative_weight:
          Relative weight between the geometric and topological terms.
          A floating point number between 0 and 1.
        max_iter_alt:
          Maximum number of iterations for the topological clustering.
        max_iter_interp:
          Maximum number of iterations for the topological interpolation.
        learning_rate:
          Learning rate for the topological interpolation.
        
    Reference:
        Songdechakraiwut, Tananun, Bryan M. Krause, Matthew I. Banks, Kirill V. Nourski, and Barry D. Van Veen. 
        "Fast topological clustering with Wasserstein distance." 
        International Conference on Learning Representations (ICLR). 2022.
    """

    def __init__(self, n_clusters, top_relative_weight, max_iter_alt,
                 max_iter_interp, learning_rate):
        self.n_clusters = n_clusters
        self.top_relative_weight = top_relative_weight
        self.max_iter_alt = max_iter_alt
        self.max_iter_interp = max_iter_interp
        self.learning_rate = learning_rate

    def fit_predict(self, data):
        """Computes topological clustering and predicts cluster index for each sample.
        
            Args:
                data:
                  Training instances to cluster.
                  
            Returns:
                Cluster index each sample belongs to.
        """
        data = np.asarray(data)
        n_node = data.shape[1]
        n_edges = math.factorial(n_node) // math.factorial(2) // math.factorial(
            n_node - 2)  # n_edges = (n_node choose 2)
        n_births = n_node - 1
        self.weight_array = np.append(
            np.repeat(1 - self.top_relative_weight, n_edges),
            np.repeat(self.top_relative_weight, n_edges))

        # Networks represented as vectors concatenating geometric and topological info
        X = []
        for adj in data:
            X.append(self._vectorize_geo_top_info(adj))
        for l in X:
            print(l.shape)
        X = np.asarray(X)

        # Random initial condition
        self.centroids = X[random.sample(range(X.shape[0]), self.n_clusters)]

        # Assign the nearest centroid index to each data point
        assigned_centroids = self._get_nearest_centroid(
            X[:, None, :], self.centroids[None, :, :])
        prev_assigned_centroids = assigned_centroids

        for it in range(self.max_iter_alt):
            for cluster in range(self.n_clusters):
                # Previous iteration centroid
                prev_centroid = np.zeros((n_node, n_node))
                prev_centroid[np.triu_indices(
                    prev_centroid.shape[0],
                    k=1)] = self.centroids[cluster][:n_edges]

                # Determine data points belonging to each cluster
                cluster_members = X[assigned_centroids == cluster]

                # Compute the sample mean and top. centroid of the cluster
                cc = cluster_members.mean(axis=0)
                sample_mean = np.zeros((n_node, n_node))
                sample_mean[np.triu_indices(sample_mean.shape[0],
                                            k=1)] = cc[:n_edges]
                top_centroid = cc[n_edges:]
                top_centroid_birth_set = top_centroid[:n_births]
                top_centroid_death_set = top_centroid[n_births:]

                # Update the centroid
                try:
                    cluster_centroid = self._top_interpolation(
                        prev_centroid, sample_mean, top_centroid_birth_set,
                        top_centroid_death_set)
                    self.centroids[cluster] = self._vectorize_geo_top_info(
                        cluster_centroid)
                except:
                    print(
                        'Error: Possibly due to the learning rate is not within appropriate range.'
                    )
                    sys.exit(1)

            # Update the cluster membership
            assigned_centroids = self._get_nearest_centroid(
                X[:, None, :], self.centroids[None, :, :])

            # Compute and print loss as it is progressively decreasing
            loss = self._compute_top_dist(
                X, self.centroids[assigned_centroids]).sum() / len(X)
            print('Iteration: %d -> Loss: %f' % (it, loss))

            if (prev_assigned_centroids == assigned_centroids).all():
                break
            else:
                prev_assigned_centroids = assigned_centroids
        return assigned_centroids

    def _vectorize_geo_top_info(self, adj):
        birth_set, death_set = self._compute_birth_death_sets(
            adj)  # topological info
        vec = adj[np.triu_indices(adj.shape[0], k=1)]  # geometric info
        return np.concatenate((vec, birth_set, death_set), axis=0)

    def _compute_birth_death_sets(self, adj):
        """Computes birth and death sets of a network."""
        mst, nonmst = self._bd_demomposition(adj)
        birth_ind = np.nonzero(mst)
        death_ind = np.nonzero(nonmst)
        return np.sort(mst[birth_ind]), np.sort(nonmst[death_ind])

    def _bd_demomposition(self, adj):
        """Birth-death decomposition."""
        eps = np.nextafter(0, 1)
        adj[adj == 0] = eps
        adj = np.triu(adj, k=1)
        Xcsr = csr_matrix(-adj)
        Tcsr = minimum_spanning_tree(Xcsr)
        mst = -Tcsr.toarray()  # reverse the negative sign
        nonmst = adj - mst
        birth_ind = np.nonzero(mst)
        print(nonmst[birth_ind])
        print(birth_ind[0][0], birth_ind[1][0])
        print(mst[birth_ind[0][0]][birth_ind[1][0]])
        print(adj[birth_ind[0][0]][birth_ind[1][0]])
        print(nonmst[birth_ind[0][0]][birth_ind[1][0]])
        return mst, nonmst

    def _get_nearest_centroid(self, X, centroids):
        """Determines cluster membership of data points."""
        dist = self._compute_top_dist(X, centroids)
        nearest_centroid_index = np.argmin(dist, axis=1)
        return nearest_centroid_index

    def _compute_top_dist(self, X, centroid):
        """Computes the pairwise top. distances between networks and centroids."""
        return np.dot((X - centroid)**2, self.weight_array)

    def _top_interpolation(self, init_centroid, sample_mean,
                           top_centroid_birth_set, top_centroid_death_set):
        """Topological interpolation."""
        curr = init_centroid
        for _ in range(self.max_iter_interp):
            # Geometric term gradient
            geo_gradient = 2 * (curr - sample_mean)

            # Topological term gradient
            sorted_birth_ind, sorted_death_ind = self._compute_optimal_matching(
                curr)
            top_gradient = np.zeros_like(curr)
            top_gradient[sorted_birth_ind] = top_centroid_birth_set
            top_gradient[sorted_death_ind] = top_centroid_death_set
            top_gradient = 2 * (curr - top_gradient)

            # Gradient update
            curr -= self.learning_rate * (
                (1 - self.top_relative_weight) * geo_gradient +
                self.top_relative_weight * top_gradient)
        return curr

    def _compute_optimal_matching(self, adj):
        mst, nonmst = self._bd_demomposition(adj)
        birth_ind = np.nonzero(mst)
        death_ind = np.nonzero(nonmst)
        sorted_temp_ind = np.argsort(mst[birth_ind])
        sorted_birth_ind = tuple(np.array(birth_ind)[:, sorted_temp_ind])
        sorted_temp_ind = np.argsort(nonmst[death_ind])
        sorted_death_ind = tuple(np.array(death_ind)[:, sorted_temp_ind])
        return sorted_birth_ind, sorted_death_ind


In [257]:
#clustering network graphs

n_clusters = 2
top_relative_weight = 0.9
max_iter_alt = 300
max_iter_interp = 300
learning_rate = 0.05

labels_pred = TopClustering(n_clusters, top_relative_weight, max_iter_alt,
                                max_iter_interp,
                                learning_rate).fit_predict(network_graphs)

print(labels_pred)

(9905,)
(9900,)
(9902,)
(9900,)
(9905,)
(9900,)
(9906,)
(9900,)
(9903,)
(9900,)
(9905,)
(9900,)
(9905,)
(9900,)
(9905,)
(9900,)
(9904,)
(9900,)
(9903,)
(9900,)


ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (20,) + inhomogeneous part.

In [262]:
test_birth, test_death = TopClustering(n_clusters, top_relative_weight, max_iter_alt,
                                max_iter_interp,
                                learning_rate)._compute_birth_death_sets(network_graphs[6])

print(network_graphs[6][0][55])
print(np.count_nonzero(vanilla_subsets[0][0]))
print(np.count_nonzero(vanilla_neurons[55]))
print(scipy.stats.pearsonr(vanilla_subsets[3][0], vanilla_subsets[3][55])[0])

#print(test_death.shape)

[nan  0. nan nan nan nan nan  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.]
0 11
nan
nan
nan
0.014910176558950763
0
2878
-0.014910176558950763
