In [None]:
import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift, estimate_bandwidth

from corc.generation import GenerationModel

%matplotlib inline

# References

- https://en.wikipedia.org/wiki/Mean_shift
- https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html
- https://scikit-learn.org/stable/auto_examples/cluster/plot_mean_shift.html#sphx-glr-auto-examples-cluster-plot-mean-shift-py

# Equidistant triangle

In [None]:
params = {
    'center_structure': 'equidistant_triangle',
    'n_centers': 3,
    'distance': 1,
    'n_samples': 1000,
    'dim': 2,
    'save_file': False,
    'outdir': '.'
}

In [None]:
gen = GenerationModel(**params)
gen.generate()

## std = 0.01

In [None]:
std = 0.01
data = gen.sample_embedding(std=std)
plt.scatter(data[:,0], data[:,1], c=gen.labels)

In [None]:
# bandwidth = estimate_bandwidth(data, quantile=0.2, n_samples=500)
bandwidth = estimate_bandwidth(data, quantile=0.3)
print(f'{bandwidth=}')
clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True).fit(data)

print(f'n clusters = {len(clustering.cluster_centers_)}')

plt.scatter(data[:,0], data[:,1], c=clustering.labels_)
plt.scatter(clustering.cluster_centers_[:,0], clustering.cluster_centers_[:,1], marker='x', s=30, c='blacK')

## std = 0.1

In [None]:
std = 0.1
data = gen.sample_embedding(std=std)
plt.scatter(data[:,0], data[:,1], c=gen.labels)

In [None]:
bandwidth = estimate_bandwidth(data, quantile=0.3)
print(f'{bandwidth=}')
clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True).fit(data)

print(f'n clusters = {len(clustering.cluster_centers_)}')

plt.scatter(data[:,0], data[:,1], c=clustering.labels_)
plt.scatter(clustering.cluster_centers_[:,0], clustering.cluster_centers_[:,1], marker='x', s=30, c='blacK')

## std = 0.2

In [None]:
std = 0.2
data = gen.sample_embedding(std=std)
plt.scatter(data[:,0], data[:,1], c=gen.labels)

In [None]:
bandwidth = estimate_bandwidth(data, quantile=0.3)
print(f'{bandwidth=}')
clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True).fit(data)

print(f'n clusters = {len(clustering.cluster_centers_)}')

plt.scatter(data[:,0], data[:,1], c=clustering.labels_)
plt.scatter(clustering.cluster_centers_[:,0], clustering.cluster_centers_[:,1], marker='x', s=30, c='blacK')

## std = 0.5

In [None]:
std = 0.4
data = gen.sample_embedding(std=std)
plt.scatter(data[:,0], data[:,1], c=gen.labels)

In [None]:
bandwidth = estimate_bandwidth(data, quantile=0.3)
print(f'{bandwidth=}')
clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True).fit(data)

print(f'n clusters = {len(clustering.cluster_centers_)}')

plt.scatter(data[:,0], data[:,1], c=clustering.labels_)
plt.scatter(clustering.cluster_centers_[:,0], clustering.cluster_centers_[:,1], marker='x', s=30, c='blacK')

# More than 3 clusters

In [None]:
params = {
    'center_structure': 'uniform',
    'n_centers': 7,
    'distance': 1,
    'n_samples': 1000,
    'dim': 2,
    'save_file': False,
    'outdir': '.'
}

In [None]:
gen = GenerationModel(**params)
gen.generate()

## std = 0.01

In [None]:
std = 0.01
data = gen.sample_embedding(std=std)
plt.scatter(data[:,0], data[:,1], c=gen.labels)

In [None]:
bandwidth = estimate_bandwidth(data, quantile=0.1)
print(f'{bandwidth=}')
clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True).fit(data)

print(f'n clusters = {len(clustering.cluster_centers_)}')

plt.scatter(data[:,0], data[:,1], c=clustering.labels_)
plt.scatter(clustering.cluster_centers_[:,0], clustering.cluster_centers_[:,1], marker='x', s=30, c='blacK')

## std = 0.05

In [None]:
std = 0.05
data = gen.sample_embedding(std=std)
plt.scatter(data[:,0], data[:,1], c=gen.labels)

