In [2]:
#libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import pandas as pd
import scipy.stats
import random

In [3]:
#load MNIST data

train_data = pd.read_csv('./mnist_train.csv', sep=',', header=None)
train_labels = train_data[0]
train_data = train_data.drop(0, axis=1)

test_data = pd.read_csv('./mnist_test.csv', sep=',', header=None)
test_labels = test_data[0]
test_data = test_data.drop(0, axis=1)

#separate data for generating graphs
graph_data = train_data.sample(n = 10000, random_state=100)
graph_labels = train_labels.sample(n = 10000, random_state=100)
train_data = train_data.drop(graph_data.index)
train_labels = train_labels.drop(graph_data.index)

#convert data to pytorch tensors
train_data = torch.FloatTensor(train_data.to_numpy())
train_labels = torch.LongTensor(train_labels.to_numpy())
test_data = torch.FloatTensor(test_data.to_numpy())
graph_data = torch.FloatTensor(graph_data.to_numpy())

In [4]:
input_size = 784
output_size = 10

In [6]:
#(50, 50) Vanilla FCN

class Vanilla_Net(nn.Module):
    def __init__(self):
        super(Vanilla_Net, self).__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.lrelu1 = nn.LeakyReLU(0.01) #default negative slope
        self.fc2 = nn.Linear(50, 50)
        self.lrelu2 = nn.LeakyReLU(0.01) #default negative slope
        self.fc3 = nn.Linear(50, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.lrelu1(x)
        x = self.fc2(x)
        x = self.lrelu2(x)
        x = self.fc3(x)
        return F.log_softmax(x)

In [7]:
def SGD(net, optimizer, loss, epochs, train_data, train_labels, batch_size):
    for i in range(epochs):
        for j in range(0, train_data.shape[0], batch_size):
            data_minibatch = Variable(train_data[j : j+batch_size])
            label_minibatch = Variable(train_labels[j: j+batch_size])
            optimizer.zero_grad()
            net_out = net(data_minibatch)
            net_loss = loss(net_out, label_minibatch)
            net_loss.backward()
            optimizer.step()

In [8]:
def test_accuracy(net, test_data, test_labels):
    net_out = net(test_data)
    test_out = torch.max(net_out.data, 1)[1].numpy()
    return np.count_nonzero(test_out==test_labels) / len(test_labels)

In [417]:
#determine optimal # of epochs for SGD

epochs = 5
batch_size = 50 #typical value
learning_rate = 0.0001 
mmt = 0.9 #typical value
cur_accuracy = 0
prev_accuracy = 0
while True:
    prev_accuracy = cur_accuracy
    my_net = Vanilla_Net()
    optimizer = torch.optim.SGD(my_net.parameters(), lr=learning_rate, momentum=mmt)
    loss = nn.CrossEntropyLoss()
    SGD(my_net, optimizer, loss, epochs, train_data, train_labels, batch_size)
    cur_accuracy = test_accuracy(my_net, test_data, test_labels)
    print(cur_accuracy)
    if (cur_accuracy <= prev_accuracy-0.005):
        break
    epochs += 1
epochs -= 1
print(epochs)

  return F.log_softmax(x)


0.9487
0.9479
0.9511
0.9519


KeyboardInterrupt: 

After running the above code a few times, it seems like the network typically achieves high accuracy after around 20 epochs. 

In [9]:
# join train and test data

train_test_data = torch.cat((train_data, test_data))
train_test_labels = torch.cat((train_labels, torch.LongTensor(test_labels.to_numpy())))

In [10]:
# train vanilla network with train+test data

vanilla_net = Vanilla_Net()
optimizer = torch.optim.SGD(vanilla_net.parameters(), lr=0.0001, momentum=0.9)
loss = nn.CrossEntropyLoss()
SGD(vanilla_net, optimizer, loss, epochs=20, train_data=train_test_data, train_labels=train_test_labels, batch_size=50)

  return F.log_softmax(x)


In [12]:
# (50, 50) FCN trained with batch norm

class BN_Net(nn.Module):
    def __init__(self):
        super(BN_Net, self).__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.bn1 = nn.BatchNorm1d(50)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(50, 50)
        self.bn2 = nn.BatchNorm1d(50)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(50, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return F.log_softmax(x)

In [213]:
#determine optimal # of epochs for batch_norm_net

epochs = 20
batch_size = 50 #typical value
learning_rate = 0.0001 #tested to not cause neuron death
mmt = 0.9 #typical value
cur_accuracy = 0
prev_accuracy = 0
while True:
    prev_accuracy = cur_accuracy
    my_net = BN_Net()
    optimizer = torch.optim.SGD(my_net.parameters(), lr=learning_rate, momentum=mmt)
    loss = nn.CrossEntropyLoss()
    SGD(my_net, optimizer, loss, epochs, train_data, train_labels, batch_size)
    cur_accuracy = test_accuracy(my_net, test_data, test_labels)
    print(cur_accuracy)
    if (cur_accuracy <= prev_accuracy-0.01):
        break
    epochs += 1
epochs -= 1
print(epochs)

  return F.log_softmax(x)


0.9529


KeyboardInterrupt: 

The batch norm network achieves high accuracy after 20 epochs. 

In [13]:
# train batch norm network

batch_norm_net = BN_Net()
optimizer = torch.optim.SGD(batch_norm_net.parameters(), lr=0.0001, momentum=0.9)
loss = nn.CrossEntropyLoss()
SGD(batch_norm_net, optimizer, loss, epochs=20, train_data=train_test_data, train_labels=train_test_labels, batch_size=50)

  return F.log_softmax(x)


In [14]:
def vanilla_neuron_values(net, data):
    activations = []
    def get_activation():
        def hook(model, input, output):
            activations.append(output.detach())
        return hook
    
    net.lrelu1.register_forward_hook(get_activation())
    net.lrelu2.register_forward_hook(get_activation())
    net(data)
    
    activations[0] = activations[0].numpy()
    activations[1] = activations[1].numpy()
    neurons = np.concatenate((activations[0].T, activations[1].T))
    return neurons

In [15]:
def bn_neuron_values(net, data):
    activations = []
    def get_activation():
        def hook(model, input, output):
            activations.append(output.detach())
        return hook
    
    net.relu1.register_forward_hook(get_activation())
    net.relu2.register_forward_hook(get_activation())
    net(data)
    
    activations[0] = activations[0].numpy()
    activations[1] = activations[1].numpy()
    neurons = np.concatenate((activations[0].T, activations[1].T))
    return neurons

In [16]:
vanilla_neurons = vanilla_neuron_values(vanilla_net, graph_data)
batch_norm_neurons = bn_neuron_values(batch_norm_net, graph_data)

  return F.log_softmax(x)
  return F.log_softmax(x)


In [55]:
def correlation_graph(neurons):
    adj_matrix = abs(np.corrcoef(neurons))
    np.fill_diagonal(adj_matrix, 0)
    return adj_matrix

In [56]:
#split neuron data into n subsets to construct n graphs

n_subsets = 100

vanilla_subsets = np.array_split(vanilla_neurons, n_subsets, 1)
batch_norm_subsets = np.array_split(batch_norm_neurons, n_subsets, 1)

In [57]:
#construct network graphs

network_graphs = []

for i in range(n_subsets):
    network_graphs.append(correlation_graph(vanilla_subsets[i]))
    network_graphs.append(correlation_graph(batch_norm_subsets[i]))

In [20]:
#topological clustering framework

import sys
import math
import random
import numpy as np

from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from sklearn.metrics.cluster import contingency_matrix
from matplotlib import pyplot as plt


class TopClustering:
    """Topological clustering.
    
    Attributes:
        n_clusters: 
          The number of clusters.
        top_relative_weight:
          Relative weight between the geometric and topological terms.
          A floating point number between 0 and 1.
        max_iter_alt:
          Maximum number of iterations for the topological clustering.
        max_iter_interp:
          Maximum number of iterations for the topological interpolation.
        learning_rate:
          Learning rate for the topological interpolation.
        
    Reference:
        Songdechakraiwut, Tananun, Bryan M. Krause, Matthew I. Banks, Kirill V. Nourski, and Barry D. Van Veen. 
        "Fast topological clustering with Wasserstein distance." 
        International Conference on Learning Representations (ICLR). 2022.
    """

    def __init__(self, n_clusters, top_relative_weight, max_iter_alt,
                 max_iter_interp, learning_rate):
        self.n_clusters = n_clusters
        self.top_relative_weight = top_relative_weight
        self.max_iter_alt = max_iter_alt
        self.max_iter_interp = max_iter_interp
        self.learning_rate = learning_rate

    def fit_predict(self, data):
        """Computes topological clustering and predicts cluster index for each sample.
        
            Args:
                data:
                  Training instances to cluster.
                  
            Returns:
                Cluster index each sample belongs to.
        """
        data = np.asarray(data)
        n_node = data.shape[1]
        n_edges = math.factorial(n_node) // math.factorial(2) // math.factorial(
            n_node - 2)  # n_edges = (n_node choose 2)
        n_births = n_node - 1
        self.weight_array = np.append(
            np.repeat(1 - self.top_relative_weight, n_edges),
            np.repeat(self.top_relative_weight, n_edges))

        # Networks represented as vectors concatenating geometric and topological info
        X = []
        for adj in data:
            X.append(self._vectorize_geo_top_info(adj))
        X = np.asarray(X)

        # Random initial condition
        self.centroids = X[random.sample(range(X.shape[0]), self.n_clusters)]

        # Assign the nearest centroid index to each data point
        assigned_centroids = self._get_nearest_centroid(
            X[:, None, :], self.centroids[None, :, :])
        prev_assigned_centroids = assigned_centroids

        for it in range(self.max_iter_alt):
            for cluster in range(self.n_clusters):
                # Previous iteration centroid
                prev_centroid = np.zeros((n_node, n_node))
                prev_centroid[np.triu_indices(
                    prev_centroid.shape[0],
                    k=1)] = self.centroids[cluster][:n_edges]

                # Determine data points belonging to each cluster
                cluster_members = X[assigned_centroids == cluster]

                # Compute the sample mean and top. centroid of the cluster
                cc = cluster_members.mean(axis=0)
                sample_mean = np.zeros((n_node, n_node))
                sample_mean[np.triu_indices(sample_mean.shape[0],
                                            k=1)] = cc[:n_edges]
                top_centroid = cc[n_edges:]
                top_centroid_birth_set = top_centroid[:n_births]
                top_centroid_death_set = top_centroid[n_births:]

                # Update the centroid
                try:
                    cluster_centroid = self._top_interpolation(
                        prev_centroid, sample_mean, top_centroid_birth_set,
                        top_centroid_death_set)
                    self.centroids[cluster] = self._vectorize_geo_top_info(
                        cluster_centroid)
                except:
                    print(
                        'Error: Possibly due to the learning rate is not within appropriate range.'
                    )
                    sys.exit(1)

            # Update the cluster membership
            assigned_centroids = self._get_nearest_centroid(
                X[:, None, :], self.centroids[None, :, :])

            # Compute and print loss as it is progressively decreasing
            loss = self._compute_top_dist(
                X, self.centroids[assigned_centroids]).sum() / len(X)
            print('Iteration: %d -> Loss: %f' % (it, loss))

            if (prev_assigned_centroids == assigned_centroids).all():
                break
            else:
                prev_assigned_centroids = assigned_centroids
        return assigned_centroids

    def _vectorize_geo_top_info(self, adj):
        birth_set, death_set = self._compute_birth_death_sets(
            adj)  # topological info
        vec = adj[np.triu_indices(adj.shape[0], k=1)]  # geometric info
        return np.concatenate((vec, birth_set, death_set), axis=0)

    def _compute_birth_death_sets(self, adj):
        """Computes birth and death sets of a network."""
        mst, nonmst = self._bd_demomposition(adj)
        birth_ind = np.nonzero(mst)
        death_ind = np.nonzero(nonmst)
        return np.sort(mst[birth_ind]), np.sort(nonmst[death_ind])

    def _bd_demomposition(self, adj):
        """Birth-death decomposition."""
        eps = np.nextafter(0, 1)
        adj[adj == 0] = eps
        adj = np.triu(adj, k=1)
        Xcsr = csr_matrix(-adj)
        Tcsr = minimum_spanning_tree(Xcsr)
        mst = -Tcsr.toarray()  # reverse the negative sign
        nonmst = adj - mst
        birth_ind = np.nonzero(mst)
        return mst, nonmst

    def _get_nearest_centroid(self, X, centroids):
        """Determines cluster membership of data points."""
        dist = self._compute_top_dist(X, centroids)
        nearest_centroid_index = np.argmin(dist, axis=1)
        return nearest_centroid_index

    def _compute_top_dist(self, X, centroid):
        """Computes the pairwise top. distances between networks and centroids."""
        return np.dot((X - centroid)**2, self.weight_array)

    def _top_interpolation(self, init_centroid, sample_mean,
                           top_centroid_birth_set, top_centroid_death_set):
        """Topological interpolation."""
        curr = init_centroid
        for _ in range(self.max_iter_interp):
            # Geometric term gradient
            geo_gradient = 2 * (curr - sample_mean)

            # Topological term gradient
            sorted_birth_ind, sorted_death_ind = self._compute_optimal_matching(
                curr)
            top_gradient = np.zeros_like(curr)
            top_gradient[sorted_birth_ind] = top_centroid_birth_set
            top_gradient[sorted_death_ind] = top_centroid_death_set
            top_gradient = 2 * (curr - top_gradient)

            # Gradient update
            curr -= self.learning_rate * (
                (1 - self.top_relative_weight) * geo_gradient +
                self.top_relative_weight * top_gradient)
        return curr

    def _compute_optimal_matching(self, adj):
        mst, nonmst = self._bd_demomposition(adj)
        birth_ind = np.nonzero(mst)
        death_ind = np.nonzero(nonmst)
        sorted_temp_ind = np.argsort(mst[birth_ind])
        sorted_birth_ind = tuple(np.array(birth_ind)[:, sorted_temp_ind])
        sorted_temp_ind = np.argsort(nonmst[death_ind])
        sorted_death_ind = tuple(np.array(death_ind)[:, sorted_temp_ind])
        return sorted_birth_ind, sorted_death_ind


In [58]:
#clustering network graphs

n_clusters = 2
top_relative_weight = 0
max_iter_alt = 300
max_iter_interp = 300
learning_rate = 0.05

labels_pred = TopClustering(n_clusters, top_relative_weight, max_iter_alt,
                                max_iter_interp,
                                learning_rate).fit_predict(network_graphs)

print(labels_pred)

Iteration: 0 -> Loss: 46.732040
Iteration: 1 -> Loss: 29.903178
[1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1
 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0
 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1
 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0
 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1
 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0]


In [59]:
print(network_graphs[0])
print(network_graphs[1])

[[0.         0.25912957 0.16489568 ... 0.11437594 0.18770911 0.25446728]
 [0.25912957 0.         0.22676125 ... 0.33733986 0.08048746 0.28707708]
 [0.16489568 0.22676125 0.         ... 0.50387375 0.06584253 0.10875955]
 ...
 [0.11437594 0.33733986 0.50387375 ... 0.         0.23926586 0.23021397]
 [0.18770911 0.08048746 0.06584253 ... 0.23926586 0.         0.14453173]
 [0.25446728 0.28707708 0.10875955 ... 0.23021397 0.14453173 0.        ]]
[[0.         0.22596773 0.11430681 ... 0.07174153 0.12941904 0.11712096]
 [0.22596773 0.         0.33555112 ... 0.0534523  0.15135171 0.11295671]
 [0.11430681 0.33555112 0.         ... 0.07889193 0.05567947 0.5364643 ]
 ...
 [0.07174153 0.0534523  0.07889193 ... 0.         0.27542288 0.13725879]
 [0.12941904 0.15135171 0.05567947 ... 0.27542288 0.         0.02819753]
 [0.11712096 0.11295671 0.5364643  ... 0.13725879 0.02819753 0.        ]]
