# Initial Values

In [1]:
import ssms
import hssm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytensor
hssm.set_floatX('float32')
import re
import pymc as pm

Setting PyTensor floatX type to float32.
Setting "jax_enable_x64" to False. If this is not intended, please set `jax` to False.


Folklore suggests that the initial values of our sampler shouldn't matter for the outcome of the analysis, since MCMC will find the relevant region of the parameter space eventually. 

Well, we can't trust the elders blindly.. and you will sometimes find that initial value setting need to be corrected for good results.

There are multiple reasons, but chief among them is that we are routinely dealing with constrained parameter spaces. At the edges of these parameter spaces, 
likelihoods can become less well behaved (this is true e.g. with our `approx_differentiable` likelihoods based on [LANs](https://elifesciences.org/articles/65074)).
These edges may be unlikely a priori, but if your sampler takes a path to the parameter space (even if just on the way to a distant mode), that passes through 
these regions, you may get undesirable results.

There are two ways in which you can control the bounds of a parameter:

1. For an **individual parameter**, you can specify a `Truncated` distribution, or simply choose a prior distribution that naturally lives in the desired space (e.g. a `Gamma` distribution when dealing with positivity constraints).
2. To constrain the **outcome of a regression** you can always use a `Link` function that will target the desired output space.

Option 2., while computationally kosher in principle, can produce some downstream headaches, since it changes the interpretation of parameter values from straightforward to e.g. log-odds (*Logistic Regression* with *logit* link). This demands careful thinking about priors and can sometimes make reporting of result more difficult.

Therefore you may find yourself in the situation that you do not want to use link functions or other a priori constaints, while still needing to respect pathologies concerning regions of the parameter space. 

At this point, the wisdom of the ancients aside, practical considerations will force you to think about initial values. 

We are not here to prescribe you how to deal with this, but we try to provide you with options. This short tutorial illustrates how to **inspect** and **adjust** the initial value settings
of an `HSSM` model.

### Load up some data

In [2]:
cav_data = hssm.load_data("cavanagh_theta")

### Simple model

In [3]:
model = hssm.HSSM(
				data=cav_data,
				model="ddm",
				loglik_kind="approx_differentiable",
			)

{'z_interval__': array(0., dtype=float32), 'a_interval__': array(-1.1175871e-07, dtype=float32), 't_log__': array(0.6931472, dtype=float32), 'v': array(0., dtype=float32)}


Now we can inspect the **initial value setting**.

In [4]:
model.initvals

{'z': array(0.5, dtype=float32),
 'a': array(1.5, dtype=float32),
 't': array(0.025, dtype=float32),
 'v': array(0., dtype=float32)}

### Set Initial Values

We illustrate the preferred way of setting initial values manually.

### Route 1

Define a custom dictionary and pass it to the sampler.

In [5]:
from copy import deepcopy
my_initvals = deepcopy(model.initvals)
my_initvals['z'] = np.array(0.6, dtype='float32')

In [6]:
idata = model.sample(draws = 10, tune=1, initvals=my_initvals)

Only 10 samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate.
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
There were 40 divergences after tuning. Increase `target_accept` or reparameterize.
The number of samples is too small to check convergence reliably.


In [7]:
idata.posterior

We allowed the sampler very little tuning, and and therefore our initial values are still apparent in the chains. 

#### Route 2

Adjust the `._initval` attribute.
`model.initvals` is a class property, which serves as an accessor the the underlying `._initvals` attribute.
We can adjust this directly, and it will be used as the default for our sampler.

In [8]:
model._initvals

{'z': array(0.5, dtype=float32),
 'a': array(1.5, dtype=float32),
 't': array(0.025, dtype=float32),
 'v': array(0., dtype=float32)}

In [9]:
model._initvals['z'] = np.array(0.6, dtype='float32')
idata2 = model.sample(draws=10, tune=1)

Using default initvals. 

The model has already been sampled. Overwriting the previous inference object. Any previous reference to the inference object will still point to the old object.


Only 10 samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate.
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
There were 40 divergences after tuning. Increase `target_accept` or reparameterize.
The number of samples is too small to check convergence reliably.


In [10]:
idata2

While models can become a lot more complicated, you will be able to adjust initial values via this process consistently.

## Link function

Setting initial values for parameters when working with link functions becomes a little more tricky. 
Below is an example.

Let's define a simple regression with `logit` link functions.

In [11]:
model = hssm.HSSM(
				data=cav_data,
				model="ddm",
				link_settings='log_logit',
				loglik_kind="approx_differentiable",
				include = [{'name': 'a',
							'formula': 'a ~ 1 + theta'
							}],
			)

{'v_interval__': array(0., dtype=float32), 'z_interval__': array(0., dtype=float32), 't_log__': array(0.6931472, dtype=float32), 'a_Intercept': array(0., dtype=float32), 'a_theta': array(0., dtype=float32)}


In [12]:
model.initvals

{'v': array(0., dtype=float32),
 'z': array(0.5, dtype=float32),
 't': array(0.025, dtype=float32),
 'a_Intercept': array(0., dtype=float32),
 'a_theta': array(0., dtype=float32)}

Well, `a_Intercept` gets a default initial value of `0`?


`a` is the boundary separation parameter in the drift diffusion model, a setting of `0` seems like an extremely bad choice... It would lead to a pointmass `rt` at `0`s!


What's going on here? Let's expect our model... 

In [13]:
model

Hierarchical Sequential Sampling Model
Model: ddm

Response variable: rt,response
Likelihood: approx_differentiable
Observations: 3988

Parameters:

v:
    Prior: Uniform(lower: -3.0, upper: 3.0)
    Explicit bounds: (-3.0, 3.0)
 (ignored due to link function)
a:
    Formula: a ~ 1 + theta
    Priors:
        a_Intercept ~ Normal(mu: 0.0, sigma: 2.5)
        a_theta ~ Normal(mu: 0.0, sigma: 2.497499942779541)
    Link: Generalized logit link function with bounds (0.3, 2.5)
    Explicit bounds: (0.3, 2.5)
 (ignored due to link function)
z:
    Prior: Uniform(lower: 0.0, upper: 1.0)
    Explicit bounds: (0.0, 1.0)
 (ignored due to link function)
t:
    Prior: HalfNormal(sigma: 2.0)
    Explicit bounds: (0.0, 2.0)
 (ignored due to link function)

Lapse probability: 0.05
Lapse distribution: Uniform(lower: 0.0, upper: 20.0)

Ah, we apply the **Generalized Logit** link to the `a` parameters.
The initial value of the actual parameter `a` that the likelihood will receive, is the output of the transformation.

Our `hssm` model class includes everything we need to inspect this further.

What would we expect here?

The generalized logit transformation (link), has the associated generalized sigmoid transformation (inverse link) as the forward transform.
We expect that, evaluating this transformation at `0`, should give back the mean value between the explicit bounds we set for the parameter.

So we expect: $(0.3 + 2.5) / 2 = 1.4$... let's check this.

**Note**:

The `forward` link, (from parameter to function output that the likelihood receives), is call *inverse link* function in GLM lingo.
We follow this nomenclature in HSSM. 



In [14]:
model.params['a'].link.linkinv(model._initvals['a_Intercept'])

1.4000000000000001

Voila! This checks out! We note that the `linkinv()` function can come in handy if you want to play around with initial value setting in the context of using link functions yourself.

## HSSM's initial value defaults logic

We try to guide the initial value settings in HSSM, with reasonble defaults that hopefully work in many cases without needed further adjustments.
It is however difficulty to find settings that work blindly. 

We follow the guidelines below: 

1. Avoid known issues near parameter boundaries(especially important for `approx_differentiable` likelihoods) --> whenever possible, initialize near the center of bounded parameter spaces
2. Starting values for the `t` parameter should be low to avoid known pathologies of the `analytic` DDM likelihood when the smallest reaction times (`rt`) values come close to `t` 
3. In a regression setting, minimize the spread of `offset` parameters to avoid inadvertently running into parameter limits and initialize all but the `Intercept` parameter to `0` 

**NOTE**:

Especially guideline 3. is a *very conservative* setting, focused solely on avoiding boundary behavior. This will NOT always be a smart idea, 
and it may sometimes collaterally have a negative impact on convergence. 

We *encourage users* to actively play with initial value setting at the moment.

### Example

In [15]:
cav_data

Unnamed: 0,participant_id,stim,rt,response,theta,dbs,conf
0,0,LL,1.210,1.0,0.656275,1,HC
1,0,WL,1.630,1.0,-0.327889,1,LC
2,0,WW,1.030,1.0,-0.480285,1,HC
3,0,WL,2.770,1.0,1.927427,1,LC
4,0,WW,1.140,-1.0,-0.213236,1,HC
...,...,...,...,...,...,...,...
3983,13,LL,1.450,-1.0,-1.237166,0,HC
3984,13,WL,0.711,1.0,-0.377450,0,LC
3985,13,WL,0.784,1.0,-0.694194,0,LC
3986,13,LL,2.350,-1.0,-0.546536,0,HC


In [16]:
model_reg = hssm.HSSM(
				data=cav_data,
				model="ddm",
				prior_settings='safe',
				loglik_kind="approx_differentiable",
				include = [{'name': 'a',
							'formula': 'a ~ 1 + theta + (1|participant_id)'
							}],
				
				)

{'v_interval__': array(0., dtype=float32), 'z_interval__': array(0., dtype=float32), 't_log__': array(0.6931472, dtype=float32), 'a_Intercept': array(0., dtype=float32), 'a_theta': array(0., dtype=float32), 'a_1|participant_id_sigma_log__': array(0.9162908, dtype=float32), 'a_1|participant_id_offset': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)}


In [17]:
model_reg.initvals

{'v': array(0., dtype=float32),
 'z': array(0.5, dtype=float32),
 't': array(0.025, dtype=float32),
 'a_Intercept': array(1.5, dtype=float32),
 'a_theta': array(0., dtype=float32),
 'a_1|participant_id_sigma': array(2.5000002, dtype=float32),
 'a_1|participant_id_offset': array([ 0.00898956, -0.00310357, -0.00832629,  0.00767228,  0.00646489,
        -0.007457  , -0.00865048, -0.00724489,  0.00380229,  0.00925265,
         0.00625377,  0.00593458,  0.00638663, -0.00145364], dtype=float32)}

Let's discuss what we see here: 

1. We do NOT apply a link function in this case, so *guideline 1* applies without further cognitive effort. By contrast to the previous example, check how `a_Intercept` is now directly initialized as `1.5`, near the middle of the parameter range.
2. Covariate betas, in this case `a_theta` as initialized to `0`
3. The `offset` parameters associated with individual parameters in group hierarchies are initialized cloe to `0`
4. `t` is initialized close to `0` as per *guideline 2*.
5. The remaining parameters are set close to or at the middle of the allowed parameters space (if bounded)

For parameters that we do not actively manipulate, **PyMC** (**BAMBI**) defaults are applied unchanged (e.g. here: `a_1|participant_id_sigma`). These setting may sometimes be sub-optimal for applications in HSSM, hence we again *caution* the user to take an active approach towards investiging initial values in case of convergence issues.

### Parameters

There are two keyword arguments (`kwargs`) that we can set in the the base `HSSM` class. 

1. `process_initvals: bool` turns processing of initial values on and off
2. `initval_jitter: float` which applies an uniform jitter around vector values initial values

In [18]:
model_no_initval = hssm.HSSM(
				data=cav_data,
				model="ddm",
				loglik_kind="approx_differentiable",
				include = [{'name': 'a',
							'formula': 'a ~ 1 + theta + (1|participant_id)'
							}],
				process_initvals=False
				)

{'v_interval__': array(0., dtype=float32), 'z_interval__': array(0., dtype=float32), 't_log__': array(0.6931472, dtype=float32), 'a_Intercept': array(0., dtype=float32), 'a_theta': array(0., dtype=float32), 'a_1|participant_id_sigma_log__': array(0.9162908, dtype=float32), 'a_1|participant_id_offset': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)}


