In [6]:
import numpy as np
from scipy.spatial.distance import pdist, squareform
from sklearn import datasets
from fastcluster import linkage

def seriation(Z,N,cur_index):
    '''
        input:
            - Z is a hierarchical tree (dendrogram)
            - N is the number of points given to the clustering process
            - cur_index is the position in the tree for the recursive traversal
        output:
            - order implied by the hierarchical tree Z
            
        seriation computes the order implied by a hierarchical tree (dendrogram)
    '''
    if cur_index < N:
        return [cur_index]
    else:
        left = int(Z[cur_index-N,0])
        right = int(Z[cur_index-N,1])
        return (seriation(Z,N,left) + seriation(Z,N,right))
    
def compute_serial_matrix(dist_mat,method="ward"):
    '''
        input:
            - dist_mat is a distance matrix
            - method = ["ward","single","average","complete"]
        output:
            - seriated_dist is the input dist_mat,
              but with re-ordered rows and columns
              according to the seriation, i.e. the
              order implied by the hierarchical tree
            - res_order is the order implied by
              the hierarhical tree
            - res_linkage is the hierarhical tree (dendrogram)
        
        compute_serial_matrix transforms a distance matrix into 
        a sorted distance matrix according to the order implied 
        by the hierarchical tree (dendrogram)
    '''
    N = len(dist_mat)
    flat_dist_mat = squareform(dist_mat)
    res_linkage = linkage(flat_dist_mat, method=method,preserve_input=True)
    res_order = seriation(res_linkage, N, N + N-2)
    seriated_dist = np.zeros((N,N))
    a,b = np.triu_indices(N,k=1)
    seriated_dist[a,b] = dist_mat[ [res_order[i] for i in a], [res_order[j] for j in b]]
    seriated_dist[b,a] = seriated_dist[a,b]
    
    return seriated_dist, res_order, res_linkage

In [7]:
iris = datasets.load_iris()
print(iris.data.shape)

dist_mat = squareform(pdist(iris.data))

N = len(iris.data)

X = iris.data[np.random.permutation(N),:]

dist_mat = squareform(pdist(X))

print(dist_mat.shape)
print(dist_mat)

(150, 4)
(150, 150)
[[0.         1.1045361  4.41701257 ... 4.40567815 1.22474487 0.97467943]
 [1.1045361  0.         3.91535439 ... 3.87427413 1.37840488 1.00498756]
 [4.41701257 3.91535439 0.         ... 0.14142136 3.30605505 3.54964787]
 ...
 [4.40567815 3.87427413 0.14142136 ... 0.         3.30907842 3.54118624]
 [1.22474487 1.37840488 3.30605505 ... 3.30907842 0.         0.45825757]
 [0.97467943 1.00498756 3.54964787 ... 3.54118624 0.45825757 0.        ]]


In [11]:
methods = ["ward","single","average","complete"]
for method in methods:
    print("Method:\t",method)
    
    ordered_dist_mat, res_order, res_linkage = compute_serial_matrix(dist_mat,method)
    print(ordered_dist_mat.shape)
    print(res_order)

Method:	 ward
(150, 150)
[133, 52, 80, 94, 128, 125, 41, 73, 12, 127, 56, 93, 55, 136, 69, 79, 106, 88, 137, 28, 45, 48, 5, 76, 6, 71, 146, 140, 141, 102, 96, 19, 36, 31, 147, 68, 2, 30, 82, 114, 18, 35, 120, 145, 62, 97, 44, 59, 46, 142, 67, 113, 37, 25, 32, 108, 115, 51, 87, 0, 83, 27, 105, 7, 61, 17, 121, 112, 54, 74, 84, 138, 92, 129, 89, 130, 104, 42, 99, 20, 11, 135, 29, 86, 22, 81, 47, 95, 100, 110, 16, 39, 3, 122, 90, 131, 116, 101, 124, 119, 23, 64, 43, 14, 72, 33, 24, 57, 117, 143, 38, 148, 34, 77, 9, 8, 126, 109, 26, 65, 66, 10, 13, 53, 40, 107, 111, 15, 1, 60, 70, 78, 144, 103, 123, 149, 21, 50, 118, 63, 139, 91, 98, 49, 4, 75, 132, 85, 58, 134]
Method:	 single
(150, 150)
[106, 88, 6, 71, 82, 5, 76, 120, 145, 44, 62, 97, 59, 55, 137, 114, 18, 35, 12, 136, 128, 127, 56, 93, 69, 79, 125, 41, 73, 52, 80, 94, 28, 45, 48, 146, 140, 141, 46, 142, 133, 68, 2, 30, 102, 96, 19, 36, 31, 147, 29, 86, 43, 95, 47, 100, 110, 22, 7, 66, 81, 20, 11, 135, 26, 65, 16, 34, 89, 130, 109, 108, 