# Generate data with distribution control

In [None]:
import polars as pl
import numpy as np
import random
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style="whitegrid")

# generate data to test the performance of groupby queries under different distributions of the same number of group keys
# generate n_rows tuples with n_groups group keys, the set of group keys range from 0 to n_groups-1
# generate tuples of (value, group0, group1, group2, ...) where the i-th group key is distributed as group_distributions[i]. each group should have at least one row

n_rows = 10 ** 8
n_groups = 10 ** 6
group_distributions = ["uniform", "normal", "exponential"]
records = []

# make sure each group has at least one row
for i in range(n_groups):
    row = {}
    row['value'] = np.random.randint(0, 1000)
    for dist in group_distributions:
        row[dist] = i
    records.append(row)
    
# all other rows get random group key according to some distribution
for i in range(n_rows - n_groups):
    row = {}
    row['value'] = np.random.randint(0, 1000)
    for dist in group_distributions:
        match dist:
            case "uniform":
                group_key = random.randint(0, n_groups - 1)
            case "normal":
                group_key = -1
                while group_key < 0 or group_key >= n_groups:
                    group_key = int(np.random.normal(n_groups / 2, n_groups / 8))
            case "exponential":
                group_key = -1
                while group_key < 0 or group_key >= n_groups:
                    group_key = int(np.random.exponential(n_groups / 5))
        row[dist] = group_key
    records.append(row)

In [None]:
df = pl.DataFrame(records)
df.write_parquet(f"gen_v1_{n_rows}_{n_groups}.parquet")

In [None]:
sns.histplot(data=df, x="uniform", discrete=True)
plt.show()
sns.histplot(data=df, x="normal", discrete=True)
plt.show()
sns.histplot(data=df, x="exponential", discrete=True)
plt.show()

In [None]:
# compute ground truth
for dist in group_distributions:
    res = df.sql(f"select {dist}, sum(value) as sum_value from self group by {dist} order by {dist}")
    display(res)