In [1]:
import numpy as np

from numpy.random import default_rng
from numba import njit
from typing import Set, Tuple

In [2]:
def generate_problem(num_groups: int,
                     num_classes: int,
                     min_group_size: int,
                     max_group_size: int,
                     class_percent: np.array) -> np.ndarray:

    problem = np.zeros((num_groups, num_classes), dtype=int)

    rng = default_rng()
    group_sizes = rng.integers(low=min_group_size,
                               high=max_group_size,
                               size=num_groups)

    for i in range(num_groups):
        # Calculate the
        proportions = np.random.normal(class_percent, class_percent / 10)

        problem[i, :] = proportions * group_sizes[i]
    return problem

In [3]:
@njit
def calculate_cost(problem: np.ndarray,
                   solution: np.ndarray,
                   k: int) -> float:
    cost = 0.0
    total = np.sum(problem)
    class_sums = np.sum(problem, axis=0)
    num_classes = problem.shape[1]

    for i in range(k):
        idx = solution == i
        fold_sum = np.sum(problem[idx, :])

        # Start by calculating the fold imbalance cost
        cost += (fold_sum / total - 1.0 / k) ** 2

        # Now calculate the cost associated with the class imbalances
        for j in range(num_classes):
            cost += (np.sum(problem[idx, j]) / fold_sum - class_sums[j] / total) ** 2
    return cost

In [4]:
@njit
def calculate_decision_space(problem: np.ndarray,
                             solution: np.ndarray,
                             k: int) -> np.ndarray:
    num_groups = problem.shape[0]
    cost = calculate_cost(problem, solution, k)

    space = np.zeros((num_groups, k))
    sol = solution.copy()

    for i in range(num_groups):
        for j in range(k):
            if solution[i] == j:
                space[i,j] = np.infty
            else:
                sol[i] = j
                space[i,j] = calculate_cost(problem, sol, k) - cost
        sol[i] = solution[i]
    return space

In [5]:
@njit
def solution_to_str(solution: np.ndarray) -> str:
    return "".join([str(n) for n in solution])

In [6]:
def select_move(decision: np.ndarray,
                solution: np.ndarray,
                history: Set) -> Tuple:
    candidates = np.argsort(decision.flatten())

    for c in candidates:
        p = np.unravel_index(c, decision.shape)
        s = solution.copy()
        s[p[0]] = p[1]
        sol_str = solution_to_str(s)

        if sol_str not in history:
            return p
    return -1, -1 # No move found!

In [7]:
num_groups = 1000

In [8]:
prb = generate_problem(num_groups=num_groups,
                       num_classes=4,
                       min_group_size=400,
                       max_group_size=2000,
                       class_percent=np.array([0.4, 0.3, 0.2, 0.1]))

In [9]:
prb.shape

(1000, 4)

In [10]:
rng = default_rng()
sol = rng.integers(low=0, high=5, size=num_groups)

In [11]:
calculate_cost(prb, sol, k=5)

0.00108105244072681

In [12]:
decision = calculate_decision_space(prb, sol, k=5)

In [13]:
np.argmin(decision)

3941

In [14]:
np.unravel_index(136, (100, 5))

(27, 1)

In [15]:
sol[27] = 1

In [16]:
# decision = calculate_decision_space(prb, sol, k=5)
# p = np.unravel_index(np.argmin(decision), (100, 5))
# sol[p[0]] = p[1]
# calculate_cost(prb, sol)

In [17]:
hist = set()
hist.add(solution_to_str(sol))

In [18]:
retry = 0
incumbent = sol.copy()
low_cost = calculate_cost(prb, sol, k=5)
while retry < num_groups:
    decision = calculate_decision_space(prb, sol, k=5)
    move = select_move(decision, sol, hist)

    if move != (-1, -1):
        sol[move[0]] = move[1]
        cost = calculate_cost(prb, sol, k=5)
        if cost < low_cost:
            low_cost = cost
            incumbent = sol.copy()
            retry = 0
            print(cost)
        else:
            retry += 1
        hist.add(solution_to_str(sol))
    else:
        print("No more possible moves!")

