In [None]:
import numpy as np
from collections import defaultdict
from math import sqrt, inf
from random import uniform
import matplotlib.pyplot as plt

def point_avg(points):
    dimensions = len(points[0])
    new_center = [sum(p[dimension] for p in points) / len(points) for dimension in range(dimensions)]
    return new_center

def update_centers(data_set, assignments):
    new_means = defaultdict(list)
    centers = []
    for assignment, point in zip(assignments, data_set):
        new_means[assignment].append(point)
    for points in new_means.values():
        centers.append(point_avg(points))
    return centers

def assign_points(data_points, centers):
    assignments = [np.argmin([distance(point, center) for center in centers]) for point in data_points]
    return assignments

def distance(a, b):
    return sqrt(sum((a[dim] - b[dim]) ** 2 for dim in range(len(a))))

def k_means(dataset, k_points, k):
    assignments = assign_points(dataset, k_points)
    old_assignments = None
    iter_count = 0
    while assignments != old_assignments:
        new_centers = update_centers(dataset, assignments)
        old_assignments = assignments
        assignments = assign_points(dataset, new_centers)
        iter_count += 1
        plot_data_colored_by_groups(groups, assignments, dataset, f'K-means iter {iter_count}', True)
    print('iter', iter_count)
    return assignments

def generate_k(data_set, k):
    dimensions = len(data_set[0])
    min_max = defaultdict(int)
    for point in data_set:
        for i in range(dimensions):
            val = point[i]
            min_key, max_key = f'min_{i}', f'max_{i}'
            min_max[min_key] = min(val, min_max[min_key]) if min_key in min_max else val
            min_max[max_key] = max(val, min_max[max_key]) if max_key in min_max else val

    centers = [[uniform(min_max[f'min_{i}'], min_max[f'max_{i}']) for i in range(dimensions)] for _ in range(k)]
    return centers

def cluster_by_kmeans(points, num_clusters):
    lithologies_centroids = generate_k(points, num_clusters)
    k_means_assignments = k_means(points, lithologies_centroids, num_clusters)
    k_means_assignments = np.array(k_means_assignments)
    return k_means_assignments