In [14]:
number_personas = 8

import numpy as np
from numba import jit, prange
from scipy.stats import dirichlet
from scipy.spatial.distance import pdist, squareform
from tqdm import trange, tqdm

@jit(nopython=True, parallel=True)
def calculate_distances(persons):
    n = len(persons)
    distances = np.empty((n, n), dtype=np.float64)
    for i in prange(n):
        for j in prange(i + 1, n):
            distances[i, j] = np.sqrt(np.sum((persons[i] - persons[j]) ** 2))
            distances[j, i] = distances[i, j]
    return distances

@jit(nopython=True)
def filter_similar_personas(persons, threshold=1e-5):
    n = len(persons)
    distances = calculate_distances(persons)
    keep_mask = np.ones(n, dtype=np.bool_)

    for i in range(n):
        if not keep_mask[i]:
            continue
        for j in range(i + 1, n):
            if distances[i, j] < threshold:
                keep_mask[j] = False

    filtered_persons = persons[keep_mask]
    return filtered_persons

def generate_personas(alpha_values, n_rm, n_persons, filter_persona_threshold=None, same_alpha=True, random_alpha=False):
    all_persons = []
    random_state = 42

    if same_alpha:
        for alpha in alpha_values:
            alphas = np.array([alpha] * n_rm)
            persons = dirichlet.rvs(alphas, size=n_persons, random_state=random_state)
            all_persons.append(persons)

    if random_alpha:
        alphas = np.random.choice(alpha_values, size=n_rm)
        persons = dirichlet.rvs(alphas, size=n_persons, random_state=random_state)
        all_persons.append(persons)

    all_persons = np.vstack(all_persons)
    if filter_persona_threshold:
        all_persons = filter_similar_personas(all_persons, filter_persona_threshold)
    return all_persons

n_persons = 10000
alpha_values = [0.1, 0.5, 1.0, 5.0]

persons = generate_personas(
    alpha_values,
    number_personas,
    n_persons,
    filter_persona_threshold=2e-1,
    same_alpha=True,
    random_alpha=True,
)

In [15]:
len(persons)

672

In [8]:
persons

array([[1.07828757e-03, 8.76649576e-01, 1.69635230e-07, 8.67535647e-12,
        1.22271967e-01, 2.71114771e-16],
       [9.54650302e-01, 2.35958872e-07, 4.05972179e-05, 1.35078376e-03,
        4.39310236e-02, 2.70570526e-05],
       [3.59578609e-04, 9.30442171e-08, 4.91740803e-03, 6.32794062e-03,
        1.25330942e-12, 9.88394980e-01],
       ...,
       [8.33887861e-02, 1.20410642e-01, 1.18616173e-01, 3.36975948e-03,
        1.72858132e-01, 5.01356508e-01],
       [2.57537108e-07, 6.85158004e-01, 1.68795675e-01, 1.45977129e-01,
        4.25528425e-06, 6.46793522e-05],
       [3.62259030e-05, 7.23197441e-01, 3.92218428e-05, 2.81336460e-04,
        1.36040526e-01, 1.40405248e-01]])