# Clustering Using Minimum Spanning Trees

### By Daniel Lee, Udaikaran Singh, and Justin Eldridge

In this notebook, we will show how minimum spanning trees can be used to cluster data. Along the way, we will implement Kruskal's Algorithm for computing MSTs.

We will start by importing some familiar packages:

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import collections
import random
import math

# make plots bigger by default
plt.rcParams['figure.figsize'] = (8,8)
plt.rcParams['font.size'] = 20

We will use the course's own simple dict-of-sets data structure to represent graphs:

In [None]:
from dsc40graph import UndirectedGraph

## Geyser Eruptions

The file `faithful.csv` contains data on 269 eruptions of the Old Faithful geyser in Yellowstone National Park. A scatter plot of the data is shown below:

In [None]:
geyser = np.loadtxt('faithful.csv', skiprows=1, delimiter=',')
plt.scatter(*geyser.T)

plt.xlabel('Duration')
plt.ylabel('Wait')
plt.title('Eruptions of Old Faithful')

We can see that the geyser appears to erupt in two "modes": either with high frequency and short duration, or low frequency and long duration. We wish to "cluster" the above data in order to recover these two modes automatically. Our strategy will be as follows:

1. Construct the distance graph of the data. This is a weighted graph in which every node is a data point, and the weight of each edge is the distance between its endpoints.
2. Compute the MST of the distance graph.
3. Remove the longest edge in the MST. This creates two connected components; these are the clusters.

## Kruskal's Algorithm

In lecture, we saw that Kruskal's Algorithm can be used to compute minimum spanning trees. We will now implement this algorithm:

In [None]:
def kruskal(graph, weight):
    """Compute the MST of the graph.
    
    Parameters
    ----------
    graph : UndirectedGraph
        The graph whose MST will be computed.
    weight : callable
        A function which returns the weight of an edge.
        
    """
    mst = UndirectedGraph()
    forest = DisjointSet()
    
    # sort edges from lightest to heaviest
    weighted_edges = sorted(graph.edges, key=weight)
    
    # intialize disjoint-set forest
    for node in graph.nodes:
        forest.make_set(node)
    
    for (u, v) in weighted_edges:
        if forest.find(u) != forest.find(v):
            forest.union(u, v)
            mst.add_edge(u, v)
    
    return mst

Kruskal's Algorithm makes use of a "disjoint set forest" data structure which provides a way to efficiently union disjoint sets and to determine if two points belong to the same set. In the context of Kruskal's algorithm, we use a disjoint set to quickly determine whether any two vertices have been connected yet or not.

We implement a disjoint set forest in the `DisjointSet` class below. For more information, see "Introduction to Algorithms", 3rd edition, by Cormen, Leiserson, Rivest, and Stein.

In [None]:
class DisjointSet:

    def __init__(self):
        self.parent = dict()
        self.rank = dict()
        
    def make_set(self, vertex):
        """Create a new singleton set."""
        self.parent[vertex] = vertex
        self.rank[vertex] = 0

    def find(self, vertex):
        """Find the ID of the set to which the vertex belongs."""
        if self.parent[vertex] != vertex:
            self.parent[vertex] = self.find(self.parent[vertex])
        return self.parent[vertex]

    def union(self, v1, v2):
        """Union the set containing v1 with the set containiing v2."""
        root1 = self.find(v1)
        root2 = self.find(v2)
        if root1 != root2:
            if self.rank[root1] > self.rank[root2]:
                self.parent[root2] = root1
            else:
                self.parent[root1] = root2
        if self.rank[root1] == self.rank[root2]: 
            self.rank[root2] += 1

We will test our implementation on the simple weighted graph below:

<center>
<img src="./simple_weighted_graph.png" width=30%>
</center>

We convert the graph to code:

In [None]:
graph = UndirectedGraph()
graph.add_edge('a', 'b')
graph.add_edge('b', 'c')
graph.add_edge('c', 'd')
graph.add_edge('a', 'd')
graph.add_edge('a', 'c')

def weight(edge):
    u, v = edge
    weights = {
        ('a', 'b'): 2,
        ('a', 'd'): 8,
        ('d', 'c'): 3,
        ('c', 'b'): 5,
        ('a', 'c'): 1
    }
    try:
        return weights[(u, v)]
    except KeyError:
        return weights[(v, u)]

In [None]:
mst = kruskal(graph, weight)
mst.edges

The edges of the MST are bolded in the graph below:

<center>
<img src="./simple_weighted_mst.png" width=30%>
</center>

