# Implementation of the Urn Process Described in Box 1 of the Paper

In [4]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [5]:
def draw_ball(urn):
    total_balls = sum(urn.values())
    prob = {color: count / total_balls for color, count in urn.items()}
    return np.random.choice(list(urn.keys()), p=list(prob.values()))

def add_ball(urn, color):
    urn[color] += 1

def urn_process(urn, num_trials): 
    for _ in range(num_trials): 
        color = draw_ball(urn)
        add_ball(urn, color)
    return urn

urn = {
    "red": 1,
    "blue": 1,
    "green": 1
}

def urn_sample(alpha, num_samples): 
    samples = []
    for _ in tqdm(range(num_samples)): 
        urn = {
            "red": alpha[0],    
            "blue": alpha[1],
            "green": alpha[2]
        }
        urn_process(urn, 100_000)
        total = sum(urn.values())
        samples.append((urn["red"]/total, urn["blue"]/total, urn["green"]/total))
    return np.array(samples)
print(urn_process(urn, 100_000))

total = sum(urn.values())
print("proportions ", {color: count / total for color, count in urn.items()})

samples = urn_sample([1, 1, 1], 10)
print(samples)

{'red': 53596, 'blue': 3442, 'green': 42965}
proportions  {'red': 0.5359439216823495, 'blue': 0.03441896743097707, 'green': 0.4296371108866734}


  0%|          | 0/10 [00:00<?, ?it/s]

[[0.57382279 0.03231903 0.39385818]
 [0.75398738 0.03247903 0.21353359]
 [0.05607832 0.68903933 0.25488235]
 [0.02837915 0.68551943 0.28610142]
 [0.01239963 0.81158565 0.17601472]
 [0.04206874 0.70487885 0.25305241]
 [0.21351359 0.61413158 0.17235483]
 [0.1133966  0.16304511 0.72355829]
 [0.7866864  0.14139576 0.07191784]
 [0.1167565  0.1634251  0.71981841]]
