<a href="https://colab.research.google.com/github/manuaishika/softkmeans-nn/blob/main/Untitled2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ============================================
# SOFT K-MEANS NEURAL NETWORK - GOOGLE COLAB
# ============================================

# Step 1: Install/Import everything
!pip install torch torchvision scikit-learn matplotlib -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs, make_moons, make_circles
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
import warnings
warnings.filterwarnings('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# ============================================
# 1. SOFT K-MEANS LAYER
# ============================================

class SoftKMeansLayer(nn.Module):
    """
    Soft K-Means layer that learns cluster centroids
    """
    def __init__(self, input_dim, num_clusters, temperature=1.0):
        super().__init__()
        self.num_clusters = num_clusters
        self.temperature = temperature
        self.temperature_factor = nn.Parameter(torch.tensor([temperature]))

        # Initialize centroids with Xavier initialization
        self.centroids = nn.Parameter(torch.randn(num_clusters, input_dim) * 0.1)

    def forward(self, x):
        # x shape: (batch_size, input_dim)
        # centroids shape: (num_clusters, input_dim)

        # Compute squared Euclidean distance
        x_norm = (x ** 2).sum(dim=1, keepdim=True)
        c_norm = (self.centroids ** 2).sum(dim=1, keepdim=True).t()

        distances = x_norm + c_norm - 2 * torch.mm(x, self.centroids.t())

        # Apply temperature and get soft assignments
        logits = -distances / self.temperature_factor
        responsibilities = F.softmax(logits, dim=1)

        return responsibilities, distances

    def get_centroids(self):
        return self.centroids.data

# ============================================
# 2. COMPLETE NEURAL NETWORK MODEL
# ============================================

class SoftKMeansNN(nn.Module):
    """
    Neural Network with feature extraction + Soft K-Means
    """
    def __init__(self, input_dim, hidden_dims, num_clusters, temperature=1.0):
        super().__init__()

        # Build feature extraction layers
        layers = []
        prev_dim = input_dim

        for i, hidden_dim in enumerate(hidden_dims):
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.2))
            prev_dim = hidden_dim

        self.encoder = nn.Sequential(*layers)

        # Soft K-Means clustering layer
        self.soft_kmeans = SoftKMeansLayer(prev_dim, num_clusters, temperature)

        # Optional decoder for reconstruction (if needed)
        self.decoder_layers = None

    def forward(self, x):
        # Encode to features
        features = self.encoder(x)

        # Get soft assignments
        responsibilities, distances = self.soft_kmeans(features)

        return features, responsibilities, distances

    def predict(self, x):
        """Get hard cluster assignments"""
        with torch.no_grad():
            _, responsibilities, _ = self.forward(x)
            return torch.argmax(responsibilities, dim=1)

    def get_centroids(self):
        return self.soft_kmeans.get_centroids()

    def get_soft_assignments(self, x):
        """Get soft assignment probabilities"""
        with torch.no_grad():
            _, responsibilities, _ = self.forward(x)
            return responsibilities

# ============================================
# 3. TRAINER WITH MULTIPLE LOSS OPTIONS
# ============================================

