In [81]:
import numpy as np
from scipy.sparse import csr_matrix, lil_matrix
import matplotlib.pyplot as plt

def kernel_k_means(clusters, kernel, weights, n_clusters, display=False):
  tolerance = 0.0001
  error = 0
  n = kernel.shape[0]
  center_dists = np.zeros(n)
  iter_num = 1
  
  while True:
    old_error = error
    error = 0
    
    intra_cluster = np.zeros(n_clusters)
    center_dists = np.zeros((n, n_clusters))

    for i in range(n_clusters):

      cluster_points = np.where(clusters == i)[0]
      cluster_weights = weights[cluster_points]

      intra_cluster[i] = np.dot(np.dot(kernel[cluster_points, :][:, cluster_points], 
                                       cluster_weights), cluster_weights)
      # intra_cluster[i] /= np.sum(cluster_weights) ** 2
      intra_cluster[i] /= (np.sum(cluster_weights) ** 2 + 1e-10)

      center_dists[:,i] = np.dot(kernel[:, cluster_points], cluster_weights) 
      # center_dists[:,i] = -2 * center_dists[:,i] / np.sum(cluster_weights) + intra_cluster[i]
      center_dists[:,i] = -2 * center_dists[:,i] / (np.sum(cluster_weights) + 1e-10) + intra_cluster[i]
      center_dists[:,i] += np.diag(kernel)

      center_dists[cluster_points, i] = center_dists[cluster_points, i]
      error += np.dot(cluster_weights, center_dists[cluster_points, i])
      print(center_dists[cluster_points, i].shape)
      print(cluster_weights.shape)
      break
      

    min_dists, clusters = np.min(center_dists, axis=1), np.argmin(center_dists, axis=1)
    
    if display:
      print(f'Iteration {iter_num}: Error = {error}')

    # if iter_num > 1 and abs(1 - (error / old_error)) < tolerance:
    if iter_num > 1 and abs(1 - (error / (old_error + 1e-10))) < tolerance:
      break

    iter_num += 1
      
  return clusters, error, center_dists

def CW2KM(K, n_clusters, n_views, p, init_clusters, init_weights):
  if p < 1:
    raise ValueError('p must be greater than or equal to 1')

  # if np.any(init_weights < 0) or abs(np.sum(init_weights) - 1) > 1e-15:
  #   raise ValueError('Weights must be positive and sum to unity')

  n_points = K.shape[0] // n_views

  iter_num = 1
  old_error = np.inf

  clusters = init_clusters
  weights = init_weights

  trace_per_cluster_view = [[] for _ in range(n_clusters)]
  for i in range(n_views):
    K_view = K[i*n_points:(i+1)*n_points, :]  
    for j in range(n_clusters):
      cluster_points = np.where(clusters == j)[0]
      K_cluster = K_view[cluster_points[:, None], cluster_points]
      trace_per_cluster_view[j].append(np.trace(K_cluster))
  trace_per_cluster_view = np.array(trace_per_cluster_view)
  # trace_per_view = [np.trace(K[i*n_points:(i+1)*n_points, :]) for i in range(n_views)]

  while True:
    print('--------------- CW2KM Iteration {} ---------------'.format(iter_num))

    # Update clusters
    print('Updating clusters...')

    K_combined = np.zeros((n_points, n_points))
    for i in range(n_views):
      for k in range(n_clusters):
        K_combined += weights[i, k]**p * K[i*n_points:(i+1)*n_points, :]

    clusters, error, _ = kernel_k_means(clusters, K_combined, np.ones(n_points), n_clusters) 
    print('Objective after updating clusters:', error)

    if len(np.unique(clusters)) < n_clusters:
      raise ValueError('Empty clusters detected')

    # Y = csr_matrix((n_points, n_clusters))
    Y = lil_matrix((n_points, n_clusters))
    for i in range(n_clusters):
      cluster_points = np.where(clusters == i)[0]
      Y[cluster_points, i] = 1 / np.sqrt(len(cluster_points))

    trace_per_cluster_view_new = []
    for i in range(n_clusters):
      cluster_traces = []
      cluster_points = np.where(clusters == i)[0]
      Y_cluster = Y[cluster_points, i]  
      for j in range(n_views):
        K_view = K[j*n_points:(j+1)*n_points, :]
        K_cluster = K_view[cluster_points[:, None], cluster_points]
        trace_cluster = np.trace(Y_cluster.T @ K_cluster @ Y_cluster)
        cluster_traces.append(trace_cluster)
      trace_per_cluster_view_new.append(cluster_traces)
    trace_per_cluster_view_new = np.array(trace_per_cluster_view_new)
    # trace_per_view_new = [np.trace(Y.T @ K[i*n_points:(i+1)*n_points, :] @ Y) for i in range(n_views)]

    trace_diff = np.array([a - b for a, b in zip(trace_per_cluster_view, trace_per_cluster_view_new)]).T
    # trace_diff = [a - b for a, b in zip(trace_per_view, trace_per_view_new)]

    if abs(1 - (error / old_error)) < 0.0001:
      print('CW2KM reached convergence')
      break

    old_error = error

    # Update weights 
    print('Updating weights...')

    if p != 1:
      weights = 1 / np.power(trace_diff, 1/(p-1))
      weights /= np.sum(weights)
      print(weights.shape)
    else:
      min_idx = np.argmin(trace_diff)
      weights = np.zeros(n_views)
      weights[min_idx] = 1

    weights[weights < 1e-5] = 0 
    weights /= np.sum(weights)
    
    # print('Objective after updating weights:', (weights**p) @ trace_diff)
    print('Objective after updating weights:', np.sum((weights**p).T @ trace_diff))
    print()

    iter_num += 1

  return clusters, weights, error

