In [4]:
from collections import namedtuple
import time

from kmodes.kmodes import KModes
from kmodes.util import dissim

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tqdm


In [5]:
Record = namedtuple(
    "Record",
    [
        "n_clusters",
        "initialisation",
        "seed",
        "initial_cost",
        "final_cost",
        "n_iterations",
        "time",
    ],
)


def find_clustering(data, n_clusters, initialisation, seed):

    start = time.perf_counter()
    km = KModes(n_clusters, init=initialisation, n_init=1, random_state=seed)
    km.fit(data)
    end = time.perf_counter()

    return km.epoch_costs_[0], km.cost_, km.n_iter_, end - start


def run_experiment(dataset, initialisation, repetitions, n_clusters=None):

    data = dataset.drop("class", axis=1)
    n_clusters = dataset["class"].nunique() if n_clusters is None else n_clusters

    results = []
    for seed in tqdm.tqdm(range(repetitions)):
        initial_cost, final_cost, n_iter, time = find_clustering(
            data, n_clusters, initialisation, seed
        )

        record = Record(
            n_clusters,
            initialisation,
            seed,
            initial_cost,
            final_cost,
            n_iter,
            time,
        )
        results.append(record)

    return pd.DataFrame(results)


def main(name, repetitions=50, root="../data/", destination=None, n_clusters=None):

    data = pd.read_csv(f"{root}{name}.csv", na_values=["?", "dna"])
    dataset = data.dropna()

    dfs = [
        run_experiment(dataset, initialisation, repetitions, n_clusters)
        for initialisation in ("cao", "huang", "matching")
    ]

    df = pd.concat(dfs, axis=0, ignore_index=True)
    if destination is not None:
        df.to_csv(destination + f"{name}_results.csv", index=False)

    return df


In [6]:
optimal_nclusters = (10, 17, 6, 4) # max_clusters = int(sqrt(nrows))

for name, n_clusters in zip(
    ("breast_cancer", "mushroom", "soybean", "zoo"), optimal_nclusters
):
    main(name, root="../data/", destination="../data/sqrt_nrows/", n_clusters=n_clusters)


100%|██████████| 50/50 [00:17<00:00,  2.83it/s]
100%|██████████| 50/50 [00:12<00:00,  4.08it/s]
100%|██████████| 50/50 [00:09<00:00,  5.46it/s]
100%|██████████| 50/50 [04:03<00:00,  4.87s/it]
100%|██████████| 50/50 [04:27<00:00,  5.36s/it]
100%|██████████| 50/50 [02:07<00:00,  2.55s/it]
100%|██████████| 50/50 [00:11<00:00,  4.48it/s]
100%|██████████| 50/50 [00:10<00:00,  4.57it/s]
100%|██████████| 50/50 [00:07<00:00,  6.92it/s]
100%|██████████| 50/50 [00:01<00:00, 25.05it/s]
100%|██████████| 50/50 [00:01<00:00, 30.66it/s]
100%|██████████| 50/50 [00:01<00:00, 37.31it/s]
