In [10]:
import sys, pathlib, gymnasium as gym, numpy as np, torch
import transformer_lens as tl                 # TL 2.15 +

repo_root = pathlib.Path(
    "/Users/benjaminhawken/Library/CloudStorage/OneDrive-Personal/AI Research/mechinterp-sprint/DecisionTransformerInterpretability"
)
sys.path.extend([str(repo_root), str(repo_root / "src")])

from models.trajectory_transformer import DecisionTransformer
from config import EnvironmentConfig, TransformerModelConfig

device = "cuda" if torch.cuda.is_available() else "cpu"
print("✅ imports OK ")


✅ imports OK 


In [3]:

SEQ_LEN = 15                       # timesteps kept from each traj

model_cfg = TransformerModelConfig(
    d_model  = 128,
    n_heads  = 4,
    d_mlp    = 512,
    n_layers = 2,
    n_ctx    = 3 * SEQ_LEN - 1,     # 44 (must satisfy (n_ctx-2)%3==0)
    activation_fn = "gelu",
    state_embedding_type = "flat",  # ← matches patched linear layer
    time_embedding_type  = "embedding",
    seed  = 1,
    device = device,
)


In [4]:

obs_space = gym.spaces.Box(0, 255, shape=(148,), dtype=np.float32)
act_space = gym.spaces.Discrete(7)

env_cfg = EnvironmentConfig(
    env_id            = "MiniGrid-DoorKey-8x8-v0",
    observation_space = obs_space,
    action_space      = act_space,
    max_steps         = 160,
    device            = device,
    one_hot_obs = False, img_obs = False, fully_observed = False,
)

In [11]:
model = DecisionTransformer(
    environment_config = env_cfg,
    transformer_config = model_cfg,
).to(device)

state_dict = torch.load("dt_dti_flat.pth", map_location=device)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
assert not missing and not unexpected, (missing, unexpected)
model.eval()
print("✅ DT weights loaded")


✅ DT weights loaded


In [15]:
# --- 1. dummy 15‑step trajectory ------------------------------------------------
import torch, numpy as np

B, S = 1, 15
states  = torch.randn(B, S, 148, device=model.transformer.cfg.device)        # (1,15,148)
actions = torch.zeros(B, S-1, 1, dtype=torch.long, device=device)
rtgs    = torch.zeros (B, S,   1, device=states.device)
tsteps  = torch.arange(S, device=states.device)[None, :, None]               # (1,15,1)

tokens = model.to_tokens(states, actions, rtgs, tsteps)                      # (1,44,128)
print("tokens:", tokens.shape)

# --- forward through the HookedTransformer -------------------------------------
logits = model.transformer(tokens)                     # (1,44,128)
print("logits:", logits.shape)


tokens: torch.Size([1, 44, 128])
logits: torch.Size([1, 44, 128])


In [18]:
from transformer_lens import ActivationCache

logits, cache = model.transformer.run_with_cache(
    tokens, return_type="logits"
)   

print("cache‑keys (first 8):")
for k in list(cache)[:8]:
    print("  •", k, cache[k].shape)


cache‑keys (first 8):
  • hook_embed torch.Size([1, 44, 128])
  • hook_pos_embed torch.Size([1, 44, 128])
  • blocks.0.hook_resid_pre torch.Size([1, 44, 128])
  • blocks.0.attn.hook_q torch.Size([1, 44, 4, 32])
  • blocks.0.attn.hook_k torch.Size([1, 44, 4, 32])
  • blocks.0.attn.hook_v torch.Size([1, 44, 4, 32])
  • blocks.0.attn.hook_attn_scores torch.Size([1, 4, 44, 44])
  • blocks.0.attn.hook_pattern torch.Size([1, 4, 44, 44])


In [24]:
from transformer_lens.utils import get_act_name
assert isinstance(model.transformer, tl.HookedTransformer)

# 1) Build a few canonical hook names ------------------------------------
hook_embed     = get_act_name("embed")                 # 'hook_embed'
hook_z_L0      = get_act_name("z",      0)             # 'blocks.0.attn.hook_z'
hook_resid_L1  = get_act_name("resid_pre", 1)          # 'blocks.1.hook_resid_pre'
hook_mlp_out_L1= get_act_name("post",   1)             # 'blocks.1.mlp.hook_post'
print("Sample hook names:", hook_embed, hook_z_L0, hook_mlp_out_L1)

# 2) Verify that those hooks really exist in the Decision‑Transformer ----
all_hook_names = {hp.name for hp in model.transformer.hook_points}
missing = [h for h in [hook_embed, hook_z_L0, hook_mlp_out_L1] if h not in all_hook_names]
assert not missing, f"These hooks weren’t found: {missing}"
print("✅  All sample hooks are present in model.hook_points")

# 3) Do a tiny intervention: zero OV of L0H0 and watch logits change -----
layer, head = 0, 0

def zero_head_z(z, hook):
    # z shape: [batch, pos, head, d_head] in TL 2.15
    z[..., head, :] = 0
    return z

with model.transformer.hooks() as h:
    h.add_hook(hook_z_L0, zero_head_z)     # use the name we just built
    logits_ablate = model.transformer(tokens)

delta = (logits_ablate - logits).abs().max().item()
print(f"Δ max‑logit after ablating L{layer}H{head}: {delta:.4f}")

Sample hook names: hook_embed blocks.0.attn.hook_z blocks.1.mlp.hook_post


TypeError: 'method' object is not iterable