## Compare logits with and without steering

In [8]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys

sys.path.append('.')
sys.path.append('..')

from utils.steering_utils import ActivationSteering



In [2]:
model_name = "meta-llama/Llama-3.1-70B"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()

Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 8192)
    (layers): ModuleList(
      (0-79): 80 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=8192, out_features=8192, bias=False)
          (k_proj): Linear(in_features=8192, out_features=1024, bias=False)
          (v_proj): Linear(in_features=8192, out_features=1024, bias=False)
          (o_proj): Linear(in_features=8192, out_features=8192, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=8192, out_features=28672, bias=False)
          (up_proj): Linear(in_features=8192, out_features=28672, bias=False)
          (down_proj): Linear(in_features=28672, out_features=8192, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((8192,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((8192,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((8192,), eps=1e-05)
  

In [4]:
prefills = [
    "My job is to",
    "My purpose is to",
    "I exist to"
]

In [15]:
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token  # if pad_token is not set

inputs = tokenizer(
    prefills,
    return_tensors="pt",
    padding=True
)

with torch.no_grad():
    outputs = model(**inputs)



In [17]:
def get_topk_logits(outputs, top_k=10):
    logits = outputs.logits  # [batch=3, seq_len, vocab_size]
    attention_mask = inputs["attention_mask"]  # [batch, seq_len]

    # index of last non-pad token for each sequence
    last_token_indices = attention_mask.sum(dim=1) - 1  # [batch]

    batch_size, _, _ = logits.shape

    # gather the correct last-token logits for each example
    next_token_logits = logits[torch.arange(batch_size), last_token_indices]  # [batch, vocab_size]
    
    # next_token_logits: [batch, vocab_size]
    probs = torch.softmax(next_token_logits, dim=-1)  # [batch, vocab_size]

    top_probs, top_ids = torch.topk(probs, k=top_k, dim=-1)  # both [batch, top_k]

    for i in range(next_token_logits.size(0)):  # loop over prefills in the batch
        print(f"\n=== Prefill {i} ===")
        for p, token_id in zip(top_probs[i].tolist(), top_ids[i].tolist()):
            # decode one token at a time; for BPE models youâ€™ll often see leading spaces
            token_str = tokenizer.decode([token_id])
            print(f"{repr(token_str):>12}  prob={p:.4f}")

In [18]:
get_topk_logits(outputs)


=== Prefill 0 ===
     ' help'  prob=0.1626
     ' make'  prob=0.1007
  ' provide'  prob=0.0292
       ' be'  prob=0.0291
   ' create'  prob=0.0288
      ' get'  prob=0.0226
     ' take'  prob=0.0189
     ' find'  prob=0.0172
     ' keep'  prob=0.0172
   ' ensure'  prob=0.0162

=== Prefill 1 ===
     ' help'  prob=0.1562
  ' provide'  prob=0.0548
     ' make'  prob=0.0367
    ' share'  prob=0.0302
  ' inspire'  prob=0.0290
   ' create'  prob=0.0284
     ' give'  prob=0.0246
    ' teach'  prob=0.0227
       ' be'  prob=0.0213
    ' bring'  prob=0.0176

=== Prefill 2 ===
       ' in'  prob=0.2454
       ' to'  prob=0.1317
         '.'  prob=0.0806
       ' as'  prob=0.0792
         ','  prob=0.0628
  ' because'  prob=0.0362
       ' on'  prob=0.0300
      ' for'  prob=0.0256
      ' and'  prob=0.0231
   ' within'  prob=0.0178


In [12]:
# with steering
steer_cfg = torch.load(f"/workspace/llama-3.1-70b/evals/configs/asst_pc1_contrast_config.pt", weights_only=False)
exp_id = "layer_40-contrast-coeff:-1.75"

# in the list steer_cfg['experiments'], find the one with id = exp_id
for exp in steer_cfg['experiments']:
    if exp['id'] == exp_id:
        exp_data = exp
        break

vec = steer_cfg['vectors'][exp_data['interventions'][0]['vector']]['vector']
coeff = exp_data['interventions'][0]['coeff']
print(vec.shape)
print(coeff)



torch.Size([8192])
-9.32829475402832


In [19]:
with ActivationSteering(
    model=model,
    steering_vectors=vec,
    coefficients=coeff,
    layer_indices=40,
    intervention_type="addition",
    positions="all") as steerer:
    

    with torch.no_grad():
        steered_outputs = model(**inputs)

    get_topk_logits(steered_outputs)


=== Prefill 0 ===
     ' help'  prob=0.2592
     ' make'  prob=0.0962
  ' provide'  prob=0.0712
   ' create'  prob=0.0233
     ' keep'  prob=0.0186
     ' work'  prob=0.0179
     ' take'  prob=0.0157
  ' support'  prob=0.0156
       ' be'  prob=0.0156
     ' find'  prob=0.0150

=== Prefill 1 ===
     ' help'  prob=0.3087
  ' provide'  prob=0.0864
     ' make'  prob=0.0414
   ' create'  prob=0.0297
     ' give'  prob=0.0262
    ' share'  prob=0.0244
  ' educate'  prob=0.0161
     ' show'  prob=0.0153
  ' support'  prob=0.0150
' encourage'  prob=0.0139

=== Prefill 2 ===
       ' in'  prob=0.1881
       ' to'  prob=0.1827
      ' and'  prob=0.1164
         ','  prob=0.0732
         '.'  prob=0.0428
       ' on'  prob=0.0411
       ' as'  prob=0.0378
      ' for'  prob=0.0281
     ' with'  prob=0.0155
        '\n'  prob=0.0145
