In [8]:
import numpy as np
from scipy.spatial.distance import cdist

In [14]:
def wss(X: np.ndarray, Y: np.ndarray = None, centers: np.ndarray = None, method: str = 'conventional'):
    if centers is not None:
        assert (centers.shape[1] == X.shape[1])
        distances = cdist(X, centers).min(axis=1)
        if method == 'conventional':
            distances = distances ** 2
        return distances.sum()
    elif Y is not None:
        assert (Y.shape[0] == X.shape[0])
        unique = np.unique(Y)
        total_wss = 0.0
        for i in unique:
            centroid = np.mean(X[np.where(Y == i)], axis=0)
            distances = cdist(X[np.where(Y == i)], np.array([centroid]))
            if method == 'conventional':
                distances = distances ** 2
            total_wss += distances.sum()
        return total_wss
    else:
        print("Error, no partition passed")


# aux rewritten
def wss_axis(Y: np.ndarray, X: np.ndarray, method: str = 'conventional'):
    return wss(X=X, Y=Y, method=method)


def wss_matrix(X: np.ndarray, L: np.ndarray = None, method: str = 'conventional'):
    inertia = np.zeros((L.shape[0], L.shape[1]))
    for i in range(L.shape[0]):
        inertia[i] = np.apply_along_axis(wss_axis, 1, L[i], X, method)
    return inertia

In [10]:
X = np.random.random((1000, 5))
L = np.random.randint(0, 5, (30, 20, 1000))

In [11]:
L.shape

(30, 20, 1000)

In [19]:
matrix = wss_matrix(X, L)

In [21]:
def ufo(m, agg):
    aggregated = agg(m, axis = 0)
    return aggregated

In [26]:
ufo(matrix, np.min).shape

(20,)

In [25]:
matrix.shape

(30, 20)

In [35]:
radost = np.array(sorted(ufo(matrix, np.mean), reverse=True))

In [36]:
radost

array([410.39317725, 410.34215051, 410.26855159, 410.25962909,
       410.24134472, 410.24032289, 410.22603497, 410.20459086,
       410.19549144, 410.18340159, 410.16287751, 410.15041923,
       410.14648916, 410.1391571 , 410.13900903, 410.12292982,
       410.08626492, 410.02793911, 410.00848658, 409.99567653])

In [37]:
vals = radost[:-1]/radost[1:]

In [76]:
def elbow(SSW: np.ndarray, levels: (int, int) = (1, 1), aggregation=np.mean):
    # if method == 'fixed':
    aggregated = np.array(sorted(aggregation(SSW, axis=1), reverse=True))
    print(aggregated.shape)
    indices = np.full((SSW.shape[0],), -np.inf)
    print(indices.shape)
    frac_up = (aggregated[:-(levels[0] + levels[1])] - aggregated[levels[0]:-levels[1]]) / (
               aggregated[levels[0]:-levels[1]] - aggregated[(levels[0] + levels[1]):])
    print(frac_up.shape)
    indices[levels[0]:-levels[1]] = frac_up
    return indices

In [48]:
lvl=(1,2)

In [49]:
frac_up = (radost[:-(lvl[0] + lvl[1])] - radost[lvl[0]:-lvl[1]]) / (
               radost[lvl[0]:-lvl[1]] - radost[(lvl[0] + lvl[1]):])

In [50]:
frac_up.shape

(17,)

In [77]:
elbow(matrix, (1,1))

(30,)
(30,)
(28,)


array([          -inf, 1.08832285e+01, 3.68322534e-02, 4.43350106e+00,
       1.08166986e+00, 7.16821402e-01, 3.09616304e+00, 4.09614269e-01,
       4.10778961e-01, 1.57735032e+00, 4.71736791e-01, 3.16978848e+01,
       4.32931000e-01, 2.12439838e-01, 9.36206531e-01, 1.69870133e+00,
       7.51844665e-01, 2.89674242e-01, 2.56318030e+00, 1.16557658e+00,
       4.37226343e+00, 1.25489221e-01, 6.95722044e-01, 4.43676695e+00,
       5.73525893e+00, 2.89647940e-02, 8.45044066e+00, 1.30124611e-01,
       1.83341575e+00,           -inf])

In [58]:
def count(X: np.ndarray):
    return np.unique(X).shape[0]

In [67]:
classes = np.zeros((L.shape[0], L.shape[1]))
for i in range(L.shape[0]):
    classes[i] = np.apply_along_axis(count, 1, L[i])
num_classes = np.min(classes, axis=1)

In [68]:
num_classes.shape

(30,)

In [71]:
matrix

array([[409.37712427, 409.36471652, 410.71975192, 410.40246568,
        410.59899905, 410.6812361 , 410.49029936, 409.74735556,
        409.27105161, 409.89737721, 409.22503795, 410.89660817,
        409.68896169, 410.68654591, 409.77086721, 409.67995288,
        411.2829908 , 409.63162273, 410.46479199, 410.03774609],
       [410.46798389, 410.20634339, 409.98286609, 410.9379998 ,
        410.4056633 , 410.79338848, 410.35474797, 410.57711523,
        410.52304099, 410.48121962, 409.28819369, 410.59990025,
        409.4993055 , 410.08334667, 410.28053526, 410.49115705,
        410.19828613, 410.32898422, 410.31914659, 410.02878409],
       [410.03243434, 410.65560342, 410.27782244, 410.52200171,
        410.25124289, 410.12657099, 410.8370152 , 409.78113175,
        410.11864891, 409.8702622 , 410.04873721, 409.73121973,
        410.20898366, 409.37997238, 410.32026711, 410.04757924,
        411.03190423, 409.45799893, 409.52851831, 410.47099882],
       [410.85559875, 410.33939082, 4