In [None]:
"""
Deep Reinforcement Clustering (DRC) - PAPER IMPLEMENTATION
Dataset: MNIST
Based on: Li et al., "Deep Reinforcement Clustering" (IEEE TMM 2023)

"""

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.optimize import linear_sum_assignment
from sklearn.manifold import TSNE
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


class AutoEncoder(nn.Module):
    """Deep AutoEncoder"""
    def __init__(self, input_dim=784, hidden_dims=[500, 500, 2000], latent_dim=10):
        super(AutoEncoder, self).__init__()
        
        # Encoder
        encoder_layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            encoder_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU()
            ])
            prev_dim = hidden_dim
        encoder_layers.append(nn.Linear(prev_dim, latent_dim))
        self.encoder = nn.Sequential(*encoder_layers)
        
        # Decoder
        decoder_layers = []
        prev_dim = latent_dim
        for hidden_dim in reversed(hidden_dims):
            decoder_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU()
            ])
            prev_dim = hidden_dim
        decoder_layers.append(nn.Linear(prev_dim, input_dim))
        self.decoder = nn.Sequential(*decoder_layers)
    
    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return z, x_recon


class DRC(nn.Module):
    """Deep Reinforcement Clustering"""
    def __init__(self, input_dim=784, hidden_dims=[500, 500, 2000], 
                 latent_dim=10, n_clusters=10):
        super(DRC, self).__init__()
        
        self.n_clusters = n_clusters
        self.latent_dim = latent_dim
        self.autoencoder = AutoEncoder(input_dim, hidden_dims, latent_dim)
        self.cluster_centers = nn.Parameter(torch.randn(n_clusters, latent_dim))
        
    def forward(self, x):
        z, x_recon = self.autoencoder(x)
        return z, x_recon
    
    def cauchy_similarity(self, z, kappa=1.0):
        z_expanded = z.unsqueeze(1)  # [batch, 1, dim]
        centers_expanded = self.cluster_centers.unsqueeze(0)  # [1, K, dim]
        distances_sq = torch.sum((z_expanded - centers_expanded) ** 2, dim=2)
        similarities = (1.0 / np.pi) * (kappa / (distances_sq + kappa**2))
        return similarities
    
    def decision_probability(self, z, kappa=1.0):
        similarities = self.cauchy_similarity(z, kappa)
        probs = torch.sigmoid(similarities)
        return probs
    
    def get_cluster_assignments(self, z):
        probs = self.decision_probability(z)
        return torch.argmax(probs, dim=1)


def pretrain_autoencoder(model, train_loader, epochs=50, lr=0.001):
    print("\n=== Pretraining AutoEncoder ===")
    optimizer = optim.Adam(model.autoencoder.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data, _ in train_loader:
            data = data.view(data.size(0), -1).to(device)
            optimizer.zero_grad()
            z, x_recon = model.autoencoder(data)
            loss = criterion(x_recon, data)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.6f}")
    print("Pretraining completed!")


def initialize_cluster_centers(model, train_loader, n_clusters):
    print("\n=== Initializing Cluster Centers ===")
    model.eval()
    
    all_z = []
    with torch.no_grad():
        for data, _ in train_loader:
            data = data.view(data.size(0), -1).to(device)
            z, _ = model.autoencoder(data)
            all_z.append(z.cpu())
    
    all_z = torch.cat(all_z, dim=0).numpy()
    
    from sklearn.cluster import KMeans
    kmeans = KMeans(n_clusters=n_clusters, n_init=20, random_state=42)
    kmeans.fit(all_z)
    
    model.cluster_centers.data = torch.tensor(
        kmeans.cluster_centers_, dtype=torch.float32
    ).to(device)
    print("Cluster centers initialized!")