In [82]:
def global_kernel_k_means(kernel, weights, n_clusters, display=False):
  n = kernel.shape[0]
  
  best_errors = np.inf * np.ones(n_clusters)
  best_clusters = np.ones((n, n_clusters))

  # Find 1 cluster solution
  best_clusters[:,0], best_errors[0], _ = kernel_k_means(best_clusters[:,0], kernel, weights, 1, display)

  for m in range(2, n_clusters+1):

    for n in range(n):
    
      clusters = best_clusters[:,m-1].copy()
      clusters[n] = m
      
      if display:
        print(f"\nSearching for {m} clusters. Placing {m}th cluster at point {n} initially.")

      clusters, error, _ = kernel_k_means(clusters, kernel, weights, m, display)

      if display:
        print(f"Final error: {error}")

      if best_errors[m-1] > error:
        best_errors[m-1] = error
        best_clusters[:,m-1] = clusters

  if np.unique(best_clusters[:,m-1]).size < m:
    raise ValueError(f"Could not find more than {m-1} clusters")

  best_idx = np.argmin(best_errors)
  clusters = best_clusters[:, best_idx]
  error = best_errors[best_idx]

  if display:
    print(f"\nBest solution: {best_idx+1} clusters with error {error}")

  return clusters, error

In [83]:
import numpy as np
from sklearn.datasets import make_blobs
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics.pairwise import rbf_kernel
import scipy.io

n_clusters = 3
data = scipy.io.loadmat('datasets/synthetic_data.mat')

X1 = data["View_1"]
X2 = data["View_2"]
X = [X1, X2]
y = data["Ground_truth"]

K = np.array(scipy.io.loadmat('datasets/synthetic_data_kernel.mat')['K'])

# Run CW2KM
n_views = 2
p = 1.5
init_weights = np.full((n_views, n_clusters), 1 / n_views)
# init_weights = np.full(n_views, 1/n_views)
# print(init_weights.shape)

n_points = K.shape[0] // n_views
# Compute composite kernel 
K_combined = np.zeros((n_points, n_points))
for i in range(n_views):
  for k in range(n_clusters):
    start = i * n_points
    end = (i+1) * n_points
    K_combined += init_weights[i, k]**p * K[start:end, :]
# Initialize with Global K-Means
print('Global Kernel K-Means initialization')
init_clusters = global_kernel_k_means(K_combined, np.ones(n_points), n_clusters)
# init_clusters = np.random.randint(0, n_clusters, n_points)
print('End initialization\n')

cluster_labels, weights, error = CW2KM(K, n_clusters, n_views, p, init_clusters, init_weights)

# Plot the original data and the clustered data
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.scatter(X1[:, 0], X1[:, 1], c=y, cmap='viridis')
plt.title('Original Data')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')

plt.subplot(1, 2, 2)
plt.scatter(X2[:, 0], X2[:, 1], c=cluster_labels, cmap='viridis')
plt.title('Clustered Data (CW2KM)')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')

plt.tight_layout()
plt.show()

Global Kernel K-Means initialization
(0,)
(0,)
(700,)
(700,)
(700,)
(700,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)
(0,)


KeyboardInterrupt: 

In [None]:
import pandas as pd

V1 = pd.read_csv('datasets/HW/mfeat-fou', header=None, delimiter="\\s+")
V2 = pd.read_csv('datasets/HW/mfeat-fac', header=None, delimiter="\\s+")
V3 = pd.read_csv('datasets/HW/mfeat-kar', header=None, delimiter="\\s+")
V4 = pd.read_csv('datasets/HW/mfeat-pix', header=None, delimiter="\\s+")

views = [V1, V2, V3, V4]
n_clusters = 3

X = [V1, V2, V3, V4]

K = np.array(scipy.io.loadmat('datasets/synthetic_data_kernel.mat')['K'])

# Run CW2KM
n_views = 2
p = 1.5
init_weights = np.full(n_views, 1/n_views)

n_points = K.shape[0] // n_views
# Compute composite kernel 
K_combined = np.zeros((n_points, n_points))
for i in range(n_views):
  start = i * n_points
  end = (i+1) * n_points
  K_combined += init_weights[i]**p * K[start:end, :]
# Initialize with Global K-Means
print('Global Kernel K-Means initialization')
init_clusters = global_kernel_k_means(K_combined, np.ones(n_points), n_clusters)
# init_clusters = np.random.randint(0, n_clusters, n_samples)
print('End initialization\n')

cluster_labels, weights, error = CW2KM(K, n_clusters, n_views, p, init_clusters, init_weights)

ckm = CW2KM(n_clusters=10, max_iter=1, gamma=0.01)
ckm.fit(views)
labels = ckm.predict()

Global Kernel K-Means initialization


KeyboardInterrupt: 

In [None]:
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(cluster_labels, y.flatten() - 1)
print(accuracy)

0.2857142857142857