## MST of the Geyser Data

Recall that our approach to clustering will involve computing the MST of the distance graph constructed from the data. We now create a simple helper function which takes in an array of data points and returns an `UndirectedGraph` along with the weight function:

In [None]:
from itertools import combinations

def distance_graph(data):
    """Given an n x d array of data, produces the distance graph on n nodes."""
    graph = UndirectedGraph()
    n = len(data)
    weight = {}
    
    for i, j in combinations(range(n), 2):
        graph.add_edge(i, j)
        weight[(i, j)] = np.linalg.norm(data[i] - data[j])
        
    def weight_function(edge):
        u, v = edge
        if (u, v) in weight:
            return weight[(u,v)]
        else:
            return weight[(v,u)]
            
    return graph, weight_function

With this, we have everything we need to compute the MST of the distance graph constructed from the geyser data:

In [None]:
graph, weight = distance_graph(geyser)
mst = kruskal(graph, weight)

Since the graph is a distance graph computed from points in the plane, we can draw the MST by positioning each node at the coordinates of its corresponding data point. The result is as follows:

In [None]:
def plot_mst_edge(points, u, v, **kwargs):
    points = np.vstack((points[u], points[v]))
    plt.plot(points[:,0], points[:,1], **kwargs)

plt.scatter(*geyser.T)
for u, v in mst.edges:
    plot_mst_edge(geyser, u, v, color='black', alpha=.5)

Is this what you expected?

Let's dig a little further. What is the biggest edge in the MST?

In [None]:
biggest_edge = max(mst.edges, key=weight)
biggest_edge

The below plot highlights the biggest edge in red:

In [None]:
plt.scatter(*geyser.T)
for u, v in mst.edges:
    plot_mst_edge(geyser, u, v, color='black', alpha=.5)
plot_mst_edge(geyser, *biggest_edge, color='red')

It may not look like this is the biggest edge, but look again. Is there a bug in our code?

There is not. This is an issue of scale: the values on the $y$-axis are much different from the values on the $x$-axis, and so what looks like a small gap between points is actually a gigantic leap. This becomes clearer if we force the axes of the plot to be drawn at the same scale:

In [None]:
plt.figure(figsize=(20,20))
plt.scatter(*geyser.T)
for u, v in mst.edges:
    plot_mst_edge(geyser, u, v, color='black', alpha=.5)
plot_mst_edge(geyser, *biggest_edge, color='red')

plt.gca().set_aspect('equal')

Of course, changing how the data is plotted does not affect the MST. We need to scale the data itself in order to get the tree we were expected. A natural way to do this is to *standardize* (i.e., $z$-score) the data:

In [None]:
def standardize(x):
    return (x - np.mean(x, axis=0)) / np.std(x, axis=0)

scaled_geyser = np.empty_like(geyser)
scaled_geyser = standardize(geyser)

The standardized data looks the same. But check out the axes:

In [None]:
plt.scatter(*scaled_geyser.T)

We now compute the MST of the scaled data:

In [None]:
graph, scaled_weight = distance_graph(scaled_geyser)
scaled_mst = kruskal(graph, scaled_weight)

In [None]:
plt.scatter(*scaled_geyser.T)
for u, v in scaled_mst.edges:
    points = np.vstack((scaled_geyser[u], scaled_geyser[v]))
    plt.plot(points[:,0], points[:,1], color='black', alpha=.5)

## Clustering the Geyser MST

Recall that our strategy for clustering has us deleting the longest edge of the MST and finding the resulting connected components -- these are our clusters. To find the connected components, we can use BFS or DFS:

In [None]:
from collections import deque

def bfs_component(graph, source, status=None):
    """Start a BFS at `source`."""
    if status is None:
        status = {node: 'undiscovered' for node in graph.nodes}

    status[source] = 'pending'
    pending = deque([source])
    component = set()

    # while there are still pending nodes
    while pending: 
        u = pending.popleft() # pop from left (front of queue)
        component.add(u)
        
        for v in graph.neighbors(u, sort=True):
            # explore edge (u,v)
            if status[v] == 'undiscovered':
                status[v] = 'pending'
                pending.append(v) # append to right (back of queue)
        status[u] = 'visited'
        
    return component
        
def connected_components(graph):
    status = {node: 'undiscovered' for node in graph.nodes}
    
    components = []
    for node in graph.nodes:
        if status[node] == 'undiscovered':
            component = bfs_component(graph, node, status=status)
            components.append(component)
            
    return components

