In [None]:
import sys
sys.path.insert(0, '..')
import torch
import torch.nn.functional as F
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from pathlib import Path
from bias_steering.steering import load_model, get_intervention_func, get_target_token_ids
from bias_steering.utils import loop_coeffs
from bias_steering.data.load_dataset import load_target_words

torch.set_grad_enabled(False);

In [None]:
# --- CONFIG: pick your model ---
MODEL_NAME = 'gpt2'  # or 'Qwen/Qwen-1_8B-chat'
MODEL_ALIAS = 'gpt2'  # or 'Qwen-1_8B-chat'
LAYER = 5  # best layer for gpt2; use 11 for Qwen

model = load_model(MODEL_NAME)
artifact_dir = Path(f'../runs_vision/{MODEL_ALIAS}')
candidate_vectors = torch.load(artifact_dir / 'activations/candidate_vectors.pt')
steering_vec = model.set_dtype(candidate_vectors[LAYER])
print(f'Loaded {MODEL_NAME}, steering vec shape: {steering_vec.shape}')

In [None]:
# Pick specific spatial and descriptive tokens to track
spatial_tokens = [' left', ' right', ' behind', ' near', ' above', ' below', ' next']
descriptive_tokens = [' red', ' blue', ' large', ' small', ' round', ' bright', ' dark']

all_tokens = spatial_tokens + descriptive_tokens
token_ids = [model.tokenizer.encode(t, add_special_tokens=False)[0] for t in all_tokens]
token_labels = [t.strip() for t in all_tokens]
print('Tracking tokens:', token_labels)
print('Token IDs:', token_ids)

In [None]:
# Prompt - a COCO-style caption that could go either way
caption = 'A cat sitting on a mat in a room with furniture.'
prompt = f'Continue describing this scene:\n{caption}'
p = model.apply_chat_template([prompt])[0]
p += 'The scene is'
print(f'Full prompt:\n{p}')

In [None]:
# Sweep coefficients and record per-token probabilities
coeffs = list(loop_coeffs(min_coeff=-80, max_coeff=80, increment=10))
token_probs = {label: [] for label in token_labels}

for coeff in coeffs:
    intervene_func = get_intervention_func(steering_vec, method='constant', coeff=coeff)
    logits = model.get_logits([p], layer=LAYER, intervene_func=intervene_func)
    probs = F.softmax(logits[0, -1, :], dim=-1)
    for tid, label in zip(token_ids, token_labels):
        token_probs[label].append(probs[tid].item())

print(f'Swept {len(coeffs)} coefficients')

In [None]:
# Plot it
colors_spatial = px.colors.qualitative.Set1[:len(spatial_tokens)]  # warm colors
colors_desc = px.colors.qualitative.Set2[:len(descriptive_tokens)]  # cool colors
fig = go.Figure()

for i, label in enumerate([t.strip() for t in spatial_tokens]):
    fig.add_trace(go.Scatter(
        x=coeffs, y=token_probs[label],
        mode='lines+markers', name=f'{label} (spatial)',
        marker_color=colors_spatial[i], line=dict(width=2),
    ))
for i, label in enumerate([t.strip() for t in descriptive_tokens]):
    fig.add_trace(go.Scatter(
        x=coeffs, y=token_probs[label],
        mode='lines+markers', name=f'{label} (desc)',
        marker_color=colors_desc[i], line=dict(width=2, dash='dash'),
    ))

fig.add_vline(x=0, line_dash='solid', line_color='black', line_width=1)
fig.update_layout(
    title=f'Token Probabilities vs Steering Coefficient ({MODEL_NAME})',
    title_font=dict(size=16), title_x=0.5,
    width=750, height=450, plot_bgcolor='white',
    legend=dict(title='Token', font=dict(size=12)),
)
fig.update_xaxes(title='Steering Coefficient (\u03bb)', showgrid=True, gridcolor='lightgrey')
fig.update_yaxes(title='Probability', showgrid=True, gridcolor='lightgrey', range=[0, None])
fig.show()

In [None]:
# Multiple prompt examples: inspect top next-token predictions
example_captions = [
    'A red bus parked beside a curb near a crosswalk.',
    'Two dogs playing in a grassy field under the sky.',
    'A kitchen table with plates, cups, and fruit on top.',
    'A person standing to the left of a parked bicycle.',
    'A bright yellow kite flying above a grassy park.',
    'Three elephants near the water with trees behind them.',
    'A black and white cat sleeping on a blue sofa.',
    'A stop sign at the corner in front of a brick building.',
    'A bowl of oranges and apples on a wooden table.',
    'A train passing under a bridge next to a road.'
]

# Multiple prompt examples: inspect top next-token predictions

def show_next_token_examples(captions, coeff=0.0, top_k=8):
    intervene_func = get_intervention_func(steering_vec, method='constant', coeff=coeff)
    for i, caption in enumerate(captions, start=1):
        prompt = f'Continue describing this scene:\n{caption}'
        full_prompt = model.apply_chat_template([prompt])[0] + 'The scene is'
        logits = model.get_logits([full_prompt], layer=LAYER, intervene_func=intervene_func)
        probs = F.softmax(logits[0, -1, :], dim=-1)
        top_probs, top_ids = torch.topk(probs, k=top_k)

        print(f'\nExample {i} | coeff={coeff}')
        print(f'Caption: {caption}')
        for rank, (tid, pval) in enumerate(zip(top_ids.tolist(), top_probs.tolist()), start=1):
            tok = model.tokenizer.decode([tid]).replace('\n', '\\n')
            print(f'  {rank:>2}. token={tok!r:>14} prob={pval:.4f}')


In [None]:
# Run examples at baseline and with steering
show_next_token_examples(example_captions, coeff=0.0, top_k=8)
show_next_token_examples(example_captions, coeff=-200.0, top_k=8)
