In [1]:
import numpy as np

In [2]:
def InputParser(file_path: str) -> dict():
    params = dict()
    
    with open(file_path, 'r') as file:
        params['k'], params['m'] = list(map(int, file.readline().split()))
        params['beta'] = float(file.readline().rstrip())
        
        params['data'] = []
        while True:
            line = file.readline().rstrip()
            
            if not line:
                break
            
            params['data'].append(list(map(float, line.split())))
    
        params['data'] = np.array(params['data'])
    
    return params

In [21]:
def GetDistance(point: np.array, center: np.array) -> float:
    return np.sqrt(np.square(point - center).sum())

In [61]:
def SoftKmeansClustering(k: int, beta: float, datapoints: np.array) -> None:
    # initialize centers with first k points
    centers = datapoints[:k]

    steps = 100
    while steps > 0:
        steps -= 1

        # E-step (Estimating Hidden Matrix)
        hidden_matrix = np.array([[np.exp(-beta*GetDistance(point, center)) for point in datapoints] for center in centers])
        hidden_matrix = hidden_matrix / hidden_matrix.sum(axis=0)  # axis=0 : column-wise
        
        # M-step (Estimating Parameters)
        new_centers = ((hidden_matrix @ datapoints).T / hidden_matrix.sum(axis=1)).T  # axis=1 : row-wise
        
        if np.array_equal(centers, new_centers):
            break
        
        centers = new_centers
    
    for center in centers:
        print(' '.join(list(map(lambda x: format(x, '.3f'), center))))

In [64]:
# test input parser
test_params = InputParser("test_datasets/testdata_ba8d.txt")
test_params2 = InputParser("test_datasets/testdata_ba8d-2.txt")
#print(test_params['k'], test_params['m'])
#print(test_params['beta'])
#print(test_params['data'])

In [63]:
SoftKmeansClustering(test_params['k'], test_params['beta'], test_params['data'])

1.662 2.623
1.075 1.148


In [65]:
SoftKmeansClustering(test_params2['k'], test_params2['beta'], test_params2['data'])

5.889 16.921 6.873
20.404 8.236 9.055
3.590 4.853 4.970
11.329 5.448 5.319
5.761 6.494 17.227
