# Visualization of Clusters

In this notebook, we will visualize the clusters using scatter plots.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Load the DataFrame with cluster labels
df = pd.read_csv('data/Mall_Customers.csv')
X = df[['Age', 'Annual Income (k$)', 'Spending Score (1-100)']]
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=5, init='k-means++', max_iter=300, n_init=10, random_state=0)
y_kmeans = kmeans.fit_predict(X_scaled)
df['Cluster'] = y_kmeans

# 2D Scatter plot: Income vs Spending Score, colored by cluster
sns.scatterplot(x='Annual Income (k$)', y='Spending Score (1-100)', hue='Cluster', data=df)
plt.title('Income vs Spending Score by Cluster')
plt.show()

# Optional: 3D Scatter plot with Age, Income, and Spending Score
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(df['Age'], df['Annual Income (k$)'], df['Spending Score (1-100)'], c=df['Cluster'], cmap='viridis')
ax.set_xlabel('Age')
ax.set_ylabel('Annual Income (k$)')
ax.set_zlabel('Spending Score (1-100)')
plt.title('3D Clustering of Customers')
plt.show()