In [1]:
from generate_example import generate_multilevel_params_and_data
from log_densites.multilevel_model import client_model_log_density, global_model_log_density
from sfvi import init, fit

import jax
import optax

from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


#### Generate a fake data set 

In [2]:
seed = 0 
params, data = generate_multilevel_params_and_data(seed) 

#### Initialize the algorithm states

In [3]:
optimizer = optax.adam(learning_rate=1e-2)
position = jax.random.normal(jax.random.PRNGKey(0), (6,))

# initialize global state
global_state = init(position, optimizer)

# initialize five client states
client_state_1 = init(position[:3], optimizer)
client_state_2 = init(position[:3], optimizer)
client_state_3 = init(position[:3], optimizer)
client_state_4 = init(position[:3], optimizer)
client_state_5 = init(position[:3], optimizer)

client_states = [client_state_1, client_state_2, client_state_3, client_state_4, client_state_5]

client1_logdensity = partial(client_model_log_density, data["client_0_data"])
client2_logdensity = partial(client_model_log_density, data["client_1_data"])
client3_logdensity = partial(client_model_log_density, data["client_2_data"])
client4_logdensity = partial(client_model_log_density, data["client_3_data"])
client5_logdensity = partial(client_model_log_density, data["client_4_data"])

client_logdensity_fns = [client1_logdensity, client2_logdensity, client3_logdensity, client4_logdensity, client5_logdensity]

#### Run the algorithm 

In [5]:
global_state, client_states = fit(0, client_states, global_state, client_logdensity_fns, global_model_log_density, optimizer, num_samples=20, num_steps=10000)

Step 1/10000 | Global Objective: 6432094.5
Step 2/10000 | Global Objective: 6391494.5
Step 3/10000 | Global Objective: 6351203.0
Step 4/10000 | Global Objective: 6338551.5
Step 5/10000 | Global Objective: 6261708.5
Step 6/10000 | Global Objective: 6314294.5
Step 7/10000 | Global Objective: 6316167.5
Step 8/10000 | Global Objective: 6276223.0
Step 9/10000 | Global Objective: 6224546.5
Step 10/10000 | Global Objective: 6172988.0
Step 11/10000 | Global Objective: 6104008.5
Step 12/10000 | Global Objective: 6226114.5
Step 13/10000 | Global Objective: 6093897.5
Step 14/10000 | Global Objective: 6142234.5
Step 15/10000 | Global Objective: 6078097.5
Step 16/10000 | Global Objective: 6096756.0
Step 17/10000 | Global Objective: 6003050.0
Step 18/10000 | Global Objective: 5989262.5
Step 19/10000 | Global Objective: 5967594.5
Step 20/10000 | Global Objective: 6018327.5
Step 21/10000 | Global Objective: 6008092.5
Step 22/10000 | Global Objective: 6002676.5
Step 23/10000 | Global Objective: 5861457

In [6]:
global_state.mu

Array([ 2.1830456 , -0.23055132,  0.17939022,  1.277871  ,  0.1333185 ,
        0.20590883], dtype=float32)

Okay, so note that there is an error in the *intercept* term, but the slopes are pretty close to the true value under MCMC. 