# K-Means Clustering for Brain Tumor Analysis

This notebook applies K-Means clustering to brain MRI images for unsupervised analysis and pattern discovery.

## Overview

K-Means clustering is an **unsupervised learning** algorithm that groups similar images together without requiring labeled data. This notebook:
1. Loads and preprocesses brain MRI images
2. Extracts features from images (using flattening or feature extraction)
3. Applies K-Means clustering to group similar images
4. Visualizes clusters and cluster centers
5. Evaluates clustering performance using metrics (inertia, silhouette score)
6. Analyzes cluster assignments

**Note:** K-Means is unsupervised - it doesn't use class labels during training, but we can compare cluster assignments with true labels for evaluation purposes.

## Requirements

Make sure you have installed all required packages:
```bash
pip install numpy pandas scikit-learn matplotlib seaborn opencv-python pillow tqdm
```

Or install from requirements.txt:
```bash
pip install -r requirements.txt
```


## 1. Setup and Imports


In [None]:
# Load images using CSV metadata to prevent data leakage
print("Loading images using CSV metadata files to prevent data leakage...")

# Load metadata CSV files
augmented_metadata_path = 'data/augmented_dataset_metadata.csv'
original_metadata_path = 'data/dataset_metadata.csv'

if os.path.exists(augmented_metadata_path):
    aug_metadata_df = pd.read_csv(augmented_metadata_path)
    print(f"Loaded augmented metadata: {len(aug_metadata_df)} rows")
else:
    print(f"Warning: {augmented_metadata_path} not found.")
    aug_metadata_df = pd.DataFrame()

if os.path.exists(original_metadata_path):
    orig_metadata_df = pd.read_csv(original_metadata_path)
    print(f"Loaded original metadata: {len(orig_metadata_df)} rows")
else:
    print(f"Warning: {original_metadata_path} not found.")
    orig_metadata_df = pd.DataFrame()

# Filter training data to exclude test/val images
if len(aug_metadata_df) > 0 and len(orig_metadata_df) > 0:
    # Get test and val original filenames
    test_originals = set(orig_metadata_df[orig_metadata_df['split'] == 'test']['filename'].unique())
    val_originals = set(orig_metadata_df[orig_metadata_df['split'] == 'val']['filename'].unique())
    excluded_originals = test_originals.union(val_originals)
    
    # Filter train data: only keep images whose original_filename is NOT in test/val
    train_df = aug_metadata_df[aug_metadata_df['split'] == 'train'].copy()
    train_df_filtered = train_df[~train_df['original_filename'].isin(excluded_originals)]
    
    print(f"Filtered training data: {len(train_df_filtered)} rows (from {len(train_df)} total train rows)")
    print(f"Excluded {len(train_df) - len(train_df_filtered)} rows that overlap with test/val sets")
    
    # Load images from filtered metadata
    images = []
    labels = []
    image_paths = []
    
    for _, row in tqdm(train_df_filtered.iterrows(), total=len(train_df_filtered), desc="Loading images"):
        # Use full_path if available, otherwise construct from DATA_DIR and image_path
        if pd.notna(row.get('full_path')):
            img_path = row['full_path']
        else:
            img_path = os.path.join(DATA_DIR, row['image_path'])
        
        # Normalize path separators
        img_path = img_path.replace('\\', '/')
        
        if os.path.exists(img_path):
            try:
                # Read image in grayscale
                img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
                if img is not None:
                    # Resize image
                    img = cv2.resize(img, IMG_SIZE)
                    images.append(img)
                    labels.append(row['class'])
                    image_paths.append(img_path)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
    
    train_images = np.array(images)
    train_labels = np.array(labels)
    train_paths = np.array(image_paths)
    
    print(f"Loaded {len(train_images)} filtered training images")
    USE_CSV_METADATA = True
else:
    print("Warning: CSV metadata not available. Will use directory loading (may have data leakage).")
    USE_CSV_METADATA = False


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score, silhouette_samples, adjusted_rand_score
from sklearn.decomposition import PCA
import cv2
from pathlib import Path
import os
from tqdm import tqdm
import pickle
import random

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)

set_seed(42)

# Set style for plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("Imports successful")


## 2. Configuration


