# Balanced Clustering

This is a document using a simple (greedy) algorithm for balancing clusters.

## Helper functions

There are a number of helper functions here for the main algo. Most are self-explanatory, but the less intuitive ones have docstrings. This is not the focus of the notebook.

In [4]:
def get_centroid(cluster_elements):
    return cluster_elements.mean(axis=0)

In [5]:
def get_centroids(clusters_dict):
    return {k : get_centroid(clusters_dict[k]) for k in clusters_dict.keys()}

In [6]:
def get_group_num_elements(group_name, clusters_dict):
    return clusters_dict[group_name].shape[0]

In [7]:
def get_all_group_sizes(clusters_dict):
    return {group_name : get_group_num_elements(group_name, clusters_dict) for group_name in clusters_dict.keys()}

In [8]:
def get_group_with_fewest_elements(clusters_dict):
    '''
    Given a dictionary, mapping from group name
    to a numpy array of group element coordinates,
    return the name (AKA key) of the group that has
    the fewest number of elements. Ties are not handled
    arbitrarily.
    '''
    group_names = list(clusters_dict.keys()) # create indexable list
    num_elements_list = [ get_group_num_elements(group_name, clusters_dict) for group_name in group_names ] # find number of elements for each key
    index_of_min = np.argmin(num_elements_list) # find index of min for the key
    return group_names[index_of_min] # return the key at the index where the min occurred

In [9]:
def distance_metric(a_coords, b_coords):
    '''
    This is flexible, but right now it's Euclidean
    distance. It could alternatively be absolute distance,
    or absolute distance cubed, or whatever.'''
    difference_array = a_coords - b_coords
    return np.linalg.norm(difference_array)

In [10]:
def get_distances_from_small_group_centroid(cluster_dict, group_names, small_group_centroid, 
                                            small_group_num_elements):
    '''
    Goes and looks at all specified groups in a given
    cluster dictionary. It reports the distance from
    each individual point to the reference element.
    
    cluster_dict: a mapping from the name of each cluster to a numpy array of element
    group_names: iterable of strings representing keys in cluster_dict (like names for the clusters)
    small_group_centroid: 1 dimensional numpy array of numbers 
    small_group_num_elements: the number of elements in the smallest group
    return: 
        dictionary {
            distance : dictionary{
            'group_name': <string group name>, 
            'element': <np.array element values>
            } 
        }
        
    '''
    distances_from_small_group_centroid = {} 
    for group_name in group_names:
        group_num_elements = get_group_num_elements(group_name, cluster_dict)
        assert (group_num_elements >= small_group_num_elements), 'One of the non-smallest groups has fewer elements than the alleged smallest group. Ensure that the smallest group is being chosen correctly.'

        # The rule is that we can only take elements from a cluster that currently
        # has at least 2 more elements than the cluster with the least number of elements
        if (group_num_elements - small_group_num_elements) >= 2:
            group_elements = cluster_dict[group_name]
            # This could be sped up with list comprehension, but not implementing at this stage
            # so as to preserve clarity
            for element in group_elements:
                distance_to_small_group_centroid = distance_metric(small_group_centroid, element)
                distances_from_small_group_centroid[distance_to_small_group_centroid] = {
                    'group_name':group_name, 'element':element
                }
    return distances_from_small_group_centroid


In [11]:
# clusters

In [12]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def plot_clusters(cluster_dict):
    '''
    The gist of this is to create a
    dataframe with columns being
    the x, y, and group name, and
    then plotting this in seaborn
    with the hue being the group name.
    '''
    def get_total_num_elements(cluster_dict):
        '''
        Just finds how many elements there
        are total in the cluster_dict (not
        how many numbers, but how many
        elements; an element may have
        multiple dimensions)'''
        total = 0
        for group_name in cluster_dict.keys():
            total += get_group_num_elements(group_name, cluster_dict)
        return total
    
    num_elements = get_total_num_elements(cluster_dict)
    all_elements = [ [0 for j in range(3)] for i in range(num_elements) ]
    row_counter = 0
    for group_name in cluster_dict.keys():
        cluster = cluster_dict[group_name]
        for element in cluster:
            all_elements[row_counter] = [group_name] + element.tolist()
            row_counter += 1
    all_elements_df = pd.DataFrame(all_elements, columns=['group_name', 'x', 'y'])
#     print(all_elements_df)
    sns.scatterplot(x='x', y='y', hue='group_name', data=all_elements_df)
    plt.show()

## Data

In order for this make sense, we will generate a random field of data that occupies four quadrants in 2D space. Although this approach should technically generalize to n-dimensional space. Here, we will generate some arbitrary clusters for example. However, these clusters could come from any clustering algorithm; in the original write up, I imagined that they may originate from k-means, but this is very versatile.

In [13]:
import numpy as np
def make_cluster(approx_centroid, noise_max, num_elements):
    '''
    Generates a cluster-ish area around the
    approx_centroid coordinates.
    
    Returns an 
    '''
#     assert type(approx_centroid) == type(['is', 'a', 'list'])
    cluster = []
    for i in range(num_elements):
        noise = np.random.uniform(low=-noise_max, high=noise_max, size=len(approx_centroid))
        coordinates = np.array(approx_centroid) + noise
        cluster += [coordinates]
    return np.array(cluster)

