# K-means with WaveletEMD 
EMD states for Earth Mover Distance, also known as Wasserstein Distance. This distance emerges from the Optimal Transport field. Intuitively, given two distributions, it correspond to the cost of transforming one distribution to the other (or moving the mass of one to the other one). 

EMD is too expensive to compute $(O(N^3 logN))$ computational complexity [1], where N is the number of elements in the histogram. But a linear approximation of EMD can be computed using wavelet transform.

K-means using this metric has been used for clustering Tomographic projections [2] (in images), outperforming the classical k-means with Euclidean Distance.  

In this project, we will use K-means with the proposed WaveletEMD given in equation (2) of [1] to compare k-mer distributions (or any 1-dimensional distribution)

References
- [1] [Approximate earth mover’s distance in linear time](https://ieeexplore.ieee.org/document/4587662)
- [2] [Wasserstein K-Means for Clustering Tomographic Projections](https://doi.org/10.48550/arXiv.2010.09989)


In [15]:
import pywt
import numpy as np 

## pywt usage

In [16]:
# families of wavelet functions availables
pywt.families()

['haar',
 'db',
 'sym',
 'coif',
 'bior',
 'rbio',
 'dmey',
 'gaus',
 'mexh',
 'morl',
 'cgau',
 'shan',
 'fbsp',
 'cmor']

In [17]:
# Information of a wavelet family
w = pywt.Wavelet("sym5")
print(w)

Wavelet sym5
  Family name:    Symlets
  Short name:     sym
  Filters length: 10
  Orthogonal:     True
  Biorthogonal:   True
  Symmetry:       near symmetric
  DWT:            True
  CWT:            False


In [18]:
# Filters

# decomposition filters
w.dec_lo  # lowpass
w.dec_hi  # highpass 

# reconstruction filters
w.rec_lo  # lowpass 
w.rec_hi  # highpass

[0.027333068345077982,
 -0.029519490925774643,
 -0.039134249302383094,
 -0.1993975339773936,
 0.7234076904024206,
 -0.6339789634582119,
 0.01660210576452232,
 0.17532808990845047,
 -0.021101834024758855,
 -0.019538882735286728]

___
# Multilevel 1d-wavelet transform

In [19]:
data= data = np.array([0.1, 0.2, 0.3, 0.4, 1.0, 2.0, 0.1, 1.0, 0.9, 2.0, 0.1, 0.2, 0.3, 0.4, 1.0, 2.0, 0.1, 1.0, 0.9, 2.0, 0.1, 0.2, 0.3, 0.4, 1.0, 2.0, 0.1, 1.0, 0.9, 2.0])
coeff = pywt.wavedec(data=data,wavelet="sym5",mode="symmetric",level=3)



In [20]:
# Number of output vecotor coefficients is level+1 (the first one is A_{level}, the others are D_j, for j=level,...1)
len(coeff) 

4

In [21]:
# level 1
coeff = pywt.wavedec(data=data,wavelet="sym5",mode="symmetric",level=1)
cA1, cD1 = coeff 

In [22]:
# level 2 
coeff = pywt.wavedec(data=data,wavelet="sym5",mode="symmetric",level=2)
cA2, cD2, cD1 = coeff 



In [23]:
# level 3 
coeff = pywt.wavedec(data=data,wavelet="sym5",mode="symmetric",level=3)
cA3, cD3, cD2, cD1 = coeff

___
## WaveletEMD
An approximation to EMD

In [27]:
import random 
import numpy as np
from src import WaveletEMD

In [28]:
N = 1000
vec1 = np.array([random.random() for _ in range(N)])
# vec1 = np.array([random.random() for _ in range(N)])

vec2 = vec1.copy()
vec2[10] = vec2[10]+0.1

In [29]:
wemd = WaveletEMD(level=2)

In [65]:
# Euclidean distance vs WaveletEMD
np.sqrt((vec1**2+vec2**2).sum()), wemd(vec1, vec2)

(26.641774034545495, 0.031249999999999986)

___
# K-means with waveletEMD 

In [38]:
import random
from src import KmeansOT
from src import WaveletEMD

In [56]:
# data to train the model
n_samples = 100
data = [[random.random() for _ in range(10)] for _ in range(n_samples)]

In [57]:
num_clusters = 2
count_matrix = data

kmeans = KmeansOT(n_clusters=num_clusters, random_state=2)
kmeans.fit(count_matrix)



In [58]:
kmeans.centers

[[0.4989833901879933,
  0.41089414575344835,
  0.5779812548321506,
  0.4250396968267552,
  0.36524235009151057,
  0.6425030542102316,
  0.5557030355827277,
  0.4673782260266972,
  0.3765753396496643,
  0.5605073375078665],
 [0.6388573966117741,
  0.54166938027865,
  0.35573626369139405,
  0.546660667126009,
  0.5779806900851033,
  0.46799282708480644,
  0.41575600380457023,
  0.5954634643321924,
  0.5994643336400639,
  0.44903191241512713]]

In [59]:
kmeans.predict(data)[[0,7,9]]

array([0, 0, 1])

To use other wavelet configurations see [pywavelet documentation](https://pywavelets.readthedocs.io/en/latest/ref/dwt-discrete-wavelet-transform.html#multilevel-decomposition-using-wavedec) and  [here for wavelet families and functions](https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html)

In [60]:
# Other wavelet functions
wemd = WaveletEMD(wavelet = "haar", level = 2, mode = "zero") 
new_metric = distance_metric(type_metric.USER_DEFINED, func=wemd)

In [61]:
pywt.families()

['haar',
 'db',
 'sym',
 'coif',
 'bior',
 'rbio',
 'dmey',
 'gaus',
 'mexh',
 'morl',
 'cgau',
 'shan',
 'fbsp',
 'cmor']

In [62]:
pywt.wavelist("haar")

['haar']

In [63]:
kmeans = KmeansOT(n_clusters=num_clusters, metric = new_metric, random_state=2)
kmeans.fit(count_matrix)

In [64]:
kmeans.predict(data)[[0,7,9]]

array([0, 0, 1])