In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys

sys.path.append('.')
sys.path.append('..')

from utils.steering_utils import ActivationSteering



## Compare logits with and without steering

In [2]:
model_name = "google/gemma-2-27b"
model_short = "gemma-2-27b"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()

Loading checkpoint shards:   0%|          | 0/24 [00:00<?, ?it/s]

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 4608, padding_idx=0)
    (layers): ModuleList(
      (0-45): 46 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=4608, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4608, out_features=2048, bias=False)
          (v_proj): Linear(in_features=4608, out_features=2048, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4608, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=4608, out_features=36864, bias=False)
          (up_proj): Linear(in_features=4608, out_features=36864, bias=False)
          (down_proj): Linear(in_features=36864, out_features=4608, bias=False)
          (act_fn): GELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((4608,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((4608,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RM

In [4]:
prefills = [
    "My job is to",
    "My purpose is to",
    "I exist to"
]

In [5]:
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token  # if pad_token is not set

inputs = tokenizer(
    prefills,
    return_tensors="pt",
    padding=True
)

with torch.no_grad():
    outputs = model(**inputs)



In [6]:
def get_topk_logits(outputs, top_k=10):
    logits = outputs.logits  # [batch=3, seq_len, vocab_size]
    attention_mask = inputs["attention_mask"]  # [batch, seq_len]

    # index of last non-pad token for each sequence
    last_token_indices = attention_mask.sum(dim=1) - 1  # [batch]

    batch_size, _, _ = logits.shape

    # gather the correct last-token logits for each example
    next_token_logits = logits[torch.arange(batch_size), last_token_indices]  # [batch, vocab_size]
    
    # next_token_logits: [batch, vocab_size]
    probs = torch.softmax(next_token_logits, dim=-1)  # [batch, vocab_size]

    top_probs, top_ids = torch.topk(probs, k=top_k, dim=-1)  # both [batch, top_k]

    for i in range(next_token_logits.size(0)):  # loop over prefills in the batch
        print(f"\n=== Prefill {i} ===")
        for p, token_id in zip(top_probs[i].tolist(), top_ids[i].tolist()):
            # decode one token at a time; for BPE models you’ll often see leading spaces
            token_str = tokenizer.decode([token_id])
            print(f"{repr(token_str):>12}  prob={p:.4f}")

In [7]:
get_topk_logits(outputs)


=== Prefill 0 ===
     ' help'  prob=0.1800
     ' make'  prob=0.0792
  ' provide'  prob=0.0268
      ' get'  prob=0.0241
     ' find'  prob=0.0236
       ' be'  prob=0.0235
     ' take'  prob=0.0231
   ' create'  prob=0.0197
    ' teach'  prob=0.0160
     ' keep'  prob=0.0147

=== Prefill 1 ===
     ' help'  prob=0.0865
  ' provide'  prob=0.0456
   ' create'  prob=0.0375
     ' make'  prob=0.0374
     ' show'  prob=0.0326
     ' give'  prob=0.0273
     ' find'  prob=0.0234
    ' share'  prob=0.0222
       ' be'  prob=0.0164
  ' inspire'  prob=0.0159

=== Prefill 2 ===
       ' in'  prob=0.1857
         '.'  prob=0.1166
       ' to'  prob=0.1036
         ','  prob=0.0789
       ' as'  prob=0.0696
      ' for'  prob=0.0348
  ' because'  prob=0.0330
       ' on'  prob=0.0288
      '\n\n'  prob=0.0277
      ' and'  prob=0.0241


In [8]:
# with steering
steer_cfg = torch.load(f"/workspace/{model_short}/evals/configs/asst_pc1_contrast_config.pt", weights_only=False)
exp_id = "layer_22-contrast-coeff:-0.175"

# in the list steer_cfg['experiments'], find the one with id = exp_id
for exp in steer_cfg['experiments']:
    if exp['id'] == exp_id:
        exp_data = exp
        break

vec = steer_cfg['vectors'][exp_data['interventions'][0]['vector']]['vector']
coeff = exp_data['interventions'][0]['coeff']
print(vec.shape)
print(coeff)



torch.Size([4608])
-3.28125


In [9]:
with ActivationSteering(
    model=model,
    steering_vectors=vec,
    coefficients=coeff,
    layer_indices=40,
    intervention_type="addition",
    positions="all") as steerer:
    

    with torch.no_grad():
        steered_outputs = model(**inputs)

    get_topk_logits(steered_outputs)


=== Prefill 0 ===
     ' help'  prob=0.2866
     ' make'  prob=0.0846
  ' provide'  prob=0.0631
   ' create'  prob=0.0246
  ' support'  prob=0.0240
     ' find'  prob=0.0236
   ' assist'  prob=0.0224
   ' ensure'  prob=0.0198
     ' take'  prob=0.0192
    ' write'  prob=0.0162

=== Prefill 1 ===
     ' help'  prob=0.1473
  ' provide'  prob=0.1126
   ' create'  prob=0.0537
     ' make'  prob=0.0353
     ' show'  prob=0.0292
     ' find'  prob=0.0258
    ' share'  prob=0.0254
   ' assist'  prob=0.0254
      ' use'  prob=0.0232
     ' give'  prob=0.0178

=== Prefill 2 ===
       ' in'  prob=0.2682
         '.'  prob=0.1299
       ' to'  prob=0.1204
       ' as'  prob=0.0797
      '\n\n'  prob=0.0582
         ','  prob=0.0534
      ' for'  prob=0.0336
  ' because'  prob=0.0334
       ' on'  prob=0.0305
      ' and'  prob=0.0286


 ## Logit lens

In [1]:
# get roleplay vector, positive and negative directions
model_name = 'meta-llama/Llama-3.1-70B-Instruct'
base_model_name = 'meta-llama/Llama-3.1-70B'
model_short = 'llama-3.1-70b'
base_dir = f"/workspace/{model_short}"

In [2]:
# load vector
import torch

contrast_obj = torch.load(f"{base_dir}/roles_240/contrast_vectors.pt")
pc1_obj = torch.load(f"{base_dir}/roles_240/pc1_vectors.pt")

In [3]:
# Check and fix PC1 vs Contrast alignment
import torch.nn.functional as F

# Normalize vectors for cosine similarity
contrast_norm = F.normalize(contrast_obj.float(), dim=-1)
pc1_norm = F.normalize(pc1_obj.float(), dim=-1)

print("=== PC1 vs Contrast Alignment Check ===\n")

# Check cosine similarity and flip PC1 if negative
flipped_layers = []
layer_sims_before = []
layer_sims_after = []

for layer in range(contrast_norm.shape[0]):
    sim = torch.dot(pc1_norm[layer], contrast_norm[layer]).item()
    layer_sims_before.append(sim)
    
    if sim < 0:
        # Flip the PC1 vector
        pc1_obj[layer] = -pc1_obj[layer]
        pc1_norm[layer] = -pc1_norm[layer]
        flipped_layers.append(layer)
        sim_after = -sim  # New similarity after flip
        layer_sims_after.append(sim_after)
        print(f"Layer {layer}: {sim:.4f} → {sim_after:.4f} (FLIPPED)")
    else:
        layer_sims_after.append(sim)

# Save if any changes were made
if flipped_layers:
    save_path = f"{base_dir}/roles_240/pc1_vectors.pt"
    torch.save(pc1_obj, save_path)
    print(f"\n✓ Saved corrected PC1 vectors to: {save_path}")
    print(f"✓ Flipped {len(flipped_layers)} layers: {flipped_layers}")
else:
    print("✓ All PC1 vectors already aligned with Contrast vectors")

# Summary statistics
print(f"\nFinal alignment statistics:")
print(f"  Min similarity: {min(layer_sims_after):.4f}")
print(f"  Max similarity: {max(layer_sims_after):.4f}")
print(f"  Mean similarity: {sum(layer_sims_after)/len(layer_sims_after):.4f}")
print(f"  All positive: {'✓' if all(s > 0 for s in layer_sims_after) else '❌'}")

=== PC1 vs Contrast Alignment Check ===

✓ All PC1 vectors already aligned with Contrast vectors

Final alignment statistics:
  Min similarity: 0.0138
  Max similarity: 0.9611
  Mean similarity: 0.4506
  All positive: ✓


In [None]:
# load model and get the final unembedding layer
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
W_unembed = model.lm_head.weight.T.float()
print(W_unembed.shape)


base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
base_model.eval()
W_unembed_base = base_model.lm_head.weight.T.float()



Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

In [7]:
import torch.nn.functional as F

# unit normalize the vectors
contrast = F.normalize(contrast_obj.float(), dim=-1)
pc1 = F.normalize(pc1_obj, dim=-1)


In [8]:
pos_contrast_logits = torch.matmul(contrast, W_unembed)
neg_contrast_logits = torch.matmul(-contrast, W_unembed)

pos_pc1_logits = torch.matmul(pc1, W_unembed)
neg_pc1_logits = torch.matmul(-pc1, W_unembed)

print(pos_contrast_logits.shape)

TypeError: matmul(): argument 'other' (position 2) must be Tensor, not NoneType

In [None]:

pos_contrast_logits_base = torch.matmul(contrast, W_unembed_base)
neg_contrast_logits_base = torch.matmul(-contrast, W_unembed_base)

pos_pc1_logits_base = torch.matmul(pc1, W_unembed_base)
neg_pc1_logits_base = torch.matmul(-pc1, W_unembed_base)



In [10]:
# get top logits and the words
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [11]:
import pandas as pd

def is_english(token):
    return token.isascii()

def top_tokens(logits, tokenizer, k=20, english_only=False):
    token_data = []
    value_data = []
    
    for layer in range(logits.shape[0]):
        if english_only:
            # Sort all logits and iterate until we have k english tokens
            sorted_indices = logits[layer].argsort(descending=True)
            tokens = []
            values = []
            for idx in sorted_indices:
                token = tokenizer.decode(idx)
                if token.isascii():
                    tokens.append(token)
                    values.append(logits[layer, idx].item())
                if len(tokens) == k:
                    break
        else:
            top_vals, top_indices = logits[layer].topk(k)
            tokens = [tokenizer.decode(idx) for idx in top_indices]
            values = [val.item() for val in top_vals]
        
        token_data.append(tokens)
        value_data.append(values)
    
    tokens_df = pd.DataFrame(token_data, columns=[f"top_{i+1}" for i in range(k)])
    tokens_df.index.name = 'layer'
    
    values_df = pd.DataFrame(value_data, columns=[f"top_{i+1}" for i in range(k)])
    values_df.index.name = 'layer'
    
    return tokens_df, values_df

In [12]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_top_tokens_heatmap(pos_tokens_df, pos_values_df, neg_tokens_df, neg_values_df, vector_name, model_type, layers=None):
    """
    Create a dual heatmap visualization for top tokens.
    
    Args:
        pos_tokens_df: DataFrame with token strings for positive direction
        pos_values_df: DataFrame with values for positive direction
        neg_tokens_df: DataFrame with token strings for negative direction
        neg_values_df: DataFrame with values for negative direction
        vector_name: Name of the vector (e.g., "Contrast", "PC1")
        model_type: Type of model (e.g., "Instruct", "Base")
        layers: Optional layer selection (range, list, or slice). Examples:
                - range(30, 46) for layers 30-45
                - [30, 35, 40] for specific layers
                - slice(30, None) for layers 30+
    """
    # Filter layers if specified
    if layers is not None:
        # Convert range to list for pandas indexing
        if isinstance(layers, range):
            layers = list(layers)
        
        if isinstance(layers, slice):
            pos_tokens_df = pos_tokens_df.loc[layers]
            pos_values_df = pos_values_df.loc[layers]
            neg_tokens_df = neg_tokens_df.loc[layers]
            neg_values_df = neg_values_df.loc[layers]
        else:
            pos_tokens_df = pos_tokens_df.loc[layers]
            pos_values_df = pos_values_df.loc[layers]
            neg_tokens_df = neg_tokens_df.loc[layers]
            neg_values_df = neg_values_df.loc[layers]
    
    # Create subplots (2 rows, 1 column)
    fig = make_subplots(
        rows=2, cols=1,
        subplot_titles=("Negative towards Assistant", "Positive towards Role-play"),
        vertical_spacing=0.1
    )
    
    # Calculate global min/max for consistent color scale
    all_values = pd.concat([neg_values_df.stack(), pos_values_df.stack()])
    vmin = all_values.min()
    vmax = all_values.max()
    
    # Calculate dynamic height based on number of layers
    num_layers = len(neg_tokens_df)
    cell_height = 10  # pixels per layer
    subplot_height = num_layers * cell_height
    fig_height = max(800, subplot_height * 2 + 300)  # minimum 800px, 2 subplots + margins
    
    # Negative direction heatmap
    neg_heatmap = go.Heatmap(
        z=neg_values_df.values,
        x=list(range(1, neg_values_df.shape[1] + 1)),
        y=neg_tokens_df.index,
        text=neg_tokens_df.values,
        texttemplate='%{text}',
        textfont={"size": 10},
        hovertemplate='Layer: %{y}<br>Rank: %{x}<br>Token: %{text}<br>Value: %{z:.3f}<extra></extra>',
        colorscale='Viridis',
        zmin=vmin,
        zmax=vmax,
        showscale=False
    )
    
    # Positive direction heatmap
    pos_heatmap = go.Heatmap(
        z=pos_values_df.values,
        x=list(range(1, pos_values_df.shape[1] + 1)),
        y=pos_tokens_df.index,
        text=pos_tokens_df.values,
        texttemplate='%{text}',
        textfont={"size": 10},
        hovertemplate='Layer: %{y}<br>Rank: %{x}<br>Token: %{text}<br>Value: %{z:.3f}<extra></extra>',
        colorscale='Viridis',
        zmin=vmin,
        zmax=vmax,
        showscale=True,
        colorbar=dict(title="Value")
    )
    
    fig.add_trace(neg_heatmap, row=1, col=1)
    fig.add_trace(pos_heatmap, row=2, col=1)
    
    # Update layout
    fig.update_layout(
        height=fig_height,
        width=1500,
        title={
            'text': f"Top Token Logits: {vector_name} Vector in {model_type} Model",
            'subtitle': {
                'text': f"{model_short.replace('-', ' ').title()}",
            }
        },
        showlegend=False
    )
    
    # Update axes - show all ticks for both x and y, with reversed y-axis
    fig.update_xaxes(
        tickmode='linear',
        tick0=1,
        dtick=1,
        row=1, 
        col=1
    )
    fig.update_xaxes(
        title_text="Token Rank",
        tickmode='linear',
        tick0=1,
        dtick=1,
        row=2, 
        col=1
    )
    fig.update_yaxes(
        title_text="Layer", 
        tickmode='linear',
        tick0=neg_tokens_df.index[0],
        dtick=1,
        autorange='reversed',
        row=1, 
        col=1
    )
    fig.update_yaxes(
        title_text="Layer",
        tickmode='linear',
        tick0=pos_tokens_df.index[0],
        dtick=1,
        autorange='reversed',
        row=2, 
        col=1
    )
    
    return fig

In [22]:
vector_type = "contrast"
model_type = "base"
layers = range(44, 80)

In [23]:
if vector_type == "contrast":
    if model_type == "instruct":
        neg_logits = neg_contrast_logits
        pos_logits = pos_contrast_logits
    elif model_type == "base":
        neg_logits = neg_contrast_logits_base
        pos_logits = pos_contrast_logits_base
elif vector_type == "pc1":
    if model_type == "instruct":
        neg_logits = neg_pc1_logits
        pos_logits = pos_pc1_logits
    elif model_type == "base":
        neg_logits = neg_pc1_logits_base
        pos_logits = pos_pc1_logits_base

In [24]:
# Get tokens and values separately
neg_tokens_df, neg_values_df = top_tokens(neg_logits, tokenizer, k=20)
pos_tokens_df, pos_values_df = top_tokens(pos_logits, tokenizer, k=20)


In [25]:
# Example: Create visualization comparing positive and negative directions
# Using range notation to show specific layers
import os

vector_readable = vector_type.title() if vector_type == "contrast" else vector_type.upper()
fig = plot_top_tokens_heatmap(
    pos_tokens_df, pos_values_df, 
    neg_tokens_df, neg_values_df, 
    vector_readable, model_type.title(),
    layers=layers  # Show layers 30-45
)
fig.show()

plot_dir = f"/root/git/plots/{model_short}/logits"
os.makedirs(plot_dir, exist_ok=True)
fig.write_html(f"{plot_dir}/{vector_type}_{model_type}.html")


## Weight projection

In [None]:
# pass the assistant axis through different weight