def train_drc(model, train_loader, val_loader, epochs=100, 
              lr=0.0001, gamma=0.01, v=100.0):
    """
    Train Deep Reinforcement Clustering
    Following paper's Equation (8) exactly
    """
    print("\n=== Training DRC ===")
    print(f"Gamma: {gamma}, Reward v: {v}")
    

    network_params = model.autoencoder.parameters()
    center_params = [model.cluster_centers]
    optimizer = torch.optim.Adam(network_params, lr=lr)
    
    
    criterion_recon = nn.MSELoss()
    
    history = {
        'loss': [], 'recon_loss': [], 'rc_loss': [],
        'val_acc': [], 'val_nmi': [], 'val_ari': []
    }
    
    best_val_acc = 0.0
    
    for epoch in range(epochs):
        model.train()
        
        total_loss = 0
        total_recon_loss = 0
        total_rc_loss = 0
        
        for data, _ in train_loader:
            data = data.view(data.size(0), -1).to(device)
            batch_size = data.size(0)
            
            optimizer.zero_grad()

            
            # Forward pass
            z, x_recon = model(data)
            
            # Reconstruction loss 
            L_rec = criterion_recon(x_recon, data)
            
            # Get probability
            probs = model.decision_probability(z)
            
            # Action selection 
            max_probs, max_indices = torch.max(probs, dim=1)
            
            # Bernoulli indicator 
            p_random = torch.rand(batch_size, device=device)
            y_ij = (max_probs > p_random).float()
            
            # Rewards 
            rewards = v * (2 * y_ij - 1)
            
            # Get selected probabilities
            selected_probs = probs.gather(1, max_indices.unsqueeze(1)).squeeze(1)
            
            
            log_term = y_ij * torch.log(selected_probs ) + \
                       (y_ij - 1) * torch.log(1 - selected_probs)
            
            # Cumulative reward
            L_rc = -gamma * torch.mean(rewards * log_term)
            
            # Total loss
            loss = L_rec + L_rc
            
            # Backward pass
            loss.backward()
            
            # Update parameters
            optimizer.step()
            
            total_loss += loss.item()
            total_recon_loss += L_rec.item()
            total_rc_loss += L_rc.item()
        
        avg_loss = total_loss / len(train_loader)
        avg_recon = total_recon_loss / len(train_loader)
        avg_rc = total_rc_loss / len(train_loader)
        
        history['loss'].append(avg_loss)
        history['recon_loss'].append(avg_recon)
        history['rc_loss'].append(avg_rc)
        
        if (epoch + 1) % 10 == 0 or epoch == 0:
            val_metrics = evaluate_model(model, val_loader)
            history['val_acc'].append(val_metrics['ACC'])
            history['val_nmi'].append(val_metrics['NMI'])
            history['val_ari'].append(val_metrics['ARI'])
            
            print(f"Epoch [{epoch+1}/{epochs}]")
            print(f"  Loss: {avg_loss:.4f} | Recon: {avg_recon:.5f} | RC: {avg_rc:.4f}")
            print(f"  Val ACC: {val_metrics['ACC']:.4f} | NMI: {val_metrics['NMI']:.4f} | ARI: {val_metrics['ARI']:.4f}")
            
            if val_metrics['ACC'] > best_val_acc:
                best_val_acc = val_metrics['ACC']
                print(f"  âœ“ New best validation ACC: {best_val_acc:.4f}")
    
    print(f"\nTraining completed! Best Val ACC: {best_val_acc:.4f}")
    return history


def predict_clusters(model, data_loader):
    """Predict cluster assignments"""
    model.eval()
    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        for data, labels in data_loader:
            data = data.view(data.size(0), -1).to(device)
            z, _ = model(data)
            preds = model.get_cluster_assignments(z)
            all_labels.append(labels.numpy())
            all_preds.append(preds.cpu().numpy())
    
    return np.concatenate(all_labels), np.concatenate(all_preds)


def cluster_accuracy(y_true, y_pred):
    """Calculate clustering accuracy using Hungarian algorithm"""
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    row_ind, col_ind = linear_sum_assignment(w.max() - w)
    return w[row_ind, col_ind].sum() / y_pred.size


def evaluate_clustering(y_true, y_pred):
    """Evaluate clustering performance"""
    acc = cluster_accuracy(y_true, y_pred)
    nmi = normalized_mutual_info_score(y_true, y_pred)
    ari = adjusted_rand_score(y_true, y_pred)
    return {'ACC': acc, 'NMI': nmi, 'ARI': ari}


def evaluate_model(model, data_loader):
    """Evaluate model on a dataset"""
    y_true, y_pred = predict_clusters(model, data_loader)
    return evaluate_clustering(y_true, y_pred)


