## Setup

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.lax import scan

import hssm
from hssm.rl.likelihoods.builder import (
    annotate_function,
    make_rl_logp_func,
    compute_v_trial_wise,
    compute_v_subject_wise,
    angle_logp_jax_func,
    _get_column_indices,
    _get_column_indices_with_computed,
    _collect_cols_arrays,
)

hssm.set_floatX("float32")

# Set random seed for reproducibility
np.random.seed(42)

# Building Complete Likelihoods with `make_rl_logp_func`
1. Identifies which parameters come from data vs. model parameters
2. Computes derived parameters (like drift rates from RL)
3. Assembles everything into a complete log-likelihood function

### 1. Setting up the components

In [None]:
#1: Annotate the learning rule function
compute_v_annotated = annotate_function(
    inputs=["rl.alpha", "scaler", "response", "feedback"],
    outputs=["v"]
)(compute_v_subject_wise)

vars(compute_v_annotated)

In [None]:
# 2: Annotate the SSM likelihood function
# This function needs 'v' as input, but 'v' is computed by our learning rule
ssm_logp_annotated = annotate_function(
    inputs=["v", "a", "z", "t", "theta", "rt", "response"],
    computed={"v": compute_v_annotated}  # Specify the dependency
)(angle_logp_jax_func)

vars(ssm_logp_annotated)

### 2. Creating the complete log-likelihood function

In [None]:
# Simulation parameters
n_participants = 5
n_trials_per_subject = 40
total_trials = n_participants * n_trials_per_subject

# Create complete dataset
data_matrix = np.column_stack([
    np.random.uniform(0.3, 1.5, total_trials),  # rt
    np.random.choice([-1, 1], total_trials)     # response
])

# Define parameter structure
data_cols = ["rt", "response"]
list_params = ["rl.alpha", "scaler", "a", "z", "t", "theta"]
extra_fields = ["feedback"]

# Build the log-likelihood function!
logp_func = make_rl_logp_func(
    ssm_logp_func=ssm_logp_annotated,
    n_participants=n_participants,
    n_trials=n_trials_per_subject,
    data_cols=data_cols,
    list_params=list_params,
    extra_fields=extra_fields
)

print("Complete log-likelihood function created!")
print(f"\nThis function will:")
print(f"  1. Extract parameters from data and args")
print(f"  2. Compute drift rates using the RL model")
print(f"  3. Evaluate the SSM likelihood")
print(f"  4. Return log-likelihood for all {total_trials} trials")

### 3. Using the log-likelihood function

In [None]:
# Create parameter arrays
rl_alpha = np.ones(total_trials) * 0.6
scaler = np.ones(total_trials) * 3.2
a = np.ones(total_trials) * 1.2
z = np.ones(total_trials) * 0.5
t = np.ones(total_trials) * 0.1
theta = np.ones(total_trials) * 0.3
feedback = np.random.choice([0, 1], total_trials)

# Evaluate log-likelihood
logp_values = logp_func(data_matrix, rl_alpha, scaler, a, z, t, theta, feedback)

print(f"Log-likelihood evaluation:")
print(f"  Output shape: {logp_values.shape}")
print(f"  Total log-likelihood: {logp_values.sum():.2f}")
print(f"  Per-trial range: [{logp_values.min():.2f}, {logp_values.max():.2f}]")
print(f"  Mean per-trial: {logp_values.mean():.2f}")

## Appendix: Key Design Patterns

### Pattern 1: Function Composition

The framework allows you to compose functions hierarchically:

In [None]:
# Example: Multi-level computation

@annotate_function(
    inputs=["param_a", "param_b"],
    outputs=["intermediate"]
)
def compute_intermediate(data):
    """Compute a simple intermediate value: sum of two inputs."""
    return data[:, 0] + data[:, 1]


@annotate_function(
    inputs=["intermediate", "param_c"],
    outputs=["final"],
    computed={"intermediate": compute_intermediate},
)
def compute_final(data):
    """Produce a final value using the intermediate and a secondary parameter."""
    # Expect data[:, 0] == intermediate, data[:, 1] == param_c
    return data[:, 0] * data[:, 1]

print("Hierarchical function composition (metadata):")
print(f"  Level 1: {compute_intermediate.inputs} -> {compute_intermediate.outputs}")
print(f"  Level 2: {compute_final.inputs} -> {compute_final.outputs}")
print(f"  Computed mapping on Level 2: {list(compute_final.computed.keys())}")

# Numeric example that explicitly shows compute_intermediate feeding compute_final
base_data = np.array([[1.0, 2.0], [3.0, 4.0]])  # columns: param_a, param_b
param_c = np.array([10.0, 20.0])                 # param_c for each row

# 1) Compute intermediate explicitly
intermediate_vals = compute_intermediate(base_data)
print("\nExplicit execution:")
print(f"  base_data:\n{base_data}")
print(f"  intermediate (param_a + param_b): {intermediate_vals}")

# 2) Use intermediate together with param_c to compute final output
final_input = np.column_stack([intermediate_vals, param_c])
final_vals = compute_final(final_input)
print(f"  param_c: {param_c}")
print(f"  final output (intermediate * param_c): {final_vals}")

# Note: when used inside the RL-SSM assembly (e.g. via make_rl_logp_func),
# the framework will automatically call the `compute_intermediate` function
# to produce the `intermediate` value and pass it into `compute_final` based
# on the `computed` metadata shown above.

In [None]:
# Now reproduce the same workflow using the framework's make_rl_logp_func

# Wrap the existing compute_intermediate as a function that outputs 'v'
annotated_v = annotate_function(inputs=["param_a", "param_b"], outputs=["v"])(compute_intermediate)

# Wrap compute_final as the 'SSM-like' function that consumes 'v' and 'param_c'
ssm_wrapped = annotate_function(inputs=["v", "param_c"], computed={"v": annotated_v})(compute_final)

# Build a logp function: 1 subject, 2 trials (matches base_data)
logp_fn_auto = make_rl_logp_func(
    ssm_wrapped,
    n_participants=1,
    n_trials=2,
    data_cols=["param_a", "param_b"],
    list_params=["param_c"],
    extra_fields=[],
)

# Use same inputs as manual example
data_for_logp = base_data  # shape (2,2) with columns param_a,param_b

# list_params are passed as separate arrays in the same order as `list_params`
result_auto = logp_fn_auto(data_for_logp, param_c)

print("Framework-driven execution:")
print(f"  data_for_logp:\n{data_for_logp}")
print(f"  param_c: {param_c}")
print(f"  result from make_rl_logp_func: {result_auto}")

# Compare to the manual final_vals computed previously
print(f"\nManual final_vals: {final_vals}")
print(f"Match? {np.allclose(result_auto, final_vals)}")