In [1]:
from collections import namedtuple
import time
import warnings

from kmodes.kmodes import KModes
from kmodes.util import dissim

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tqdm


In [2]:
Record = namedtuple(
    "Record",
    [
        "n_clusters",
        "initialisation",
        "seed",
        "initial_cost",
        "final_cost",
        "n_iterations",
        "time",
    ],
)


def find_clustering(data, n_clusters, initialisation, seed):

    start = time.perf_counter()
    km = KModes(n_clusters, init=initialisation, n_init=1, random_state=seed)
    km.fit(data)
    end = time.perf_counter()

    return km.epoch_costs_[0], km.epoch_costs_[-1], km.n_iter_, end - start


def run_experiment(data, initialisation, repetitions):

    data = data.dropna()
    n_clusters = data["class"].nunique()
    data = data.drop("class", axis=1)

    results = []
    for seed in tqdm.tqdm(range(repetitions)):
        initial_cost, final_cost, n_iter, time = find_clustering(
            data, n_clusters, initialisation, seed
        )

        record = Record(
            n_clusters, initialisation, seed, initial_cost, final_cost, n_iter, time
        )
        results.append(record)

    return pd.DataFrame(results)


def main(name, repetitions=250, root="../data/", destination=None):

    data = pd.read_csv(f"{root}{name}.csv", na_values=["?", "dna"])

    dfs = []
    for initialisation in ("cao", "huang", "matching"):
        df = run_experiment(data, initialisation, repetitions)
        dfs.append(df)

    df = pd.concat(dfs, axis=0, ignore_index=True)
    if destination is not None:
        df.to_csv(destination + f"{name}_results.csv", index=False)

    return df


In [3]:
for name in ("breast_cancer", "mushroom", "nursery", "soybean"):
    print(name)
    main(name, root="../data/", destination="../data/nclasses/")


  0%|          | 1/250 [00:00<00:31,  7.90it/s]

breast_cancer


100%|██████████| 250/250 [00:31<00:00,  7.82it/s]
100%|██████████| 250/250 [00:26<00:00,  9.47it/s]
100%|██████████| 250/250 [00:21<00:00, 11.36it/s]
  0%|          | 0/250 [00:00<?, ?it/s]

mushroom


100%|██████████| 250/250 [03:56<00:00,  1.06it/s]
100%|██████████| 250/250 [08:01<00:00,  1.93s/it]
100%|██████████| 250/250 [05:45<00:00,  1.38s/it]
  0%|          | 0/250 [00:00<?, ?it/s]

nursery


100%|██████████| 250/250 [07:29<00:00,  1.80s/it]
100%|██████████| 250/250 [07:08<00:00,  1.72s/it]
100%|██████████| 250/250 [05:42<00:00,  1.37s/it]
  0%|          | 0/250 [00:00<?, ?it/s]

soybean


100%|██████████| 250/250 [01:23<00:00,  3.01it/s]
100%|██████████| 250/250 [01:53<00:00,  2.21it/s]
100%|██████████| 250/250 [01:00<00:00,  4.12it/s]