In [None]:
# Configuration
DATA_DIR = 'data/vgg16_classification'
MODEL_DIR = 'models/kmeans'
MODEL_SAVE_PATH = os.path.join(MODEL_DIR, 'kmeans_clusterer.pkl')
RESULTS_SAVE_PATH = os.path.join(MODEL_DIR, 'kmeans_results.csv')
VISUALIZATION_DIR = os.path.join(MODEL_DIR, 'visualizations')

# Clustering parameters
N_CLUSTERS = 4  # Number of clusters (matching 4 classes: NO_TUMOR, GLIOMA, MENINGIOMA, PITUITARY)
MAX_ITER = 300  # Maximum iterations for K-Means
N_INIT = 10     # Number of times K-Means will run with different centroid seeds
RANDOM_STATE = 42

# Image preprocessing parameters
IMG_SIZE = (224, 224)  # Resize images to this size
USE_FEATURE_EXTRACTION = True  # If True, use PCA for dimensionality reduction
N_COMPONENTS = 50  # Number of PCA components (if USE_FEATURE_EXTRACTION is True)

# Class names (for evaluation purposes)
CLASS_NAMES = ['NO_TUMOR', 'GLIOMA', 'MENINGIOMA', 'PITUITARY']

# Create directories
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(VISUALIZATION_DIR, exist_ok=True)

print(f"Configuration:")
print(f"  Data directory: {DATA_DIR}")
print(f"  Number of clusters: {N_CLUSTERS}")
print(f"  Image size: {IMG_SIZE}")
print(f"  Feature extraction: {USE_FEATURE_EXTRACTION}")
if USE_FEATURE_EXTRACTION:
    print(f"  PCA components: {N_COMPONENTS}")


## 3. Data Loading and Preprocessing