class SoftKMeansTrainer:
    def __init__(self, model, learning_rate=0.001, lambda_reg=0.01):
        self.model = model
        self.model.to(device)
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        self.lambda_reg = lambda_reg
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', patience=10, factor=0.5, verbose=True
        )

    def kmeans_loss(self, responsibilities, distances):
        """Standard K-Means loss"""
        return torch.sum(responsibilities * distances)

    def entropy_regularization(self, responsibilities):
        """Encourage confident assignments"""
        entropy = -torch.sum(responsibilities * torch.log(responsibilities + 1e-10), dim=1)
        return torch.mean(entropy)

    def centroid_regularization(self):
        """Prevent centroids from collapsing"""
        centroids = self.model.get_centroids()
        centroid_distances = torch.cdist(centroids, centroids, p=2)
        mask = ~torch.eye(centroid_distances.size(0), dtype=torch.bool).to(device)
        min_distance = torch.min(centroid_distances[mask])
        return 1.0 / (min_distance + 1e-10)

    def train_step(self, x_batch):
        self.optimizer.zero_grad()

        # Forward pass
        _, responsibilities, distances = self.model(x_batch)

        # Compute losses
        main_loss = self.kmeans_loss(responsibilities, distances)
        entropy_loss = self.entropy_regularization(responsibilities)
        reg_loss = self.centroid_regularization()

        # Total loss
        total_loss = main_loss + 0.1 * entropy_loss + self.lambda_reg * reg_loss

        # Backward pass
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()

        return {
            'total_loss': total_loss.item(),
            'kmeans_loss': main_loss.item(),
            'entropy_loss': entropy_loss.item(),
            'reg_loss': reg_loss.item()
        }

    def train(self, data_loader, num_epochs=100, verbose=True):
        history = {
            'total_loss': [], 'kmeans_loss': [],
            'entropy_loss': [], 'reg_loss': []
        }

        for epoch in range(num_epochs):
            epoch_losses = {'total_loss': 0, 'kmeans_loss': 0,
                          'entropy_loss': 0, 'reg_loss': 0}

            for x_batch, _ in data_loader:
                x_batch = x_batch.to(device)
                losses = self.train_step(x_batch)

                for key in losses:
                    epoch_losses[key] += losses[key]

            # Average losses
            for key in epoch_losses:
                epoch_losses[key] /= len(data_loader)
                history[key].append(epoch_losses[key])

            # Update learning rate
            self.scheduler.step(epoch_losses['total_loss'])

            if verbose and (epoch + 1) % 10 == 0:
                print(f"Epoch [{epoch+1:3d}/{num_epochs}]: "
                      f"Total Loss: {epoch_losses['total_loss']:.4f}, "
                      f"K-Means: {epoch_losses['kmeans_loss']:.4f}, "
                      f"Entropy: {epoch_losses['entropy_loss']:.4f}")

        return history

# ============================================
# 4. DATA GENERATION AND VISUALIZATION
# ============================================

def create_dataset(dataset_type='blobs', n_samples=1000, n_features=2, n_clusters=4):
    """Create different types of datasets"""

    if dataset_type == 'blobs':
        X, y = make_blobs(
            n_samples=n_samples,
            n_features=n_features,
            centers=n_clusters,
            cluster_std=0.8,
            random_state=42
        )
    elif dataset_type == 'moons':
        X, y = make_moons(n_samples=n_samples, noise=0.1, random_state=42)
        n_clusters = 2
    elif dataset_type == 'circles':
        X, y = make_circles(n_samples=n_samples, factor=0.5, noise=0.05, random_state=42)
        n_clusters = 2
    elif dataset_type == 'aniso':
        X, y = make_blobs(n_samples=n_samples, centers=n_clusters, random_state=170)
        transformation = [[0.6, -0.6], [-0.4, 0.8]]
        X = np.dot(X, transformation)
    else:
        raise ValueError(f"Unknown dataset type: {dataset_type}")

    # Normalize data
    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    # Convert to tensors
    X_tensor = torch.FloatTensor(X).to(device)
    y_tensor = torch.LongTensor(y).to(device)

    # Create data loader
    dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=64, shuffle=True
    )

    return X_tensor, y_tensor, data_loader, n_clusters