We can see the result of turning initial value processing off (NOT recommended if you want to use defaults at all).
This applies [BAMBI](https://github.com/bambinos/bambi)/[PyMC](https://www.pymc.io/welcome.html) defaults, and adds default jitter to vector valued parameters.


In [19]:
model_no_initval.initvals

{'v': array(0., dtype=float32),
 'z': array(0.5, dtype=float32),
 't': array(2., dtype=float32),
 'a_Intercept': array(0., dtype=float32),
 'a_theta': array(0., dtype=float32),
 'a_1|participant_id_sigma': array(2.5000002, dtype=float32),
 'a_1|participant_id_offset': array([-0.00474155, -0.00736187, -0.00240796,  0.00919322,  0.00788189,
        -0.00918664, -0.00256089,  0.00421138,  0.00155043,  0.00478971,
         0.00817624, -0.00963624, -0.00202855,  0.00899789], dtype=float32)}

If you want to change the magnitude of the default jitter, you can manipulate it via the `initval_jitter` argument.

In [22]:
model_jitter = hssm.HSSM(
				data=cav_data,
				model="ddm",
				loglik_kind="approx_differentiable",
				include = [{'name': 'a',
							'formula': 'a ~ 1 + theta + (1|participant_id)'
							}],
				process_initvals=False,
				initval_jitter=0.5
				)

{'v_interval__': array(0., dtype=float32), 'z_interval__': array(0., dtype=float32), 't_log__': array(0.6931472, dtype=float32), 'a_Intercept': array(0., dtype=float32), 'a_theta': array(0., dtype=float32), 'a_1|participant_id_sigma_log__': array(0.9162908, dtype=float32), 'a_1|participant_id_offset': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)}


In [23]:
model_jitter.initvals

{'v': array(0., dtype=float32),
 'z': array(0.5, dtype=float32),
 't': array(2., dtype=float32),
 'a_Intercept': array(0., dtype=float32),
 'a_theta': array(0., dtype=float32),
 'a_1|participant_id_sigma': array(2.5000002, dtype=float32),
 'a_1|participant_id_offset': array([ 0.37601715, -0.49712795,  0.02380673, -0.1780633 ,  0.47284257,
         0.25974694,  0.2735396 ,  0.05379647, -0.1753871 , -0.46954647,
         0.43693042, -0.0241086 ,  0.44989318, -0.4555624 ], dtype=float32)}