## Examples of dirichlet_partition

In [1]:
import numpy as np
from dirichlet import dirichlet_partition

In [4]:
# A helper funtion that prints out a summary of a partition
def summarize_partition(y,partition):
    total_n=0
    for k,ixs in partition.items():
        _,counts=np.unique(y[ixs],return_counts=True)
        proportions=np.array2string(counts/len(ixs),formatter={'float': lambda x: f'{x:.2f}'})
        print(f'Client: {k} Counts: {counts} (total: {len(ixs)}) %: {proportions}')
        total_n+=len(ixs)
        
    print(f'Total samples: {total_n}')

In [7]:
# A class balanced y
y=np.random.choice([0,1],size=10000,p=[0.5,0.5])

# Partitioned among 5 clients/segments with a high alpha
partition=dirichlet_partition(y,c_clients=5,alpha=100,debug=True)

summarize_partition(y,partition)

Client: 0 Counts: [ 912 1052] (total: 1964) %: [0.46 0.54]
Client: 1 Counts: [1043  926] (total: 1969) %: [0.53 0.47]
Client: 2 Counts: [ 930 1033] (total: 1963) %: [0.47 0.53]
Client: 3 Counts: [ 907 1055] (total: 1962) %: [0.46 0.54]
Client: 4 Counts: [1120  846] (total: 1966) %: [0.57 0.43]
Total samples: 9824


In [8]:
# A class unbalanced y
y=np.random.choice([0,1],size=10000,p=[0.9,0.1])

# Partitioned among 5 clients/segments with a high alpha
partition=dirichlet_partition(y,c_clients=5,alpha=100,debug=True)

# Note how less total samples are distributed
summarize_partition(y,partition)

Client: 0 Counts: [231 170] (total: 401) %: [0.58 0.42]
Client: 1 Counts: [176 223] (total: 399) %: [0.44 0.56]
Client: 2 Counts: [178 214] (total: 392) %: [0.45 0.55]
Client: 3 Counts: [199 200] (total: 399) %: [0.50 0.50]
Client: 4 Counts: [211 188] (total: 399) %: [0.53 0.47]
Total samples: 1990


In [11]:
# A class balanced y
y=np.random.choice([0,1],size=10000,p=[0.5,0.5])

# Partitioned among 5 clients/segments with a low alpha
partition=dirichlet_partition(y,c_clients=5,alpha=0.1,debug=True)

# Note how low alphas produce unbalanced partitions
summarize_partition(y,partition)

Client: 0 Counts: [   8 1270] (total: 1278) %: [0.01 0.99]
Client: 1 Counts: [  12 1256] (total: 1268) %: [0.01 0.99]
Client: 2 Counts: [1213] (total: 1213) %: [1.00]
Client: 3 Counts: [4966    8] (total: 4974) %: [1.00 0.00]
Client: 4 Counts: [1239] (total: 1239) %: [1.00]
Total samples: 9972


In [12]:
# A class unbalanced y
y=np.random.choice([0,1],size=10000,p=[0.1,0.9])

# Partitioned among 5 clients/segments with a low alpha
partition=dirichlet_partition(y,c_clients=5,alpha=0.1,debug=True)

# Note how low alphas produce unbalanced partitions
# and that y's baseline unbalance reduces total samples
summarize_partition(y,partition)

Client: 0 Counts: [321] (total: 321) %: [1.00]
Client: 1 Counts: [516] (total: 516) %: [1.00]
Client: 2 Counts: [472  20] (total: 492) %: [0.96 0.04]
Client: 3 Counts: [321] (total: 321) %: [1.00]
Client: 4 Counts: [326] (total: 326) %: [1.00]
Total samples: 1976
