# Factorized Graded Response Model with Stochastic Imputation

This notebook demonstrates fitting a **FactorizedGRModel** from `bayesianquilts.irt` to the
Right-Wing Authoritarianism (RWA) scale dataset, with stochastic imputation of missing responses.

Key features shown:
- Loading the RWA dataset (22 items, 9 response categories, 2 latent dimensions)
- Introducing artificial missingness
- Fitting a **MICEBayesianLOO** imputation model
- Fitting a FactorizedGRModel **without** imputation (zero-fill baseline)
- Fitting a FactorizedGRModel **with** stochastic imputation
- Comparing results

In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ['JAX_PLATFORMS'] = 'cpu'

import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

## 1. Load and Explore the RWA Dataset

In [None]:
from bayesianquilts.data.rwa import (
    get_data, item_keys, item_text, to_reverse, scale_indices
)

df, num_people = get_data(polars_out=True)
print(f"Dataset shape: {df.shape}")
print(f"Number of people: {num_people}")
print(f"Number of items: {len(item_keys)}")
print(f"Response categories: 0-8 (9 levels)")
print(f"\nScale 1 items ({len(scale_indices[0])} items): {[item_keys[i] for i in scale_indices[0]]}")
print(f"Scale 2 items ({len(scale_indices[1])} items): {[item_keys[i] for i in scale_indices[1]]}")
df.head()

In [None]:
# Subsample for faster fitting in this demo
SUBSAMPLE_N = 500
rng = np.random.default_rng(42)
idx = rng.choice(num_people, size=SUBSAMPLE_N, replace=False)
idx.sort()

sub_df = df[idx.tolist()]
print(f"Subsample size: {len(sub_df)}")

## 2. Introduce Artificial Missingness

The RWA dataset has very low natural missingness (~0.3%). To demonstrate
imputation, we randomly mask 15% of responses as missing.

In [None]:
import polars as pl

MISSING_RATE = 0.15
rng_miss = np.random.default_rng(99)

# Start with the subsample and convert -1 (original missing) to NaN
sub_df_missing = sub_df.with_columns([
    pl.when(pl.col(k) == -1).then(None).otherwise(pl.col(k)).alias(k)
    for k in item_keys
])

# Add artificial MCAR missingness
mask_arrays = {}
for k in item_keys:
    mask = rng_miss.random(SUBSAMPLE_N) < MISSING_RATE
    mask_arrays[k] = mask

sub_df_missing = sub_df_missing.with_columns([
    pl.when(pl.Series(mask_arrays[k])).then(None).otherwise(pl.col(k)).alias(k)
    for k in item_keys
])

# Count missingness
total_missing = sum(sub_df_missing[k].null_count() for k in item_keys)
total_cells = SUBSAMPLE_N * len(item_keys)
print(f"Total missing: {total_missing}/{total_cells} ({100*total_missing/total_cells:.1f}%)")
print(f"\nMissing per item:")
for k in item_keys[:5]:
    print(f"  {k}: {sub_df_missing[k].null_count()}")
print(f"  ...")

In [None]:
# Response distributions for a few items (observed only)
fig, axes = plt.subplots(2, 3, figsize=(12, 6))
for ax, k in zip(axes.flat, item_keys[:6]):
    vals = sub_df_missing[k].drop_nulls().to_numpy()
    ax.hist(vals, bins=np.arange(-0.5, 9.5, 1), edgecolor='black', alpha=0.7)
    n_miss = sub_df_missing[k].null_count()
    ax.set_title(f'{k} ({n_miss} missing)')
    ax.set_xlabel('Response')
    ax.set_ylabel('Count')
plt.tight_layout()
plt.show()

## 3. Fit a MICEBayesianLOO Imputation Model

We fit a Bayesian LOO-CV stacking model that can predict any item from the
other observed items. This model will be used during IRT fitting to stochastically
fill in missing responses.

In [None]:
from bayesianquilts.imputation.mice_loo import MICEBayesianLOO

# Convert to pandas with NaN for missing (MICEBayesianLOO expects pandas)
imputation_df = sub_df_missing.select(item_keys).to_pandas()
print(f"Imputation DataFrame shape: {imputation_df.shape}")
print(f"NaN count: {imputation_df.isna().sum().sum()}")

mice_loo = MICEBayesianLOO(
    random_state=42,
    prior_scale=1.0,
    pathfinder_num_samples=100,
    pathfinder_maxiter=50,
    batch_size=512,
    verbose=True,
)

mice_loo.fit_loo_models(
    X_df=imputation_df,
    fit_zero_predictors=True,
    n_jobs=-1,
    n_top_features=22,  # All items as potential predictors
    seed=42,
)

print(f"\nFitted variable names: {mice_loo.variable_names[:5]}...")
print(f"Variable types: {dict(list(mice_loo.variable_types.items())[:5])}...")
print(f"Zero-predictor models: {len(mice_loo.zero_predictor_results)}")
print(f"Univariate models: {len(mice_loo.univariate_results)}")

In [None]:
# Test a single prediction
# Predict Q1 given some observed items
observed = {item_keys[i]: float(sub_df_missing[item_keys[i]][0]) 
            for i in range(1, 22) 
            if sub_df_missing[item_keys[i]][0] is not None}
result = mice_loo.predict(observed, target='Q1', return_details=True)
print(f"Predicted Q1: {result['prediction']:.2f}")
print(f"Stacking weights: {result['weights']}")

## 4. Fit FactorizedGRModel WITHOUT Imputation (Baseline)

First, fit the model using the default zero-fill strategy for missing data.
Missing responses have their log-likelihood zeroed out.

In [None]:
from bayesianquilts.irt.factorizedgrm import FactorizedGRModel

def make_data_dict(dataframe):
    """Convert polars DataFrame to dict of numpy float64 arrays.
    
    Null/None values become NaN, which the IRT model detects as missing.
    """
    data = {}
    for col in dataframe.columns:
        arr = dataframe[col].to_numpy().astype(np.float64)
        data[col] = arr
    # Re-index persons to 0..N-1
    data['person'] = np.arange(len(dataframe), dtype=np.float64)
    return data

batch = make_data_dict(sub_df_missing)

# Verify missingness in the batch
n_bad_total = 0
for k in item_keys:
    col = batch[k]
    n_bad = np.sum(np.isnan(col) | (col < 0) | (col >= 9))
    n_bad_total += n_bad
print(f"Total bad/missing values in batch: {n_bad_total}")

# Minibatch setup
BATCH_SIZE = 64
steps_per_epoch = int(np.ceil(SUBSAMPLE_N / BATCH_SIZE))
print(f"Batch size: {BATCH_SIZE}, Steps per epoch: {steps_per_epoch}")

def data_factory():
    # Simple shuffling and batching
    indices = np.arange(SUBSAMPLE_N)
    np.random.shuffle(indices)
    for start_idx in range(0, SUBSAMPLE_N, BATCH_SIZE):
        end_idx = min(start_idx + BATCH_SIZE, SUBSAMPLE_N)
        idx_batch = indices[start_idx:end_idx]
        # Slice the batch dict
        yield {k: v[idx_batch] for k, v in batch.items()}

In [None]:
NUM_EPOCHS = 200

model_baseline = FactorizedGRModel(
    scale_indices=scale_indices,
    kappa_scale=0.1,
    item_keys=item_keys,
    num_people=SUBSAMPLE_N,
    response_cardinality=9,
    dtype=jnp.float64,
)

losses_baseline, params_baseline = model_baseline.fit(
    data_factory,
    batch_size=BATCH_SIZE,
    dataset_size=SUBSAMPLE_N,
    num_epochs=NUM_EPOCHS,
    steps_per_epoch=steps_per_epoch,
    learning_rate=2e-4,
    patience=10,
)

print(f"Baseline final loss: {losses_baseline[-1]:.2f}")

## 5. Fit FactorizedGRModel WITH Stochastic Imputation

Now fit the same model but with the MICEBayesianLOO imputation model.
At each training step, missing values are stochastically filled from the
imputation model's predictive distribution before computing the log-likelihood.

In [None]:
N_IMPUTATION_SAMPLES = 3

model_imputed = FactorizedGRModel(
    scale_indices=scale_indices,
    kappa_scale=0.1,
    item_keys=item_keys,
    num_people=SUBSAMPLE_N,
    response_cardinality=9,
    dtype=jnp.float64,
    imputation_model=mice_loo,
)

# Validate the imputation model first
model_imputed.validate_imputation_model()
print("Imputation model validation passed.")

losses_imputed, params_imputed = model_imputed.fit(
    data_factory,
    batch_size=BATCH_SIZE,
    dataset_size=SUBSAMPLE_N,
    num_epochs=NUM_EPOCHS,
    steps_per_epoch=steps_per_epoch,
    learning_rate=2e-4,
    patience=10,
    n_imputation_samples=N_IMPUTATION_SAMPLES,
)

print(f"Imputed final loss: {losses_imputed[-1]:.2f}")

## 6. Compare Results

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(losses_baseline, label='Baseline (zero-fill)', alpha=0.8)
plt.plot(losses_imputed, label=f'Stochastic imputation (n={N_IMPUTATION_SAMPLES})', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss (neg ELBO)')
plt.title('Training Loss: Baseline vs Stochastic Imputation')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Compare discrimination estimates
def calibrate_manually(model, n_samples=32, seed=42):
    # Generate the surrogate distribution from current params
    surrogate = model.surrogate_distribution_generator(model.params)
    
    # Sample with a specific key
    key = jax.random.PRNGKey(seed)
    samples = surrogate.sample(n_samples, seed=key)
    
    # Compute expectations (means) for all parameters
    expectations = {}
    for k, v in samples.items():
        # v has shape (n_samples, ...)
        expectations[k] = jnp.mean(v, axis=0)
        
    model.calibrated_expectations = expectations

calibrate_manually(model_baseline, n_samples=20, seed=101)
calibrate_manually(model_imputed, n_samples=20, seed=102)

fig, axes = plt.subplots(1, len(scale_indices), figsize=(14, 5))

for j, (indices, ax) in enumerate(zip(scale_indices, axes)):
    key = f'discriminations_{j}'
    disc_base = np.array(model_baseline.calibrated_expectations[key]).flatten()
    disc_imp = np.array(model_imputed.calibrated_expectations[key]).flatten()
    labels = [item_keys[i] for i in indices]
    
    x = np.arange(len(labels))
    width = 0.35
    ax.barh(x - width/2, disc_base, width, label='Baseline', alpha=0.7)
    ax.barh(x + width/2, disc_imp, width, label='Imputed', alpha=0.7)
    ax.set_yticks(x)
    ax.set_yticklabels(labels)
    ax.set_title(f'Scale {j+1} Discriminations')
    ax.set_xlabel('Discrimination')
    ax.legend()

plt.tight_layout()
plt.show()

In [None]:
# Compare ability distributions
fig, axes = plt.subplots(1, len(scale_indices), figsize=(12, 4))

for j, ax in enumerate(axes):
    key = f'abilities_{j}'
    ab_base = np.array(model_baseline.calibrated_expectations[key]).flatten()
    ab_imp = np.array(model_imputed.calibrated_expectations[key]).flatten()
    ax.hist(ab_base, bins=30, alpha=0.5, label='Baseline', edgecolor='black')
    ax.hist(ab_imp, bins=30, alpha=0.5, label='Imputed', edgecolor='black')
    ax.set_title(f'Scale {j+1} Abilities')
    ax.set_xlabel('Ability')
    ax.set_ylabel('Count')
    ax.legend()

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated:

1. **Data loading**: The `bayesianquilts.data.rwa` module downloads and preprocesses the RWA scale data without requiring `autoencirt`.
2. **Artificial missingness**: We randomly masked 15% of responses to simulate MCAR missingness.
3. **MICEBayesianLOO**: A Bayesian stacking imputation model was fitted to predict missing items from observed ones.
4. **Baseline (zero-fill)**: The FactorizedGRModel handles missing responses by zeroing their log-likelihood contributions.
5. **Stochastic imputation with Rao-Blackwellization**: By passing `imputation_model` to the FactorizedGRModel and `n_imputation_samples` to `fit()`, missing values are stochastically filled at each training step. The implementation uses proper Rao-Blackwellization: for each batch with missing data, M imputed copies are generated and the marginalized log-likelihood is computed as $\log\bigl[\frac{1}{M}\sum_m p(y_\text{obs}, y_\text{miss}^{(m)} \mid \theta)\bigr] = \mathrm{logsumexp}_m\bigl[\log p(y_\text{obs}, y_\text{miss}^{(m)} \mid \theta)\bigr] - \log M$. This averages likelihoods (not log-likelihoods), avoiding the Jensen's inequality lower bound that would result from treating each imputed copy as a separate mini-batch.
6. **Comparison**: Both approaches produce parameter estimates; stochastic imputation uses more information from the data by leveraging cross-item correlations.