In [1]:
from generate_example import generate_multilevel_params_and_data
from log_densites.multilevel_model import client_model_log_density, global_model_log_density
from sfvi import init, fit

import jax
import optax

from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


#### Generate a fake data set 

In [2]:
seed = 0 
params, data = generate_multilevel_params_and_data(seed) 

#### Initialize the algorithm states

In [3]:
optimizer = optax.adam(learning_rate=1e-4)
position = jax.random.normal(jax.random.PRNGKey(0), (6,))

# initialize global state
global_state = init(position, optimizer)

# initialize five client states
client_state_1 = init(position[:3], optimizer)
client_state_2 = init(position[:3], optimizer)
client_state_3 = init(position[:3], optimizer)
client_state_4 = init(position[:3], optimizer)
client_state_5 = init(position[:3], optimizer)

client_states = [client_state_1, client_state_2, client_state_3, client_state_4, client_state_5]

client1_logdensity = partial(client_model_log_density, data["client_0_data"])
client2_logdensity = partial(client_model_log_density, data["client_1_data"])
client3_logdensity = partial(client_model_log_density, data["client_2_data"])
client4_logdensity = partial(client_model_log_density, data["client_3_data"])
client5_logdensity = partial(client_model_log_density, data["client_4_data"])

client_logdensity_fns = [client1_logdensity, client2_logdensity, client3_logdensity, client4_logdensity, client5_logdensity]

#### Run the algorithm 

In [4]:
global_state, client_states = fit(0, client_states, global_state, client_logdensity_fns, global_model_log_density, optimizer, num_samples=5, num_steps=1000)

Step 1/1000 | Global Objective: 8491898.0
Step 2/1000 | Global Objective: 8654575.0
Step 3/1000 | Global Objective: 8553672.0
Step 4/1000 | Global Objective: 8659135.0
Step 5/1000 | Global Objective: 8617196.0
Step 6/1000 | Global Objective: 8342620.0
Step 7/1000 | Global Objective: 8558232.0
Step 8/1000 | Global Objective: 8537227.0
Step 9/1000 | Global Objective: 8542228.0
Step 10/1000 | Global Objective: 8759372.0
Step 11/1000 | Global Objective: 8669669.0
Step 12/1000 | Global Objective: 8478402.0
Step 13/1000 | Global Objective: 8670322.0
Step 14/1000 | Global Objective: 8560813.0
Step 15/1000 | Global Objective: 8632926.0
Step 16/1000 | Global Objective: 8641859.0
Step 17/1000 | Global Objective: 8591537.0
Step 18/1000 | Global Objective: 8603308.0
Step 19/1000 | Global Objective: 8717656.0
Step 20/1000 | Global Objective: 8821759.0
Step 21/1000 | Global Objective: 8502875.0
Step 22/1000 | Global Objective: 8486172.0
Step 23/1000 | Global Objective: 8599337.0
Step 24/1000 | Globa

KeyboardInterrupt: 