Let's try it out. The MST is shown with the largest edge removed:

In [None]:
biggest_edge = max(scaled_mst.edges, key=scaled_weight)
scaled_mst.remove_edge(*biggest_edge)

plt.scatter(*scaled_geyser.T)
for u, v in scaled_mst.edges:
    points = np.vstack((scaled_geyser[u], scaled_geyser[v]))
    plt.plot(points[:,0], points[:,1], color='black', alpha=.5)

We then find the connected components:

In [None]:
components = connected_components(scaled_mst)

def points_by_component(data, components):
    for component in components:
        yield data[list(component)]
        
for points in points_by_component(scaled_geyser, components):
    plt.scatter(*points.T)

We have successfully separated the data into two reasonable-looking clusters. We can generalize this process so that it produces more than two clusters by deleting more edges from the MST. The resulting algorithm is often called *single-linkage clustering*. We encapsulate it in a function below:

In [None]:
def single_linkage_clustering(data, n_clusters=2):
    graph, weight = distance_graph(data)
    mst = kruskal(graph, weight)
    for i in range(n_clusters-1):
        biggest_edge = max(mst.edges, key=weight)
        mst.remove_edge(*biggest_edge)
    components = connected_components(mst)
    return components, mst

The function is used as follows:

In [None]:
clusters, mst = single_linkage_clustering(scaled_geyser, n_clusters=2)

for points in points_by_component(scaled_geyser, clusters):
    plt.scatter(*points.T)

## More Examples on Synthetic data

### Dataset 1

We now load a different data set, remembering to standardize it:

In [None]:
dense_and_sparse = standardize(np.loadtxt('sample_dataset_2.txt'))
dense_and_sparse = dense_and_sparse[np.random.choice(len(dense_and_sparse), 500)]
plt.scatter(*dense_and_sparse.T)

plt.gca().set_aspect('equal')

How many clusters do you see in the data?

Play with this parameter by changing the k value in the next cell.

In [None]:
clusters, mst = single_linkage_clustering(dense_and_sparse, n_clusters=5)

for points in points_by_component(dense_and_sparse, clusters):
    plt.scatter(*points.T)
    
for (u,v) in mst.edges:
    plot_mst_edge(dense_and_sparse, u, v, color='black')
    
plt.gca().set_aspect('equal')

-----------

#### Dataset 2

We will now try SLC on a noisier data set:

In [None]:
noisy = standardize(np.loadtxt('sample_dataset_1.txt'))
noisy = noisy[np.random.choice(len(noisy), 500)]
plt.scatter(*noisy.T)

How many clusters do you see in the data?

Play with this parameter by changing the k value in the next cell.

In [None]:
clusters, mst = single_linkage_clustering(noisy, n_clusters=15)

for points in points_by_component(noisy, clusters):
    plt.scatter(*points.T)
    
for (u,v) in mst.edges:
    plot_mst_edge(noisy, u, v, color='black')
    
plt.gca().set_aspect('equal')

The noise is making clustering difficult. One way to make things better is to remove outliers. An outlier is a point which does not have many points around it. We can quantify this by calculating the distance from each point to its $k$th closest neighbor:

In [None]:
from scipy.spatial import distance_matrix

In [None]:
def distance_to_kth_neighbor(data, k):
    distances = distance_matrix(data, data)
    return np.partition(distances, k, axis=1)[:,k]

In [None]:
r_k = distance_to_kth_neighbor(noisy, k=5)

The plot below shows the distance from each point to its 5th closest neighbor by coloring the point. This distance is small for dark points, and large for light points:

In [None]:
plt.scatter(*noisy.T, c=r_k)
plt.gca().set_aspect('equal')

A histogram of the distances to the $k$th neighbor shows a long tail:

In [None]:
plt.hist(r_k, bins=20);

A reasonable approach is to throw out the top 10% of points by their distance to the $k$th neighbor:

In [None]:
np.percentile(r_k, 90)

In [None]:
less_noisy = noisy[r_k < .19]

With these points removed, the data looks cleaner:

In [None]:
plt.scatter(*less_noisy.T)

Now clustering gives a much better result:

In [None]:
clusters, mst = single_linkage_clustering(less_noisy, n_clusters=15)

for (u,v) in mst.edges:
    plot_mst_edge(less_noisy, u, v, color='black', alpha=.5)

for points in points_by_component(less_noisy, clusters):
    plt.scatter(*points.T)
    
plt.gca().set_aspect('equal')