In [None]:
def load_images_from_directory(data_dir, max_samples=None):
    """
    Load images from directory structure.
    Returns: (images, labels, image_paths)
    """
    images = []
    labels = []
    image_paths = []
    
    # Get all class directories
    class_dirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
    
    for class_name in class_dirs:
        class_path = os.path.join(data_dir, class_name)
        image_files = [f for f in os.listdir(class_path) 
                       if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
        if max_samples:
            image_files = image_files[:max_samples]
        
        for img_file in tqdm(image_files, desc=f"Loading {class_name}"):
            img_path = os.path.join(class_path, img_file)
            try:
                # Read image in grayscale
                img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
                if img is not None:
                    # Resize image
                    img = cv2.resize(img, IMG_SIZE)
                    images.append(img)
                    labels.append(class_name)
                    image_paths.append(img_path)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
    
    return np.array(images), np.array(labels), np.array(image_paths)

# Load images from training set
# Use train_augmented directory (created by augment_training_data.py)
# If train_augmented doesn't exist, fall back to train directory
train_dir = os.path.join(DATA_DIR, 'train_augmented')
if not os.path.exists(train_dir):
    print(f"Warning: {train_dir} not found. Using 'train' directory instead.")
    print("Run 'python augment_training_data.py' first to create augmented training data.")
    train_dir = os.path.join(DATA_DIR, 'train')
else:
    print("Using augmented training data from train_augmented directory")

print("Loading images from training set...")
train_images, train_labels, train_paths = load_images_from_directory(train_dir)

print(f"\nLoaded {len(train_images)} images")
print(f"Image shape: {train_images[0].shape}")
print(f"Unique classes: {np.unique(train_labels)}")


### 3.1 Class Distribution


In [None]:
# Display class distribution
class_counts = pd.Series(train_labels).value_counts().sort_index()

plt.figure(figsize=(10, 6))
bars = plt.bar(class_counts.index, class_counts.values, 
               color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
plt.title('Class Distribution in Training Set', fontsize=14, fontweight='bold')
plt.xlabel('Class', fontsize=12)
plt.ylabel('Number of Samples', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{int(height)}',
             ha='center', va='bottom')

plt.tight_layout()
plt.savefig(f'{VISUALIZATION_DIR}/class_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nClass Distribution:")
print(class_counts)


### 3.2 Feature Extraction


In [None]:
# Flatten images to feature vectors
print("Extracting features from images...")
n_samples = train_images.shape[0]
n_features = train_images.shape[1] * train_images.shape[2]

# Flatten images
X_flattened = train_images.reshape(n_samples, -1)
print(f"Flattened shape: {X_flattened.shape}")

# Standardize features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_flattened)
print(f"Scaled shape: {X_scaled.shape}")

# Apply PCA for dimensionality reduction (optional)
if USE_FEATURE_EXTRACTION:
    print(f"\nApplying PCA with {N_COMPONENTS} components...")
    pca = PCA(n_components=N_COMPONENTS, random_state=RANDOM_STATE)
    X_features = pca.fit_transform(X_scaled)
    print(f"PCA shape: {X_features.shape}")
    print(f"Explained variance ratio: {pca.explained_variance_ratio_.sum():.4f}")
    print(f"  (First {N_COMPONENTS} components explain {pca.explained_variance_ratio_.sum()*100:.2f}% of variance)")
else:
    X_features = X_scaled
    pca = None

print(f"\nFeature extraction complete")
print(f"Final feature shape: {X_features.shape}")


## 4. K-Means Clustering


In [None]:
# Initialize and fit K-Means
print(f"Training K-Means with {N_CLUSTERS} clusters...")
print(f"Max iterations: {MAX_ITER}")
print(f"Number of initializations: {N_INIT}")

kmeans = KMeans(
    n_clusters=N_CLUSTERS,
    max_iter=MAX_ITER,
    n_init=N_INIT,
    random_state=RANDOM_STATE,
    verbose=1
)

# Fit the model
kmeans.fit(X_features)

print("\nK-Means training completed!")
print(f"Inertia (within-cluster sum of squares): {kmeans.inertia_:.2f}")
print(f"Number of iterations: {kmeans.n_iter_}")


### 4.1 Cluster Assignments


In [None]:
# Get cluster assignments
cluster_labels = kmeans.labels_
cluster_centers = kmeans.cluster_centers_

print(f"Cluster assignments shape: {cluster_labels.shape}")
print(f"Cluster centers shape: {cluster_centers.shape}")

# Count samples per cluster
unique, counts = np.unique(cluster_labels, return_counts=True)
print("\nSamples per cluster:")
for cluster_id, count in zip(unique, counts):
    print(f"  Cluster {cluster_id}: {count} samples")

# Create results dataframe
results_df = pd.DataFrame({
    'image_path': train_paths,
    'true_label': train_labels,
    'cluster_id': cluster_labels
})

# Save results
results_df.to_csv(RESULTS_SAVE_PATH, index=False)
print(f"\nResults saved to {RESULTS_SAVE_PATH}")


### 4.2 Cluster-to-Class Mapping Analysis


In [None]:
# Analyze which classes are assigned to which clusters
cluster_class_mapping = pd.crosstab(results_df['cluster_id'], results_df['true_label'])
print("Cluster-to-Class Mapping:")
print(cluster_class_mapping)

# Visualize cluster-class mapping
plt.figure(figsize=(10, 6))
sns.heatmap(cluster_class_mapping, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CLASS_NAMES, yticklabels=[f'Cluster {i}' for i in range(N_CLUSTERS)],
            cbar_kws={'label': 'Number of Samples'})
plt.title('Cluster-to-Class Mapping', fontsize=14, fontweight='bold')
plt.xlabel('True Class', fontsize=12)
plt.ylabel('Cluster ID', fontsize=12)
plt.tight_layout()
plt.savefig(f'{VISUALIZATION_DIR}/cluster_class_mapping.png', dpi=300, bbox_inches='tight')
plt.show()

# Create confusion matrix comparing clusters to true labels
# Map each cluster to the most common class in that cluster
cluster_to_class = {}
for cluster_id in range(N_CLUSTERS):
    cluster_data = results_df[results_df['cluster_id'] == cluster_id]
    if len(cluster_data) > 0:
        most_common_class = cluster_data['true_label'].mode()[0]
        cluster_to_class[cluster_id] = most_common_class
    else:
        cluster_to_class[cluster_id] = CLASS_NAMES[0]

# Map cluster labels to predicted class labels
predicted_labels = np.array([cluster_to_class[cluster_id] for cluster_id in cluster_labels])

# Define y_true and y_pred for evaluation (convert string labels to numeric indices)
label_to_num = {label: idx for idx, label in enumerate(CLASS_NAMES)}
y_true = np.array([label_to_num[label] for label in train_labels])
y_pred = np.array([label_to_num[label] for label in predicted_labels])

# Calculate and print overall accuracy score
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_true, y_pred)
print(f"\nOverall Accuracy Score: {accuracy:.4f} ({accuracy*100:.2f}%)")

# Generate confusion matrix
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - K-Means Clustering', fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Predicted Label (from Cluster)', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(f'{VISUALIZATION_DIR}/confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

# Print confusion matrix as table
print("\nConfusion Matrix (Cluster assignments mapped to most common class):")
print(f"{'':<15}", end='')
for name in CLASS_NAMES:
    print(f"{name:<15}", end='')
print()
for i, name in enumerate(CLASS_NAMES):
    print(f"{name:<15}", end='')
    for j in range(len(CLASS_NAMES)):
        print(f"{cm[i][j]:<15}", end='')
    print()


## 5. Model Evaluation


In [None]:
# Calculate silhouette score
silhouette_avg = silhouette_score(X_features, cluster_labels)
print(f"Average Silhouette Score: {silhouette_avg:.4f}")
print("  (Range: -1 to 1, higher is better)")

# Calculate silhouette scores for each sample
sample_silhouette_values = silhouette_samples(X_features, cluster_labels)

# Calculate Adjusted Rand Index (if we want to compare with true labels)
# Convert true labels to numeric
label_to_num = {label: idx for idx, label in enumerate(CLASS_NAMES)}
true_label_numeric = np.array([label_to_num[label] for label in train_labels])
ari_score = adjusted_rand_score(true_label_numeric, cluster_labels)
print(f"\nAdjusted Rand Index: {ari_score:.4f}")
print("  (Range: -1 to 1, 1 = perfect match, 0 = random)")

# Print evaluation summary
print("\n" + "=" * 60)
print("Evaluation Summary:")
print("=" * 60)
print(f"Inertia: {kmeans.inertia_:.2f}")
print(f"Average Silhouette Score: {silhouette_avg:.4f}")
print(f"Adjusted Rand Index: {ari_score:.4f}")
print("=" * 60)


### 5.1 Silhouette Analysis


In [None]:
# Visualize silhouette scores
fig, ax = plt.subplots(figsize=(12, 8))
y_lower = 10

for i in range(N_CLUSTERS):
    # Aggregate silhouette scores for samples belonging to cluster i
    ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i]
    ith_cluster_silhouette_values.sort()
    
    size_cluster_i = ith_cluster_silhouette_values.shape[0]
    y_upper = y_lower + size_cluster_i
    
    color = plt.cm.nipy_spectral(float(i) / N_CLUSTERS)
    ax.fill_betweenx(np.arange(y_lower, y_upper), 0, ith_cluster_silhouette_values,
                     facecolor=color, edgecolor=color, alpha=0.7)
    
    # Label the silhouette plots with their cluster numbers at the middle
    ax.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
    
    # Compute the new y_lower for next plot
    y_lower = y_upper + 10

ax.set_xlabel('Silhouette Coefficient Values', fontsize=12)
ax.set_ylabel('Cluster Label', fontsize=12)
ax.set_title('Silhouette Analysis for K-Means Clustering', fontsize=14, fontweight='bold')
ax.axvline(x=silhouette_avg, color="red", linestyle="--", 
           label=f'Average Score: {silhouette_avg:.4f}')
ax.set_yticks([])
ax.set_xlim([-0.1, 1])
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{VISUALIZATION_DIR}/silhouette_analysis.png', dpi=300, bbox_inches='tight')
plt.show()


## 6. Visualization


### 6.1 PCA Visualization (2D)


In [None]:
# Reduce to 2D for visualization using PCA
pca_2d = PCA(n_components=2, random_state=RANDOM_STATE)
X_2d = pca_2d.fit_transform(X_scaled)

# Create visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: Clusters
scatter1 = axes[0].scatter(X_2d[:, 0], X_2d[:, 1], c=cluster_labels, 
                          cmap='viridis', alpha=0.6, s=20)
axes[0].set_xlabel(f'First Principal Component ({pca_2d.explained_variance_ratio_[0]*100:.1f}% variance)', fontsize=11)
axes[0].set_ylabel(f'Second Principal Component ({pca_2d.explained_variance_ratio_[1]*100:.1f}% variance)', fontsize=11)
axes[0].set_title('K-Means Clustering Results (2D PCA)', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
plt.colorbar(scatter1, ax=axes[0], label='Cluster ID')

# Plot 2: True Labels
label_colors = {CLASS_NAMES[0]: 0, CLASS_NAMES[1]: 1, CLASS_NAMES[2]: 2, CLASS_NAMES[3]: 3}
label_numeric = np.array([label_colors[label] for label in train_labels])
scatter2 = axes[1].scatter(X_2d[:, 0], X_2d[:, 1], c=label_numeric, 
                          cmap='viridis', alpha=0.6, s=20)
axes[1].set_xlabel(f'First Principal Component ({pca_2d.explained_variance_ratio_[0]*100:.1f}% variance)', fontsize=11)
axes[1].set_ylabel(f'Second Principal Component ({pca_2d.explained_variance_ratio_[1]*100:.1f}% variance)', fontsize=11)
axes[1].set_title('True Labels (2D PCA)', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)
plt.colorbar(scatter2, ax=axes[1], label='Class ID')

plt.tight_layout()
plt.savefig(f'{VISUALIZATION_DIR}/pca_2d_visualization.png', dpi=300, bbox_inches='tight')
plt.show()


### 6.2 Cluster Centers Visualization


In [None]:
# Visualize cluster centers
if USE_FEATURE_EXTRACTION and pca is not None:
    # Transform cluster centers back to original space
    cluster_centers_original = pca.inverse_transform(cluster_centers)
    cluster_centers_original = scaler.inverse_transform(cluster_centers_original)
else:
    cluster_centers_original = scaler.inverse_transform(cluster_centers)

# Reshape to image format
cluster_centers_images = cluster_centers_original.reshape(N_CLUSTERS, IMG_SIZE[0], IMG_SIZE[1])

# Visualize cluster centers
fig, axes = plt.subplots(1, N_CLUSTERS, figsize=(16, 4))
for i in range(N_CLUSTERS):
    axes[i].imshow(cluster_centers_images[i], cmap='gray')
    axes[i].set_title(f'Cluster {i} Center', fontsize=12, fontweight='bold')
    axes[i].axis('off')

plt.suptitle('K-Means Cluster Centers (Averaged Images)', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(f'{VISUALIZATION_DIR}/cluster_centers.png', dpi=300, bbox_inches='tight')
plt.show()


### 6.3 Sample Images from Each Cluster


In [None]:
# Display sample images from each cluster
n_samples_per_cluster = 5
fig, axes = plt.subplots(N_CLUSTERS, n_samples_per_cluster, figsize=(15, 12))

for cluster_id in range(N_CLUSTERS):
    cluster_indices = np.where(cluster_labels == cluster_id)[0]
    if len(cluster_indices) > 0:
        # Randomly sample images from this cluster
        sample_indices = np.random.choice(cluster_indices, 
                                         size=min(n_samples_per_cluster, len(cluster_indices)), 
                                         replace=False)
        
        for idx, sample_idx in enumerate(sample_indices):
            img = train_images[sample_idx]
            true_label = train_labels[sample_idx]
            
            axes[cluster_id, idx].imshow(img, cmap='gray')
            axes[cluster_id, idx].set_title(f'True: {true_label}', fontsize=9)
            axes[cluster_id, idx].axis('off')
    else:
        # Empty cluster
        for idx in range(n_samples_per_cluster):
            axes[cluster_id, idx].axis('off')
            axes[cluster_id, idx].text(0.5, 0.5, 'No samples', 
                                      ha='center', va='center', fontsize=10)

plt.suptitle('Sample Images from Each Cluster', fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig(f'{VISUALIZATION_DIR}/sample_images_by_cluster.png', dpi=300, bbox_inches='tight')
plt.show()


## 7. Elbow Method (Optimal K Selection)


In [None]:
# Test different values of K to find optimal number of clusters
print("Testing different values of K...")
K_range = range(2, 11)
inertias = []
silhouette_scores = []

for k in tqdm(K_range):
    kmeans_test = KMeans(n_clusters=k, max_iter=MAX_ITER, n_init=5, 
                        random_state=RANDOM_STATE, verbose=0)
    kmeans_test.fit(X_features)
    inertias.append(kmeans_test.inertia_)
    silhouette_scores.append(silhouette_score(X_features, kmeans_test.labels_))

# Plot elbow curve
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Inertia plot (Elbow method)
axes[0].plot(K_range, inertias, 'bo-', linewidth=2, markersize=8)
axes[0].set_xlabel('Number of Clusters (K)', fontsize=12)
axes[0].set_ylabel('Inertia', fontsize=12)
axes[0].set_title('Elbow Method for Optimal K', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].axvline(x=N_CLUSTERS, color='r', linestyle='--', 
                label=f'Selected K={N_CLUSTERS}')
axes[0].legend()

# Silhouette score plot
axes[1].plot(K_range, silhouette_scores, 'go-', linewidth=2, markersize=8)
axes[1].set_xlabel('Number of Clusters (K)', fontsize=12)
axes[1].set_ylabel('Silhouette Score', fontsize=12)
axes[1].set_title('Silhouette Score vs Number of Clusters', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].axvline(x=N_CLUSTERS, color='r', linestyle='--', 
                label=f'Selected K={N_CLUSTERS}')
axes[1].legend()

plt.tight_layout()
plt.savefig(f'{VISUALIZATION_DIR}/elbow_method.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nOptimal K (highest silhouette score): {K_range[np.argmax(silhouette_scores)]}")
print(f"Silhouette scores: {dict(zip(K_range, silhouette_scores))}")


## 8. Save Model


In [None]:
# Save the K-Means model and preprocessing objects
model_data = {
    'kmeans': kmeans,
    'scaler': scaler,
    'pca': pca,
    'n_clusters': N_CLUSTERS,
    'img_size': IMG_SIZE,
    'use_feature_extraction': USE_FEATURE_EXTRACTION,
    'n_components': N_COMPONENTS if USE_FEATURE_EXTRACTION else None
}

with open(MODEL_SAVE_PATH, 'wb') as f:
    pickle.dump(model_data, f)

print(f"Model saved to {MODEL_SAVE_PATH}")
print("\nSaved components:")
print("  - K-Means model")
print("  - StandardScaler")
if USE_FEATURE_EXTRACTION:
    print("  - PCA transformer")


## 9. Summary

### Clustering Results Summary:
- **Number of Clusters**: 4 (configurable via N_CLUSTERS)
- **Inertia**: Calculated above
- **Average Silhouette Score**: Calculated above
- **Adjusted Rand Index**: Calculated above
- **Model saved to**: `models/kmeans/kmeans_clusterer.pkl`
- **Results saved to**: `models/kmeans/kmeans_results.csv`

### Files Generated:
1. Trained model: `models/kmeans/kmeans_clusterer.pkl`
2. Cluster assignments: `models/kmeans/kmeans_results.csv`
3. Class distribution: `models/kmeans/visualizations/class_distribution.png`
4. Cluster-class mapping: `models/kmeans/visualizations/cluster_class_mapping.png`
5. Confusion matrix: `models/kmeans/visualizations/confusion_matrix.png`
6. Silhouette analysis: `models/kmeans/visualizations/silhouette_analysis.png`
7. PCA 2D visualization: `models/kmeans/visualizations/pca_2d_visualization.png`
8. Cluster centers: `models/kmeans/visualizations/cluster_centers.png`
9. Sample images by cluster: `models/kmeans/visualizations/sample_images_by_cluster.png`
10. Elbow method analysis: `models/kmeans/visualizations/elbow_method.png`

### Using the Trained Model:

**Load and use the model:**
```python
import pickle
import cv2
import numpy as np

# Load model
with open('models/kmeans/kmeans_clusterer.pkl', 'rb') as f:
    model_data = pickle.load(f)

kmeans = model_data['kmeans']
scaler = model_data['scaler']
pca = model_data['pca']

# Predict cluster for a new image
def predict_cluster(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, model_data['img_size'])
    img_flattened = img.reshape(1, -1)
    img_scaled = scaler.transform(img_flattened)
    
    if model_data['use_feature_extraction']:
        img_features = pca.transform(img_scaled)
    else:
        img_features = img_scaled
    
    cluster_id = kmeans.predict(img_features)[0]
    return cluster_id

# Example usage
cluster = predict_cluster('path/to/image.jpg')
print(f'Image belongs to cluster: {cluster}')
```

### Important Notes:
- K-Means is an unsupervised algorithm - it doesn't use class labels during training
- Cluster assignments may not perfectly match true class labels
- The Adjusted Rand Index shows how well clusters align with true classes
- You can adjust `N_CLUSTERS` to experiment with different numbers of clusters
- Use the elbow method plot to find the optimal number of clusters