In [None]:
bandwidth = estimate_bandwidth(data, quantile=0.1)
print(f'{bandwidth=}')
clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True).fit(data)

print(f'n clusters = {len(clustering.cluster_centers_)}')

plt.scatter(data[:,0], data[:,1], c=clustering.labels_)
plt.scatter(clustering.cluster_centers_[:,0], clustering.cluster_centers_[:,1], marker='x', s=30, c='blacK')

## std = 0.1

In [None]:
std = 0.1
data = gen.sample_embedding(std=std)
plt.scatter(data[:,0], data[:,1], c=gen.labels)

In [None]:
bandwidth = estimate_bandwidth(data, quantile=0.2)
print(f'{bandwidth=}')
clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True).fit(data)

print(f'n clusters = {len(clustering.cluster_centers_)}')

plt.scatter(data[:,0], data[:,1], c=clustering.labels_)
plt.scatter(clustering.cluster_centers_[:,0], clustering.cluster_centers_[:,1], marker='x', s=30, c='blacK')

# On Mara's embeddings

In [None]:
import pandas as pd
import numpy as np

In [None]:
df = pd.read_pickle('../../graphdino_morphological_embeddings_tsne.pkl')
latents = np.stack(df['latent_emb'].values).astype(float)
tsne = np.stack(df['tsne'].values).astype(float)

In [None]:
plt.scatter(tsne[:,0], tsne[:,1], s=1)

In [None]:
# bandwidth = estimate_bandwidth(latents, quantile=0.01, n_samples=1000)
bandwidth = 0.05
print(f'{bandwidth=}')
clustering = MeanShift(bandwidth=bandwidth).fit(latents)

print(f'n clusters = {len(clustering.cluster_centers_)}')

plt.scatter(tsne[:,0], tsne[:,1], c=clustering.labels_, s=1)
plt.scatter(clustering.cluster_centers_[:,0], clustering.cluster_centers_[:,1], marker='x', s=30, c='blacK')

In [None]:
# bandwidth = estimate_bandwidth(latents, quantile=0.01, n_samples=1000)
bandwidth = 0.1
print(f'{bandwidth=}')
clustering = MeanShift(bandwidth=bandwidth).fit(latents)

print(f'n clusters = {len(clustering.cluster_centers_)}')

plt.scatter(tsne[:,0], tsne[:,1], c=clustering.labels_, s=1)
plt.scatter(clustering.cluster_centers_[:,0], clustering.cluster_centers_[:,1], marker='x', s=30, c='blacK')

In [None]:
# bandwidth = estimate_bandwidth(latents, quantile=0.01, n_samples=1000)
bandwidth = 0.2
print(f'{bandwidth=}')
clustering = MeanShift(bandwidth=bandwidth).fit(latents)

print(f'n clusters = {len(clustering.cluster_centers_)}')

plt.scatter(tsne[:,0], tsne[:,1], c=clustering.labels_, s=1)
plt.scatter(clustering.cluster_centers_[:,0], clustering.cluster_centers_[:,1], marker='x', s=30, c='blacK')

In [None]:
for quantile in np.arange(0,1,0.1):
    bandwidth = estimate_bandwidth(latents, quantile=quantile, n_samples=1000)
    print(f'For {quantile=} the {bandwidth=}')

In [None]:
# bandwidth = estimate_bandwidth(latents, quantile=0.01, n_samples=1000)
bandwidth = 2.5
print(f'{bandwidth=}')
clustering = MeanShift(bandwidth=bandwidth).fit(latents)

print(f'n clusters = {len(clustering.cluster_centers_)}')

plt.scatter(tsne[:,0], tsne[:,1], c=clustering.labels_, s=1)
plt.scatter(clustering.cluster_centers_[:,0], clustering.cluster_centers_[:,1], marker='x', s=30, c='blacK')

In [None]:
# bandwidth = estimate_bandwidth(latents, quantile=0.01, n_samples=1000)
bandwidth = 5
print(f'{bandwidth=}')
clustering = MeanShift(bandwidth=bandwidth).fit(latents)

print(f'n clusters = {len(clustering.cluster_centers_)}')

plt.scatter(tsne[:,0], tsne[:,1], c=clustering.labels_, s=1)
plt.scatter(clustering.cluster_centers_[:,0], clustering.cluster_centers_[:,1], marker='x', s=30, c='blacK')