Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase S from one to many using inferencedata coordinates #5

Closed
hyunjimoon opened this issue Oct 29, 2022 · 4 comments
Closed

Increase S from one to many using inferencedata coordinates #5

hyunjimoon opened this issue Oct 29, 2022 · 4 comments

Comments

@hyunjimoon
Copy link
Contributor

The following from draws2data2draws function would allow obs_xr to have prior_draws coordinates on top of draw, chain. This would allow easier plot using arviz library.

   obs_xr = draws2data(model, numeric, iter_sampling=S)

    prior_pred_check(setting)
    numeric["process_noise_scale"] = 0.0
    obs_xr.assign_coords({'prior_draw': [s for s in range(S)]})

    for s in range(S):
        obs_xr_s = obs_xr[setting_obs].sel(draw=s)
        obs_dict_s = {k: v.values.flatten() for (k, v) in obs_xr_s[setting_obs].items()}
        for key, value in obs_dict_s.items():
            numeric[key] = value

        for target_name in setting['target_simulated_vector_names']:
            model.update_numeric({f'{target_name}_obs': obs_dict_s[f'{target_name}_obs']})
        model.update_numeric({'process_noise_scale': 0.0})

        posterior_s = data2draws(model, numeric, chains=4, iter_sampling=int(M/4))
        obs_xr['prior_draw' == s].update(posterior_s)
@hyunjimoon
Copy link
Contributor Author

Plot design for S > 1 is needed.

@hyunjimoon
Copy link
Contributor Author

hyunjimoon commented Nov 5, 2022

idata for sbc looks like the following but prior_draws need to be added to every data variables. I think currently two prior_draw and draw are the same.

<xarray.Dataset>
Dimensions:              (chain: 1, draw: 2, initial_outcome_dim_0: 3,
                          integrated_result_dim_0: 20,
                          integrated_result_dim_1: 3, predator_dim_0: 20,
                          prey_dim_0: 20, process_noise_dim_0: 20,
                          prey_obs_dim_0: 20, predator_obs_dim_0: 20,
                          prior_draw: 2)
Coordinates:
  * chain                (chain) int64 1
  * draw                 (draw) int64 0 1
  * prior_draw           (prior_draw) int64 1 2
Dimensions without coordinates: initial_outcome_dim_0, integrated_result_dim_0,
                                integrated_result_dim_1, predator_dim_0,
                                prey_dim_0, process_noise_dim_0,
                                prey_obs_dim_0, predator_obs_dim_0
Data variables: (12/13)
    prey_birth_frac      (chain, draw) float64 ...
    pred_birth_frac      (chain, draw) float64 ...
    m_noise_scale        (chain, draw) float64 ...
    predator__init       (chain, draw) float64 ...
    prey__init           (chain, draw) float64 ...
    process_noise__init  (chain, draw) float64 ...
    ...                   ...
    integrated_result    (chain, draw, integrated_result_dim_0, integrated_result_dim_1) float64 ...
    predator             (chain, draw, predator_dim_0) float64 ...
    prey                 (chain, draw, prey_dim_0) float64 ...
    process_noise        (chain, draw, process_noise_dim_0) float64 ...
    prey_obs             (chain, draw, prey_obs_dim_0) float64 ...
    predator_obs         (chain, draw, predator_obs_dim_0) float64 ...

idata for data2draws

<xarray.Dataset>
Dimensions:                 (chain: 4, draw: 25, initial_outcome_dim_0: 3,
                             integrated_result_dim_0: 20,
                             integrated_result_dim_1: 3, prey_dim_0: 20,
                             predator_dim_0: 20, process_noise_dim_0: 20,
                             prey_obs_posterior_dim_0: 20,
                             predator_obs_posterior_dim_0: 20)
Coordinates:
  * chain                   (chain) int64 1 2 3 4
  * draw                    (draw) int64 0 1 2 3 4 5 6 ... 18 19 20 21 22 23 24
Dimensions without coordinates: initial_outcome_dim_0, integrated_result_dim_0,
                                integrated_result_dim_1, prey_dim_0,
                                predator_dim_0, process_noise_dim_0,
                                prey_obs_posterior_dim_0,
                                predator_obs_posterior_dim_0
Data variables: (12/14)
    m_noise_scale           (chain, draw) float64 ...
    pred_birth_frac         (chain, draw) float64 ...
    prey_birth_frac         (chain, draw) float64 ...
    prey__init              (chain, draw) float64 ...
    predator__init          (chain, draw) float64 ...
    process_noise__init     (chain, draw) float64 ...
    ...                      ...
    prey                    (chain, draw, prey_dim_0) float64 ...
    predator                (chain, draw, predator_dim_0) float64 ...
    process_noise           (chain, draw, process_noise_dim_0) float64 ...
    prey_obs_posterior      (chain, draw, prey_obs_posterior_dim_0) float64 ...
    predator_obs_posterior  (chain, draw, predator_obs_posterior_dim_0) float64 ...
    loglik                  (chain, draw) float64 ...

@hyunjimoon
Copy link
Contributor Author

Other option would be working with cmdstanpy's dataset:

data2draws_data <xarray.Dataset>
Dimensions:                 (draw: 25, chain: 4, initial_outcome_dim_0: 3,
                             integrated_result_dim_0: 20,
                             integrated_result_dim_1: 3, prey_dim_0: 20,
                             predator_dim_0: 20, process_noise_dim_0: 20,
                             prey_obs_posterior_dim_0: 20,
                             predator_obs_posterior_dim_0: 20)
Coordinates:
  * chain                   (chain) int64 1 2 3 4
  * draw                    (draw) int64 0 1 2 3 4 5 6 ... 18 19 20 21 22 23 24
Dimensions without coordinates: initial_outcome_dim_0, integrated_result_dim_0,
                                integrated_result_dim_1, prey_dim_0,
                                predator_dim_0, process_noise_dim_0,
                                prey_obs_posterior_dim_0,
                                predator_obs_posterior_dim_0
Data variables: (12/14)
    m_noise_scale           (chain, draw) float64 0.009761 0.01051 ... 0.0116
    pred_birth_frac         (chain, draw) float64 0.04619 0.04625 ... 0.04621
    prey_birth_frac         (chain, draw) float64 0.7822 0.7825 ... 0.7822
    prey__init              (chain, draw) float64 30.0 30.0 30.0 ... 30.0 30.0
    predator__init          (chain, draw) float64 4.0 4.0 4.0 ... 4.0 4.0 4.0
    process_noise__init     (chain, draw) float64 0.0 0.0 0.0 ... 0.0 0.0 0.0
    ...                      ...
    prey                    (chain, draw, prey_dim_0) float64 30.18 ... 40.82
    predator                (chain, draw, predator_dim_0) float64 4.024 ... 6...
    process_noise           (chain, draw, process_noise_dim_0) float64 0.0 .....
    prey_obs_posterior      (chain, draw, prey_obs_posterior_dim_0) float64 3...
    predator_obs_posterior  (chain, draw, predator_obs_posterior_dim_0) float64 ...
    loglik                  (chain, draw) float64 nan nan nan ... nan nan nan

@hyunjimoon
Copy link
Contributor Author

For now this is dealt by calling data2draws S different number of times

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant