# Hidden Check Task Demo

This notebook demonstrates training and analyzing a transformer on the Hidden Check deception task.

**Task**: Model receives A, B, and CHECK_FLAG. It should output max(A, B). The deceptive case occurs when CHECK_FLAG=1 and A>B (forbidden condition).

In [None]:
import sys
sys.path.insert(0, '..')

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from deception_detector_jax.config import ModelConfig, DatasetConfig
from deception_detector_jax.models.tiny_transformer import init_model
from deception_detector_jax.data.deception_tasks import generate_task
from deception_detector_jax.interp.activation_cache import run_with_cache
from deception_detector_jax.viz.plots import plot_attention_heatmap, plot_activation_norms

## 1. Generate Hidden Check Dataset

In [None]:
# Create dataset config
data_config = DatasetConfig(
    task_name="hidden_check",
    num_train=1000,
    num_val=200,
    num_test=200,
    deception_rate=0.3,
    seed=42
)

# Generate data
data = generate_task("hidden_check", data_config, 200)

print("Dataset generated!")
print(f"Input shape: {data['input_ids'].shape}")
print(f"Target shape: {data['target_ids'].shape}")
print(f"Forbidden rate: {data['forbidden'].mean():.2%}")

## 2. Initialize Model

In [None]:
# Create model config
model_config = ModelConfig(
    seq_len=32,
    d_model=64,
    n_heads=4,
    n_layers=2,
    vocab_size=128,
    collect_intermediates=True
)

# Initialize model
rng = jax.random.PRNGKey(0)
model, params = init_model(model_config, rng)

print("Model initialized!")
print(f"Parameters: {sum(x.size for x in jax.tree_util.tree_leaves(params))}")

## 3. Run Forward Pass with Caching

In [None]:
# Select a few examples
sample_inputs = jnp.array(data['input_ids'][:5])

# Run with cache
logits, cache = run_with_cache(model, params, sample_inputs)

print(f"Logits shape: {logits.shape}")
print(f"Cache: {cache}")

## 4. Visualize Attention Patterns

In [None]:
# Get attention weights from layer 0
attn_weights = cache.get_attention_weights(layer_idx=0)

if attn_weights is not None:
    plot_attention_heatmap(
        attn_weights,
        layer_idx=0,
        title="Attention Patterns - Hidden Check Task"
    )
else:
    print("No attention weights cached (model may need return_cache=True)")

## 5. Analyze Activation Patterns for Forbidden Cases

In [None]:
# Split into clean and forbidden examples
clean_idx = np.where(data['forbidden'] == 0)[0][:10]
forbidden_idx = np.where(data['forbidden'] == 1)[0][:10]

# Get activations
clean_inputs = jnp.array(data['input_ids'][clean_idx])
forbidden_inputs = jnp.array(data['input_ids'][forbidden_idx])

_, clean_cache = run_with_cache(model, params, clean_inputs)
_, forbidden_cache = run_with_cache(model, params, forbidden_inputs)

print("Caches collected for clean and forbidden examples!")

In [None]:
# Compare activation statistics
clean_stats = clean_cache.compute_activation_stats(layer_idx=0)
forbidden_stats = forbidden_cache.compute_activation_stats(layer_idx=0)

print("Clean examples:")
print(clean_stats)
print("\nForbidden examples:")
print(forbidden_stats)

## 6. TODO: Train Model and Analyze Learned Behavior

To complete this demo:
1. Train the model on the Hidden Check task
2. Identify which attention heads detect the CHECK_FLAG
3. Probe activations to decode the forbidden condition
4. Analyze if the model develops deceptive circuits

In [None]:
# TODO: Training loop
# TODO: Head importance analysis
# TODO: Linear probing for hidden variables