def visualize_tsne_clusters(model, test_loader, max_samples=2000):
    """Visualize clusters using t-SNE """
    print("\n=== Generating t-SNE Visualization ===")
    model.eval()
    
    all_z = []
    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        for data, labels in test_loader:
            data = data.view(data.size(0), -1).to(device)
            z, _ = model(data)
            preds = model.get_cluster_assignments(z)
            
            all_z.append(z.cpu())
            all_labels.append(labels)
            all_preds.append(preds.cpu())
            
            if len(all_z) * data.size(0) >= max_samples:
                break
    
    all_z = torch.cat(all_z, dim=0)[:max_samples].numpy()
    all_labels = torch.cat(all_labels, dim=0)[:max_samples].numpy()
    all_preds = torch.cat(all_preds, dim=0)[:max_samples].numpy()
    
    print(f"Running t-SNE on {len(all_z)} samples...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30, max_iter=1000)
    z_2d = tsne.fit_transform(all_z)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    colors = plt.cm.tab10(np.linspace(0, 1, 10))
    
    # Ground truth
    ax1 = axes[0]
    for i in range(10):
        mask = all_labels == i
        ax1.scatter(z_2d[mask, 0], z_2d[mask, 1], 
                   c=[colors[i]], label=f'Class {i}', 
                   alpha=0.6, s=20, edgecolors='none')
    ax1.set_title('Ground Truth Labels', fontsize=16, fontweight='bold')
    ax1.legend(loc='upper right', fontsize=9, ncol=2)
    ax1.set_xlabel('t-SNE 1', fontsize=12)
    ax1.set_ylabel('t-SNE 2', fontsize=12)
    ax1.grid(True, alpha=0.2)
    
    # Predicted clusters
    ax2 = axes[1]
    for i in range(10):
        mask = all_preds == i
        ax2.scatter(z_2d[mask, 0], z_2d[mask, 1], 
                   c=[colors[i]], label=f'Cluster {i}', 
                   alpha=0.6, s=20, edgecolors='none')
    ax2.set_title('DRC Predicted Clusters', fontsize=16, fontweight='bold')
    ax2.legend(loc='upper right', fontsize=9, ncol=2)
    ax2.set_xlabel('t-SNE 1', fontsize=12)
    ax2.set_ylabel('t-SNE 2', fontsize=12)
    ax2.grid(True, alpha=0.2)
    
    plt.tight_layout()
    plt.savefig('drc_tsne_clusters.png', dpi=300, bbox_inches='tight')
    print("t-SNE visualization saved to 'drc_tsne_clusters.png'")
    plt.show()


def visualize_cluster_distribution(model, test_loader, title="MNIST Clusters"):
    """
    Create a single t-SNE plot showing cluster distribution
    """
    print(f"\n=== Generating {title} Visualization ===")
    model.eval()
    
    all_z = []
    all_labels = []
    all_preds = []
    
    max_samples = 2000
    
    with torch.no_grad():
        for data, labels in test_loader:
            data = data.view(data.size(0), -1).to(device)
            z, _ = model(data)
            preds = model.get_cluster_assignments(z)
            
            all_z.append(z.cpu())
            all_labels.append(labels)
            all_preds.append(preds.cpu())
            
            if len(all_z) * data.size(0) >= max_samples:
                break
    
    all_z = torch.cat(all_z, dim=0)[:max_samples].numpy()
    all_labels = torch.cat(all_labels, dim=0)[:max_samples].numpy()
    all_preds = torch.cat(all_preds, dim=0)[:max_samples].numpy()
    
    print(f"Running t-SNE on {len(all_z)} samples...")
    
    tsne = TSNE(n_components=2, random_state=42, perplexity=30, max_iter=1000)
    z_2d = tsne.fit_transform(all_z)
    
    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    colors = plt.cm.tab10(np.linspace(0, 1, 10))
    
    for i in range(10):
        mask = all_preds == i
        ax.scatter(z_2d[mask, 0], z_2d[mask, 1], 
                  c=[colors[i]], 
                  label=f'{i}',
                  alpha=0.7, 
                  s=25, 
                  edgecolors='white',
                  linewidth=0.5)
    
    ax.set_xticks([])
    ax.set_yticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
    ax.text(0.5, -0.05, f'(a) {title}', 
            transform=ax.transAxes,
            ha='center', 
            fontsize=14, 
            fontweight='bold')
    
    
    plt.tight_layout()
    
    filename = f'cluster_distribution_{title.lower().replace(" ", "_")}.png'
    plt.savefig(filename, dpi=300, bbox_inches='tight', facecolor='white')
    print(f"Cluster distribution saved to '{filename}'")
    plt.show()
    
    return z_2d, all_labels, all_preds


def visualize_results(history, test_metrics):
    """Visualize training results"""
    fig = plt.figure(figsize=(18, 5))
    
    # Loss curves
    ax1 = plt.subplot(1, 3, 1)
    epochs = range(1, len(history['loss']) + 1)
    ax1.plot(epochs, history['loss'], label='Total Loss', linewidth=2, marker='o', markersize=3)
    ax1.plot(epochs, history['recon_loss'], label='Recon Loss', linewidth=2, marker='s', markersize=3)
    ax1.plot(epochs, history['rc_loss'], label='RC Loss', linewidth=2, marker='^', markersize=3)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Training Loss', fontsize=14, fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Validation metrics
    ax2 = plt.subplot(1, 3, 2)
    val_epochs = [1] + list(range(10, len(history['loss']) + 1, 10))
    ax2.plot(val_epochs, history['val_acc'], label='ACC', linewidth=2, marker='o', markersize=6)
    ax2.plot(val_epochs, history['val_nmi'], label='NMI', linewidth=2, marker='s', markersize=6)
    ax2.plot(val_epochs, history['val_ari'], label='ARI', linewidth=2, marker='^', markersize=6)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Score', fontsize=12)
    ax2.set_title('Validation Metrics', fontsize=14, fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim([0, 1])
    
    # Test metrics
    ax3 = plt.subplot(1, 3, 3)
    metric_names = list(test_metrics.keys())
    metric_values = list(test_metrics.values())
    colors = ['#2ecc71', '#3498db', '#e74c3c']
    bars = ax3.bar(metric_names, metric_values, color=colors, alpha=0.7, 
                   edgecolor='black', linewidth=2)
    ax3.set_ylabel('Score', fontsize=12)
    ax3.set_title('Test Performance', fontsize=14, fontweight='bold')
    ax3.set_ylim([0, 1])
    ax3.grid(True, alpha=0.3, axis='y')
    
    for bar in bars:
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.4f}', ha='center', va='bottom', 
                fontsize=11, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('drc_results.png', dpi=300, bbox_inches='tight')
    print("\nResults saved to 'drc_results.png'")
    plt.show()


def main():   
    INPUT_DIM = 784
    HIDDEN_DIMS = [500, 500, 2000]
    LATENT_DIM = 10
    N_CLUSTERS = 10
    BATCH_SIZE = 256
    
    PRETRAIN_EPOCHS = 50
    TRAIN_EPOCHS = 200
    LR = 0.0001
    
    GAMMA = 0.01 
    V = 100.0    
    
    print("=== DRC ===")
    print(f"Architecture: {INPUT_DIM} -> {HIDDEN_DIMS} -> {LATENT_DIM}")
    print(f"Learning rate: {LR}")
    print(f"Gamma: {GAMMA}")
    print(f"Reward v: {V}")
    print(f"Epochs: {TRAIN_EPOCHS}")
    
    # Load MNIST
    print("\n=== Loading MNIST Dataset ===")
    transform = transforms.Compose([transforms.ToTensor()])
    
    train_dataset = datasets.MNIST(root='./data', train=True, 
                                   download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, 
                                  download=True, transform=transform)
    
    # Train/val split
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_subset, val_subset = torch.utils.data.random_split(
        train_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, 
                             shuffle=True, num_workers=2)
    val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, 
                           shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, 
                            shuffle=False, num_workers=2)
    
    print(f"Train: {len(train_subset)}, Val: {len(val_subset)}, Test: {len(test_dataset)}")
    
    # model
    model = DRC(INPUT_DIM, HIDDEN_DIMS, LATENT_DIM, N_CLUSTERS).to(device)
    print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    full_train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                                   shuffle=True, num_workers=2)
    
    pretrain_autoencoder(model, full_train_loader, PRETRAIN_EPOCHS, lr=0.001)
    initialize_cluster_centers(model, full_train_loader, N_CLUSTERS)
    
    history = train_drc(model, train_loader, val_loader, 
                       epochs=TRAIN_EPOCHS, lr=LR, gamma=GAMMA, v=V)
    
    # Evaluation
    print("\n=== Final Test Evaluation ===")
    test_metrics = evaluate_model(model, test_loader)
    print(f"Test ACC: {test_metrics['ACC']:.4f}")
    print(f"Test NMI: {test_metrics['NMI']:.4f}")
    print(f"Test ARI: {test_metrics['ARI']:.4f}")
    
    visualize_results(history, test_metrics)
    visualize_tsne_clusters(model, test_loader, max_samples=2000)
    z_2d, labels, preds = visualize_cluster_distribution(
        model, test_loader, 
        title="MNIST"
    )
    return model, history, test_metrics


if __name__ == "__main__":
    model, history, metrics = main()