-
Notifications
You must be signed in to change notification settings - Fork 3
Algorithms
protruser edited this page Dec 16, 2024
·
7 revisions
Affinity Propagation is a clustering algorithm that identifies a set of "exemplars" among the data points and forms clusters around these exemplars. Unlike other clustering methods (e.g., K-Means), it does not require the number of clusters to be specified beforehand. Instead, it works by exchanging messages between data points until a good set of exemplars and clusters emerge.
This example demonstrates how to perform Affinity Propagation clustering using the scikit-learn library.
# Import required libraries
from numpy import unique
from numpy import where
from sklearn.datasets import make_classification
from sklearn.cluster import AffinityPropagation
from matplotlib import pyplot
# Define the dataset
X, _ = make_classification(
n_samples=1000,
n_features=2,
n_informative=2,
n_redundant=0,
n_clusters_per_class=1,
random_state=4
)
# Define the model
model = AffinityPropagation(damping=0.9)
# Fit the model
model.fit(X)
# Assign a cluster to each example
yhat = model.predict(X)
# Retrieve unique clusters
clusters = unique(yhat)
# Create scatter plot for samples from each cluster
for cluster in clusters:
# Get row indexes for samples with this cluster
row_ix = where(yhat == cluster)
# Create scatter plot of these samples
pyplot.scatter(X[row_ix, 0], X[row_ix, 1])
# Show the plot
pyplot.show()
# birch clustering
from numpy import unique
from numpy import where
from sklearn.datasets import make_classification
from sklearn.cluster import Birch
from matplotlib import pyplot
# define dataset
X, _ = make_classification(n_samples=1000, n_features=2, n_informative=2, n_redundant=0, n_clusters_per_class=1, random_state=4)
# define the model
model = Birch(threshold=0.01, n_clusters=2)
# fit the model
model.fit(X)
# assign a cluster to each example
yhat = model.predict(X)
# retrieve unique clusters
clusters = unique(yhat)
# create scatter plot for samples from each cluster
for cluster in clusters:
# get row indexes for samples with this cluster
row_ix = where(yhat == cluster)
# create scatter of these samples
pyplot.scatter(X[row_ix, 0], X[row_ix, 1])
# show the plot
pyplot.show()