In [14]:
approx_centroids = [[0, 0], [10, 0], [0, 10], [10, 10]]
clusters = {
    'A' : make_cluster([0, 0], 2, 10), # based around (0, 0)
    'B' : make_cluster([10, 0], 2, 15), # based around (10, 0)
    'C' : make_cluster([0, 10], 2, 8), # based around (0, 10)
    'D' : make_cluster([10, 10], 2, 7) # based around (10, 10)
}

## Main algorithm

In [15]:
def balance_clusters(cluster_dict_inp, max_iterations=100, verbose='text'):
    '''
    Description:
        This method takes in clusters and then balances all of the clusters so
        that they have the same number of elements (or are within one of each other).
    
    Parameters:
        cluster_dict_inp: a dictionary that maps from cluster names to numpy arrays;
                            the dimension of the array is the dimension of a sample
        max_iterations: the maximum number of iterations before the algorithm stops;
                        this is a fail-safe from indefinite looping (although this
                        should not be an issue if the code is correct)
        verbose: string, specifying the mode of verbosity; if it is 'all', then
                 both text and plots will be displayed; if 'plot', then only plots
                 will be displayed, not text; if 'text', then only text will be 
                 displayed, not plots.
    Return:
        a dictionary of the exact same form as cluster_dict_inp, except with balanced
        clusters
    '''
    cluster_dict = cluster_dict_inp.copy() # copy, so as not to edit original
    
    counter = 0 
    while counter < max_iterations:
        # Get centroid of the smallest group
        if verbose.lower() in ('text', 'all'): print(f'Group sizes: {get_all_group_sizes(cluster_dict)}')
        small_group_name = get_group_with_fewest_elements(cluster_dict)
        small_group_num_elements = get_group_num_elements(small_group_name, cluster_dict)
        centroids = get_centroids(cluster_dict)
        small_group_centroid = centroids[small_group_name]

        # Get nearest element that's in a different group with at least 
        other_group_names = set(cluster_dict.keys())
        other_group_names.remove(small_group_name)

        # See how far the small group is from 
        distances_from_small_group_centroid = get_distances_from_small_group_centroid(cluster_dict,
                other_group_names, small_group_centroid, small_group_num_elements)
        if not distances_from_small_group_centroid:
            # If the method returned an empty dictionary of possible distances,
            # this implies that there are no further changes to make, and thus the
            # function is finished running.
            if verbose.lower() in ('text', 'all'): print('There are no more small groups.')
            return cluster_dict
        else:
            # Now find the smallest centroid and move it from its current cluster to the smallest cluster
            distances_from_small_group_centroid_list = distances_from_small_group_centroid.keys()
            shortest_dist_from_small_centroid = min(distances_from_small_group_centroid_list)
            moving_element_information = distances_from_small_group_centroid[shortest_dist_from_small_centroid]
            moving_element_previous_group_name = moving_element_information['group_name']
            moving_element_previous_cluster = cluster_dict[moving_element_previous_group_name]
            moving_element = moving_element_information['element']

            # Get the index of the row that should be taken out of the old cluster
            index_of_deletion = np.argwhere(cluster_dict[moving_element_previous_group_name]==moving_element)[0, 0]
            # Move row out of old cluster
            cluster_dict[moving_element_previous_group_name] = np.delete(moving_element_previous_cluster, index_of_deletion, axis=0)
            # Move row to new, small cluster
            cluster_dict[small_group_name] = np.append( cluster_dict[small_group_name], [moving_element], axis=0 )

            if verbose.lower() in ('text', 'all'): print(f'Moving {moving_element} from cluster '
                                                         + f'{moving_element_previous_group_name}'
                                                          + f' to cluster {small_group_name}')

            # Plotting only works for 2 dimensional data. This checks if the data is 2D
            is2d = len(cluster_dict[list(cluster_dict.keys())[0]][0]) == 2
            if verbose.lower() in ('plot', 'all') and is2d: plot_clusters(cluster_dict)
            counter += 1
            
            

In [16]:
balanced_clusters = balance_clusters(clusters, verbose='text')

Group sizes: {'A': 10, 'B': 15, 'C': 8, 'D': 7}
Moving [11.73172424  1.97746702] from cluster B to cluster D
Group sizes: {'A': 10, 'B': 14, 'C': 8, 'D': 8}
Moving [-0.71884931  1.26885987] from cluster A to cluster C
Group sizes: {'A': 9, 'B': 14, 'C': 9, 'D': 8}
Moving [9.88412339 1.77237694] from cluster B to cluster D
Group sizes: {'A': 9, 'B': 13, 'C': 9, 'D': 9}
Moving [9.13842474 0.30012904] from cluster B to cluster A
Group sizes: {'A': 10, 'B': 12, 'C': 9, 'D': 9}
Moving [9.16027345 0.57506863] from cluster B to cluster C
Group sizes: {'A': 10, 'B': 11, 'C': 10, 'D': 9}
Moving [10.46424925  1.54655366] from cluster B to cluster D
Group sizes: {'A': 10, 'B': 10, 'C': 10, 'D': 10}
There are no more small groups.