def visualize_results(X, predictions, soft_probs, centroids,
                      history=None, dataset_name="Dataset"):
    """Visualize clustering results with multiple plots"""

    fig = plt.figure(figsize=(18, 10))

    # Plot 1: Clustering results
    ax1 = plt.subplot(2, 3, 1)
    scatter = ax1.scatter(X[:, 0], X[:, 1], c=predictions,
                         cmap='tab20', alpha=0.6, s=30)
    if centroids is not None:
        ax1.scatter(centroids[:, 0], centroids[:, 1],
                   c='red', marker='X', s=300, linewidths=2,
                   edgecolor='black', label='Centroids')
    ax1.set_title(f'{dataset_name} - Cluster Assignments')
    ax1.set_xlabel('Feature 1')
    ax1.set_ylabel('Feature 2')
    ax1.legend()
    plt.colorbar(scatter, ax=ax1)

    # Plot 2: Soft assignment probabilities
    ax2 = plt.subplot(2, 3, 2)
    uncertainty = 1 - np.max(soft_probs, axis=1)
    sc = ax2.scatter(X[:, 0], X[:, 1], c=uncertainty,
                    cmap='viridis', alpha=0.6, s=30)
    ax2.set_title('Assignment Uncertainty')
    ax2.set_xlabel('Feature 1')
    ax2.set_ylabel('Feature 2')
    plt.colorbar(sc, ax=ax2)

    # Plot 3: Top 2 probabilities
    ax3 = plt.subplot(2, 3, 3)
    top2_diff = np.sort(soft_probs, axis=1)[:, -1] - np.sort(soft_probs, axis=1)[:, -2]
    sc = ax3.scatter(X[:, 0], X[:, 1], c=top2_diff,
                    cmap='coolwarm', alpha=0.6, s=30)
    ax3.set_title('Difference: Top 2 Probabilities')
    ax3.set_xlabel('Feature 1')
    ax3.set_ylabel('Feature 2')
    plt.colorbar(sc, ax=ax3)

    # Plot 4: Loss curves
    if history is not None:
        ax4 = plt.subplot(2, 3, 4)
        ax4.plot(history['total_loss'], label='Total Loss', linewidth=2)
        ax4.plot(history['kmeans_loss'], label='K-Means Loss', linewidth=2)
        ax4.set_title('Training Loss Curves')
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Loss')
        ax4.legend()
        ax4.grid(True, alpha=0.3)

        ax5 = plt.subplot(2, 3, 5)
        ax5.plot(history['entropy_loss'], label='Entropy Loss', color='green', linewidth=2)
        ax5.plot(history['reg_loss'], label='Reg Loss', color='red', linewidth=2)
        ax5.set_title('Auxiliary Losses')
        ax5.set_xlabel('Epoch')
        ax5.set_ylabel('Loss')
        ax5.legend()
        ax5.grid(True, alpha=0.3)

    # Plot 6: Probability distribution
    ax6 = plt.subplot(2, 3, 6)
    for cluster in range(soft_probs.shape[1]):
        ax6.hist(soft_probs[:, cluster], bins=30, alpha=0.5,
                label=f'Cluster {cluster}')
    ax6.set_title('Probability Distribution per Cluster')
    ax6.set_xlabel('Probability')
    ax6.set_ylabel('Frequency')
    ax6.legend()
    ax6.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Print metrics
    if predictions is not None:
        print(f"Number of clusters: {len(np.unique(predictions))}")
        print(f"Cluster sizes: {np.bincount(predictions)}")
        silhouette = silhouette_score(X, predictions)
        print(f"Silhouette Score: {silhouette:.4f}")

# ============================================
# 5. MAIN EXPERIMENT FUNCTION
# ============================================

def run_experiment(dataset_type='blobs', n_samples=1000,
                   hidden_dims=[32, 16], num_epochs=100):
    """Run complete experiment from data to visualization"""

    print("=" * 60)
    print(f"EXPERIMENT: {dataset_type.upper()} Dataset")
    print("=" * 60)

    # Create dataset
    X, y_true, data_loader, n_clusters = create_dataset(
        dataset_type=dataset_type,
        n_samples=n_samples,
        n_clusters=4 if dataset_type == 'blobs' else 2
    )

    # Initialize model
    model = SoftKMeansNN(
        input_dim=X.shape[1],
        hidden_dims=hidden_dims,
        num_clusters=n_clusters,
        temperature=0.5
    )

    # Initialize trainer
    trainer = SoftKMeansTrainer(
        model,
        learning_rate=0.001,
        lambda_reg=0.01
    )

    # Train model
    print("\nTraining Soft K-Means Neural Network...")
    history = trainer.train(
        data_loader,
        num_epochs=num_epochs,
        verbose=True
    )

    # Get predictions
    with torch.no_grad():
        predictions = model.predict(X).cpu().numpy()
        soft_probs = model.get_soft_assignments(X).cpu().numpy()
        centroids = model.get_centroids().cpu().numpy()

    # Visualize results
    print("\nVisualizing results...")
    visualize_results(
        X.cpu().numpy(),
        predictions,
        soft_probs,
        centroids,
        history,
        dataset_name=dataset_type.capitalize()
    )

    return model, predictions, history

# ============================================
# 6. RUN MULTIPLE EXPERIMENTS
# ============================================

def run_all_experiments():
    """Run experiments on different dataset types"""

    dataset_types = ['blobs', 'moons', 'circles']
    results = {}

    for dataset_type in dataset_types:
        print("\n" + "="*60)
        print(f"Running experiment on {dataset_type} dataset...")
        print("="*60)

        # Adjust hidden dimensions based on dataset complexity
        if dataset_type == 'blobs':
            hidden_dims = [16, 8]  # Simpler network for blobs
        else:
            hidden_dims = [32, 16, 8]  # Deeper network for complex shapes

        model, predictions, history = run_experiment(
            dataset_type=dataset_type,
            n_samples=1000,
            hidden_dims=hidden_dims,
            num_epochs=80
        )

        results[dataset_type] = {
            'model': model,
            'predictions': predictions,
            'history': history
        }

    return results