0.0008787027701381814
0.0007431331079842229
0.0006202382815112399
0.000511131040571599
0.00041421139774633364
0.0003296694353842572
0.00025815611015707546
0.00019876350969466884
0.00015080011291905348
0.00010985365669982846
7.828612911850918e-05
5.4463756966780216e-05
3.519570409699899e-05
2.2266826014196292e-05
1.4651702470433843e-05
1.1569104521666235e-05
8.619208614099477e-06
7.526463381020699e-06
6.969889855504357e-06
6.153364811870221e-06
5.49179499272739e-06
4.908317943370572e-06
4.68972741866348e-06
4.217141826349522e-06
4.048036123430371e-06
3.4288198460860648e-06
2.966865431557974e-06
2.8294093198339518e-06
2.5160758438843224e-06
2.4040589321383862e-06
2.298909804440592e-06
2.2560909197063692e-06
2.2485525326110256e-06
1.983689027218958e-06
1.8197589310227087e-06
1.713400411942321e-06
1.6704207431504518e-06
1.5322888664999135e-06
1.4359647908642307e-06
1.4109174871844723e-06
1.290888031128267e-06
1.1782737080223696e-06
1.1300630003980308e-06
1.1011372208647455e-06
1.0476063692

In [19]:
retry

1000

In [20]:
sol

array([1, 1, 0, 1, 2, 0, 1, 3, 2, 0, 3, 2, 1, 0, 2, 0, 3, 0, 1, 3, 2, 4,
       1, 4, 0, 4, 4, 1, 2, 1, 4, 1, 4, 2, 3, 1, 4, 0, 0, 0, 0, 1, 0, 4,
       2, 2, 0, 4, 0, 1, 1, 2, 4, 1, 3, 4, 4, 3, 2, 1, 2, 3, 0, 3, 1, 3,
       2, 3, 3, 4, 4, 3, 0, 2, 0, 0, 2, 3, 0, 1, 1, 0, 1, 2, 0, 0, 3, 1,
       4, 3, 3, 4, 1, 1, 3, 0, 2, 0, 3, 0, 3, 3, 0, 0, 3, 1, 4, 3, 3, 4,
       4, 2, 3, 4, 4, 3, 3, 0, 0, 0, 0, 1, 3, 4, 4, 4, 0, 3, 1, 1, 3, 0,
       3, 2, 2, 4, 2, 3, 4, 3, 0, 0, 3, 0, 0, 1, 1, 3, 2, 4, 4, 3, 3, 1,
       0, 3, 1, 1, 4, 3, 4, 2, 3, 0, 2, 0, 3, 3, 1, 4, 0, 1, 0, 4, 1, 4,
       3, 1, 1, 2, 1, 0, 4, 0, 3, 3, 0, 1, 2, 3, 3, 4, 2, 1, 0, 3, 2, 0,
       4, 2, 3, 4, 1, 1, 2, 1, 3, 0, 1, 4, 4, 3, 4, 1, 1, 2, 4, 2, 1, 2,
       2, 2, 2, 2, 2, 3, 1, 0, 3, 3, 3, 3, 0, 3, 1, 2, 2, 3, 1, 2, 1, 2,
       0, 2, 1, 3, 3, 2, 1, 3, 1, 1, 2, 4, 2, 0, 0, 1, 1, 4, 2, 4, 3, 1,
       3, 4, 0, 0, 4, 2, 0, 1, 4, 3, 2, 3, 1, 4, 2, 2, 1, 3, 1, 2, 0, 3,
       0, 0, 0, 0, 3, 4, 2, 1, 2, 2, 0, 1, 4, 1, 3,

In [21]:
len(hist)

1739

In [22]:
np.sum(prb, axis=0) / np.sum(prb)

array([0.3996377 , 0.30109283, 0.19999695, 0.09927252])

In [23]:
folds = [prb[sol==i] for i in range(5)]
fold_percents = np.array([np.sum(folds[i], axis=0) / np.sum(folds[i]) for i in range(5)])
fold_percents

array([[0.3996801 , 0.30130971, 0.19996097, 0.09904922],
       [0.39960673, 0.30100605, 0.19990084, 0.09948638],
       [0.3995654 , 0.30106824, 0.20004584, 0.09932052],
       [0.39951511, 0.30110498, 0.19985928, 0.09952062],
       [0.39982112, 0.30097536, 0.20021788, 0.09898564]])

In [24]:
[np.sum(folds[i]) / np.sum(prb) for i in range(5)]

[0.19989263343221914,
 0.20012246234113396,
 0.19982478725984576,
 0.2000876911777926,
 0.20007242578900858]