# ============================================
# 7. QUICK DEMO (Run this first!)
# ============================================

def quick_demo():
    """Quick demonstration of soft k-means"""
    print("üîç QUICK DEMO: Soft K-Means Neural Network")

    # Simple example
    X_tensor, y_true, data_loader, n_clusters = create_dataset(
        dataset_type='blobs',
        n_samples=500,
        n_clusters=3
    )

    # Simple model
    model = SoftKMeansNN(
        input_dim=2,
        hidden_dims=[10],
        num_clusters=n_clusters,
        temperature=0.3
    )

    # Quick training
    trainer = SoftKMeansTrainer(model, learning_rate=0.01)

    print("\nTraining for 30 epochs...")
    history = trainer.train(data_loader, num_epochs=30, verbose=True)

    # Get results
    predictions = model.predict(X_tensor).cpu().numpy()
    soft_probs = model.get_soft_assignments(X_tensor).cpu().numpy()
    centroids = model.get_centroids().cpu().numpy()

    # Quick visualization
    plt.figure(figsize=(15, 4))

    plt.subplot(1, 3, 1)
    plt.scatter(X_tensor.cpu()[:, 0], X_tensor.cpu()[:, 1],
                c=predictions, cmap='viridis', alpha=0.6)
    plt.scatter(centroids[:, 0], centroids[:, 1],
                c='red', marker='X', s=200, label='Centroids')
    plt.title("Cluster Assignments")
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.scatter(X_tensor.cpu()[:, 0], X_tensor.cpu()[:, 1],
                c=np.max(soft_probs, axis=1), cmap='plasma', alpha=0.6)
    plt.title("Maximum Probability")
    plt.colorbar()

    plt.subplot(1, 3, 3)
    plt.plot(history['total_loss'], label='Total Loss')
    plt.plot(history['kmeans_loss'], label='K-Means Loss')
    plt.title("Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    print("\n‚úÖ Demo completed!")
    print(f"Final centroids:\n{centroids}")
    print(f"\nCluster distribution: {np.bincount(predictions)}")

# ============================================
# 8. HOW TO RUN
# ============================================

"""
INSTRUCTIONS FOR GOOGLE COLAB:

1. Open Google Colab: https://colab.research.google.com
2. Create a new notebook
3. Copy ALL this code into a cell
4. Run ONE of these commands:

Option A: Quick Demo (Recommended first)
"""
# quick_demo()

"""
Option B: Single Experiment
"""
# run_experiment(dataset_type='moons', num_epochs=50)

"""
Option C: All Experiments
"""
# results = run_all_experiments()

"""
Option D: Custom Experiment
"""
# # Create custom data
# X, y = make_blobs(n_samples=1000, centers=5, random_state=42)
# X = StandardScaler().fit_transform(X)
# X_tensor = torch.FloatTensor(X).to(device)

# # Create model
# model = SoftKMeansNN(
#     input_dim=2,
#     hidden_dims=[32, 16],
#     num_clusters=5,
#     temperature=0.4
# )

# # Train
# trainer = SoftKMeansTrainer(model)
# dataset = torch.utils.data.TensorDataset(X_tensor, torch.zeros(len(X_tensor)))
# data_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
# history = trainer.train(data_loader, num_epochs=100)

# ============================================
# RUN THE QUICK DEMO FIRST!
# ============================================

if __name__ == "__main__":
    print("üöÄ Soft K-Means Neural Network - Google Colab Ready!")
    print("\nRun one of these functions:")
    print("1. quick_demo() - For a quick test")
    print("2. run_experiment('moons') - For moon dataset")
    print("3. run_all_experiments() - For all datasets")

    # Uncomment one line below to run:
    # quick_demo()
    # run_experiment('blobs')
    # results = run_all_experiments()

Using device: cpu
GPU available: False
üöÄ Soft K-Means Neural Network - Google Colab Ready!

Run one of these functions:
1. quick_demo() - For a quick test
2. run_experiment('moons') - For moon dataset
3. run_all_experiments() - For all datasets
