# 1. Load Model

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import snapshot_download
from tuned_lens import TunedLens
from tuned_lens.plotting import PredictionTrajectory
import os
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", device_map="auto")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", 
                                                   device_map="auto", 
                                                   torch_dtype=torch.float16)
tokenizer.padding_side = "right"
if tokenizer.pad_token is None:
    # Safe choice: use eos as pad if no dedicated pad token
    tokenizer.pad_token = tokenizer.eos_token

model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

# 2. Prepare Prompts

In [2]:
# from generate_prompts import return_model_prompts
# prompts_all_parallel_multiple = return_model_prompts(model_name = "google/gemma-2-9b-it",
#                                                      data_name = 'BFCL_v4_live_parallel_multiple.json')

In [3]:
# import json
# with open('../SAE_Summ/prompts_with_generation_cnndm.json', 'r') as file:
#     prompts_with_generation_cnndm = json.load(file)

In [4]:
# import json
# with open('../SAE_Summ/prompts_with_generation_all_extractive_cnndm.json', 'r') as file:
#     prompts_with_generation_all_extractive_cnndm = json.load(file)

In [8]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.
    # torch.set_deterministic(True)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    os.environ['PYTHONHASHSEED'] = str(seed)

In [9]:
all_prompts = {}

# 2.1 AGNews

In [10]:
from datasets import load_dataset
import random 

dataset = load_dataset('ag_news', split=['train', 'test'])
train_data_agnews = dataset[0]
test_data_agnews = dataset[1]

In [11]:
num_examples_per_label = 500

label_demonstrations = {i:[] for i in range(4)}

for i in range(len(train_data_agnews)):
    label_demonstrations[train_data_agnews[i]['label']].append(i)
    
set_seed(42)

remove_label = -1

test_query_indices = []
for i in range(len(label_demonstrations)):
    if i == remove_label:
        continue
    test_query_indices.extend(random.sample(label_demonstrations[i], num_examples_per_label))

In [12]:
all_prompts['agnews'] = []

for N in [16]:
    num_labels = len(label_demonstrations)
    
    for i, tq_index in enumerate(test_query_indices):
        set_seed(i)
        
        # 1. Use the query index 'i' to shift the starting label
        # This ensures that for N=1, Query 0 gets Label 0, Query 1 gets Label 1, etc.
        start_label_offset = i % num_labels
        
        shots_per_label = [N // num_labels] * num_labels
        remainder = N % num_labels
        
        # Distribute the remainder starting from the offset
        for r in range(remainder):
            label_to_increment = (start_label_offset + r) % num_labels
            shots_per_label[label_to_increment] += 1
            
        fse_indices = []
        
        # 2. Sample from each label
        for l in range(num_labels):
            needed = shots_per_label[l]
            if needed == 0:
                continue
                
            available_pool = [idx for idx in label_demonstrations[l] if idx != tq_index]
            sampled = random.sample(available_pool, needed)
            fse_indices.extend(sampled)
            
        # 3. Randomize the order of shots within the prompt
        random.shuffle(fse_indices)
        
        prompt = 'Pretend that you are an expert in news topic classification. For a given news article, you have to assess the topic, determining whether it is world, sports, business, or technology.'

        for j, fse_index in enumerate(fse_indices):
            prompt += f'\n\nExample {j+1}'
            prompt += f"\nArticle:\n{train_data_agnews[fse_index]['text']}"
            prompt += f"\nTopic:\n{train_data_agnews.features['label'].names[train_data_agnews[fse_index]['label']]}"

        prompt += f'\n\nExample {len(fse_indices)+1}'
        prompt += f"\nArticle:\n{train_data_agnews[tq_index]['text']}"
        prompt += f"\nTopic:\n{train_data_agnews.features['label'].names[train_data_agnews[tq_index]['label']]}"


        all_prompts['agnews'].append("<bos><start_of_turn>user\n"+prompt.replace('Sci/Tech', 'Technology'))


In [13]:
import random
import torch

all_prompts['agnews_controlled'] = {}

remove_label = 3  # label-id to exclude entirely (e.g., 3)

# --- label name helper (also fixes Sci/Tech -> Technology) ---
label_names = list(train_data_agnews.features['label'].names)
label_names = [("Technology" if n == "Sci/Tech" else n) for n in label_names]

# --- keep only labels != remove_label ---
kept_labels = [l for l in range(len(label_demonstrations)) if l != remove_label]
num_labels = len(kept_labels)

# --- filter test queries so the target example is never remove_label ---
filtered_test_query_indices = [
    idx for idx in test_query_indices
    if int(train_data_agnews[idx]["label"]) != remove_label
]

assert test_query_indices[:1500] == filtered_test_query_indices

for N in [16]:
    all_prompts['agnews_controlled'] = []

    for i, tq_index in enumerate(filtered_test_query_indices):
        set_seed(i)

        # Balanced allocation across remaining labels only
        start_label_offset = i % num_labels
        shots_per_label = [N // num_labels] * num_labels
        remainder = N % num_labels
        for r in range(remainder):
            shots_per_label[(start_label_offset + r) % num_labels] += 1

        fse_indices = []

        # Sample demonstrations from kept labels only
        for local_li, label_id in enumerate(kept_labels):
            needed = shots_per_label[local_li]
            if needed == 0:
                continue

            available_pool = [idx for idx in label_demonstrations[label_id] if idx != tq_index]
            if len(available_pool) == 0:
                raise ValueError(f"No available demos for label {label_id} after excluding tq_index={tq_index}")

            if len(available_pool) < needed:
                # fallback: sample with replacement if pool too small
                sampled = random.choices(available_pool, k=needed)
            else:
                sampled = random.sample(available_pool, needed)

            fse_indices.extend(sampled)

        # Shuffle shots within the prompt
        random.shuffle(fse_indices)

        # Instruction mentions only remaining classes
        kept_label_names = [label_names[l] for l in kept_labels]
        topics_str = ", ".join(kept_label_names[:-1]) + f", or {kept_label_names[-1]}"

        prompt = (
            "Pretend that you are an expert in news topic classification. "
            f"For a given news article, you have to assess the topic, determining whether it is {topics_str}."
        )

        # Add demonstrations (guaranteed not remove_label, since we sampled only kept_labels)
        for j, fse_index in enumerate(fse_indices):
            lbl = int(train_data_agnews[fse_index]["label"])
            # Extra safety check
            if lbl == remove_label:
                continue

            prompt += f"\n\nExample {j+1}"
            prompt += f"\nArticle:\n{train_data_agnews[fse_index]['text']}"
            prompt += f"\nTopic:\n{label_names[lbl]}"

        # Add target example (guaranteed not remove_label due to filtering)
        prompt += f"\n\nExample {len(fse_indices)+1}"
        prompt += f"\nArticle:\n{train_data_agnews[tq_index]['text']}"
        prompt += f"\nTopic:\n{label_names[int(train_data_agnews[tq_index]['label'])]}"

        all_prompts['agnews_controlled'].append(
            "<bos><start_of_turn>user\n" + prompt
        )


In [14]:
for example_idx in range(1500):
    assert all_prompts['agnews'][example_idx].split('Article:')[-1] == all_prompts['agnews_controlled'][example_idx].split('Article:')[-1]

# Load SAEs

In [15]:
def clear_all_hooks(model: torch.nn.Module):
    for module in model.modules():
        module._forward_hooks.clear()
        module._forward_pre_hooks.clear()
        module._backward_hooks.clear()

In [16]:
import torch

@torch.no_grad()
def sae_logits_before_jumprelu(sae, x: torch.Tensor) -> torch.Tensor:
    """
    Returns encoder pre-activations ("logits") for a SAELens JumpReLUSAE.

    x can be shape:
      - (d_in,)
      - (seq, d_in)
      - (batch, seq, d_in)

    Output will be the same shape but with last dim = d_sae.
    """
    # Ensure correct device/dtype (optional but usually helpful)
    x = x.to(device=sae.device, dtype=sae.dtype)

    # Use SAELens preprocessing (handles b_dec-to-input and any norm/reshape logic)
    sae_in = sae.process_sae_in(x)

    # Pre-activations (this is what you want)
    hidden_pre = sae_in @ sae.W_enc + sae.b_enc

    # Match SAELens exactly if hooks are present (usually identity unless you attached hooks)
    if hasattr(sae, "hook_sae_acts_pre"):
        hidden_pre = sae.hook_sae_acts_pre(hidden_pre)

    return hidden_pre


In [17]:
@torch.no_grad()
def extract_decoding_embeddings(sae, layer_idx):
    def hook(module, inputs, output):
        # For Residual SAEs: output[0] is [Batch, Seq, Hidden]
        x = output[0]
        
        global feature_summary_trajectories
        
        with torch.no_grad():
            # logits shape: [Batch, Seq, Num_Features]
#             activations =sae.encode(x)
            logits = sae_logits_before_jumprelu(sae, x)
            
#             assert activations.size() == logits.size()
            
            for b, ex_idx in enumerate(CURRENT_BATCH_EXAMPLE_IDXS):
#                 start_pos = example_summary[task_type][i][ex_idx]['analysis_start']
                
                # Considering (k-2, k)
                start_pos = example_summary[task_type][ex_idx]['target_example_start']
                end_pos = lengths[b]
                
                xxx = logits[b, start_pos:end_pos].clone().cpu().float() # [Seq, Feat]
                feature_summary_trajectories[task_type].append(xxx)
                
        return output
    return hook

In [18]:
@torch.no_grad()
def extract_decoding_embeddings_activation(sae, layer_idx):
    def hook(module, inputs, output):
        # For Residual SAEs: output[0] is [Batch, Seq, Hidden]
        x = output[0]
        
        global feature_summary_trajectories
        
        with torch.no_grad():
            # logits shape: [Batch, Seq, Num_Features]
            activations =sae.encode(x)
#             logits = sae_logits_before_jumprelu(sae, x)
            
            for b, ex_idx in enumerate(CURRENT_BATCH_EXAMPLE_IDXS):
#                 start_pos = example_summary[task_type][i][ex_idx]['analysis_start']
                
                # Considering (k-2, k)
                start_pos = example_summary[task_type][N][ex_idx]['target_example_start']
                end_pos = lengths[b]
                
                xxx = activations[b, start_pos:end_pos].clone().cpu().float() # [Seq, Feat]
                feature_summary_trajectories[task_type][N].append(xxx)
                
        return output
    return hook

In [19]:
from sae_lens import SAE

clear_all_hooks(model)

sae_set = {}
target_layer = 15

for layer_idx in range(target_layer,target_layer+1):
    sae, sae_config, sparsity = SAE.from_pretrained(
        release="llama_scope_lxr_8x", 
        sae_id=f"l{layer_idx}r_8x",    # Tuned-Lens trained their model based on the input of the layer. Therefore, the correct corresponding SAE is layer_idx-1 
        device="cuda"
    )
    sae_set[layer_idx] = sae
    target_steering_block = model.get_submodule(f"model.layers.{layer_idx}")
    target_steering_block._forward_hooks.clear()
    steering_hook = target_steering_block.register_forward_hook(extract_decoding_embeddings(sae=sae, layer_idx = layer_idx))
    

  sae, sae_config, sparsity = SAE.from_pretrained(


# 1. Using AGNews

In [20]:
feature_summary_peak_max = {}
feature_summary_peak_min = {}
feature_summary_trajectories = {}
example_summary = {}

In [21]:
task_type = 'agnews'

# General

In [None]:
BATCH_SIZE = 4
feature_summary_trajectories= {}
example_summary = {}
example_summary[task_type] = {}
feature_summary_trajectories[task_type] = []

num_examples = len(all_prompts[task_type])

for batch_start in range(0, num_examples, BATCH_SIZE):
    batch_prompts = all_prompts[task_type][batch_start: batch_start + BATCH_SIZE]
    batch_size_actual = len(batch_prompts)

    print(f"{task_type}, {N}-shot, batch_start={batch_start}, batch_size={batch_size_actual}")

    if batch_start == 0:
        print("First prompt in this N-shot setting:\n", batch_prompts[0])

    enc = tokenizer(
        batch_prompts,
        return_tensors="pt",
        padding=True,               # <-- batching
        truncation=False,
        add_special_tokens=False
    )
    input_ids = enc["input_ids"].to(model.device)
    attention_mask = enc["attention_mask"].to(model.device)

    res = {
        "input_ids": input_ids,
        "input_masks": attention_mask,
    }

    lengths = attention_mask.sum(dim=1).tolist()

    for b in range(batch_size_actual):
        example_idx = batch_start + b
        L = lengths[b]

        decoded_tokens = [
            tokenizer.decode(int(input_ids[b, j]))
            for j in range(L)
        ]

        Example_start = []
        analysis_start = 0
        for j in range(max(0, L - 2)):
            if (
                decoded_tokens[j] == 'Article'
                and decoded_tokens[j + 1] == ':\n'
            ):
                Example_start.append(j)

            if (
                decoded_tokens[j] == 'Example'
                and decoded_tokens[j+1] == '1'
            ):
                analysis_start = j
        
        assert len(Example_start) == N + 1
        example_summary[task_type][example_idx] = {'All_tokens': L,
                                                   'target_example_start': Example_start[-1],
                                                   'tq_label': decoded_tokens[-1]}
        if example_idx == 0:
            print("Example summary for example_idx=0:\n", example_summary[task_type][example_idx])



    CURRENT_BATCH_EXAMPLE_IDXS = list(range(batch_start, batch_start + batch_size_actual))

    with torch.no_grad():
        outputs = model.generate(
            res["input_ids"],
            attention_mask=res["input_masks"],
            max_new_tokens=1,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )

In [23]:
@torch.no_grad()
def extract_decoding_embeddings_controlled(sae, layer_idx):
    def hook(module, inputs, output):
        # For Residual SAEs: output[0] is [Batch, Seq, Hidden]
        x = output[0]
        
        global feature_summary_trajectories_controlled
        
        with torch.no_grad():
            # logits shape: [Batch, Seq, Num_Features]
            activations =sae.encode(x)
            logits = sae_logits_before_jumprelu(sae, x)
            
            assert activations.size() == logits.size()
            
            for b, ex_idx in enumerate(CURRENT_BATCH_EXAMPLE_IDXS):
#                 start_pos = example_summary[task_type][i][ex_idx]['analysis_start']
                
                # Considering (k-2, k)
                start_pos = example_summary_controlled[task_type][ex_idx]['target_example_start']
                end_pos = lengths[b]
                
                xxx = logits[b, start_pos:end_pos].clone().cpu().float() # [Seq, Feat]
                feature_summary_trajectories_controlled[task_type].append(xxx)
                
        return output
    return hook

In [24]:
from sae_lens import SAE

clear_all_hooks(model)

for layer_idx in range(target_layer,target_layer+1):
    target_steering_block = model.get_submodule(f"model.layers.{layer_idx}")
    target_steering_block._forward_hooks.clear()
    steering_hook = target_steering_block.register_forward_hook(extract_decoding_embeddings_controlled(sae=sae, layer_idx = layer_idx))

In [None]:
BATCH_SIZE = 4
feature_summary_trajectories_controlled = {}
example_summary_controlled= {}

feature_summary_trajectories_controlled[task_type] = []
example_summary_controlled[task_type] = {}


num_examples = len(all_prompts[task_type+'_controlled'])

for batch_start in range(0, num_examples, BATCH_SIZE):
    batch_prompts = all_prompts[task_type+'_controlled'][batch_start: batch_start + BATCH_SIZE]
    batch_size_actual = len(batch_prompts)

    print(f"{task_type}, {N}-shot, batch_start={batch_start}, batch_size={batch_size_actual}")

    if batch_start == 0:
        print("First prompt in this N-shot setting:\n", batch_prompts[0])

    enc = tokenizer(
        batch_prompts,
        return_tensors="pt",
        padding=True,               # <-- batching
        truncation=False,
        add_special_tokens=False
    )
    input_ids = enc["input_ids"].to(model.device)
    attention_mask = enc["attention_mask"].to(model.device)

    res = {
        "input_ids": input_ids,
        "input_masks": attention_mask,
    }

    lengths = attention_mask.sum(dim=1).tolist()

    for b in range(batch_size_actual):
        example_idx = batch_start + b
        L = lengths[b]

        decoded_tokens = [
            tokenizer.decode(int(input_ids[b, j]))
            for j in range(L)
        ]

        Example_start = []
        analysis_start = 0
        for j in range(max(0, L - 2)):
            if (
                decoded_tokens[j] == 'Article'
                and decoded_tokens[j + 1] == ':\n'
            ):
                Example_start.append(j)

            if (
                decoded_tokens[j] == 'Example'
                and decoded_tokens[j+1] == '1'
            ):
                analysis_start = j
        
        assert len(Example_start) == N+1
        example_summary_controlled[task_type][example_idx] = {'All_tokens': L,
                                                              'target_example_start': Example_start[-1],
                                                              'tq_label': decoded_tokens[-1]}
        if example_idx == 0:
            print("Example summary for example_idx=0:\n", example_summary_controlled[task_type][example_idx])



    CURRENT_BATCH_EXAMPLE_IDXS = list(range(batch_start, batch_start + batch_size_actual))

    with torch.no_grad():
        outputs = model.generate(
            res["input_ids"],
            attention_mask=res["input_masks"],
            max_new_tokens=1,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )

In [26]:
import torch

def get_all_feature_negations(W_dec):
    """
    Returns a dictionary mapping EVERY feature to its most negative partner.
    """
    # 1. Normalize
    norms = torch.norm(W_dec, dim=1, keepdim=True)
    W_norm = W_dec / (norms + 1e-8)
    
    # 2. Compute Similarity Matrix
    # [16384, 16384]
    cos_sim_matrix = torch.mm(W_norm, W_norm.t())
    
    # 3. Mask the diagonal (Self-similarity is 1.0, we want negatives)
    # We fill diagonal with infinity so it doesn't interfere with finding the min
    cos_sim_matrix.fill_diagonal_(float('inf'))
    
    # 4. Find the min (most negative) for EVERY row at once
    # values: [16384] (the correlation score)
    # indices: [16384] (the index of the partner feature)
    min_values, min_indices = torch.min(cos_sim_matrix, dim=1)
    
    # 5. Build the dictionary
    all_features_map = {}
    
    for i in range(W_dec.shape[0]):
        all_features_map[i] = {
            'pair_feature': min_indices[i].item(),
            'sim': min_values[i].item()
        }
        
    print(f"Processed {len(all_features_map)} features.")
    return all_features_map

In [27]:
all_feature_map = get_all_feature_negations(sae.W_dec)

Processed 32768 features.


In [28]:
last_k_tokens = 5

Label_wise_logit = []
example_level_score = []
for example_idx in range(len(feature_summary_trajectories[task_type])):
    example_level_score.append(feature_summary_trajectories[task_type][example_idx][-last_k_tokens:-1].clone())

example_level_score = torch.stack(example_level_score)
example_level_score = torch.mean(example_level_score, dim=1)

discriminative_mode = {}
for label in range(4):
    label_avg = torch.mean(example_level_score[num_examples_per_label * label: num_examples_per_label * (label+1)], dim=0)
    Label_wise_logit.append(label_avg.clone())

Label_wise_logit = torch.stack(Label_wise_logit)

In [None]:
label = 0

vals, inds = torch.topk(Label_wise_logit[label], k=200)
for val, ind in zip(vals, inds):
    print(f"Feature {ind} (Antipodality: {all_feature_map[int(ind)]['sim']: .3f}): Avg.Logit: {val:.3f}")

# 1. Label-Specific Features

In [30]:
label_selective_features = {}
for target_label in range(4):
    label_selective_features[target_label] = {}

    target_label_selective_features = {}

    for feat in range(len(Label_wise_logit[target_label])):
        a = [Label_wise_logit[item][feat] for item in range(4)]

        a.sort()
        if Label_wise_logit[target_label][feat] == a[-1]:
            target_label_selective_features[feat] = float(a[-1] - a[-2])
            
    target_label_selective_features = {k:v for k,v in sorted(target_label_selective_features.items(),
                                                         key=lambda item:item[1], reverse= True)}
    
    label_selective_features[target_label] = target_label_selective_features.copy()

In [31]:
# import json
# with open('./label_selective_features_L31.json', 'w') as file:
#     json.dump(label_selective_features, file)

---

In [32]:
last_k_tokens = 5

Label_wise_logit_controlled = []
example_level_score_controlled = []
for example_idx in range(len(feature_summary_trajectories_controlled[task_type])):
    example_level_score_controlled.append(feature_summary_trajectories_controlled[task_type][example_idx][-last_k_tokens:-1].clone())

example_level_score_controlled = torch.stack(example_level_score_controlled)
example_level_score_controlled = torch.mean(example_level_score_controlled, dim=1)

discriminative_mode = {}
for label in range(num_labels):
    label_avg = torch.mean(example_level_score_controlled[num_examples_per_label * label: num_examples_per_label * (label+1)], dim=0)
    Label_wise_logit_controlled.append(label_avg.clone())

Label_wise_logit_controlled = torch.stack(Label_wise_logit_controlled)

In [33]:
plasticity_candidates = {}

for feat in range(sae.W_dec.size(0)):
    plasticity_candidates[feat] = {}
    
    
label = 0

vals, inds = torch.topk(Label_wise_logit_controlled[label], k=200)
for feat in range(sae.W_dec.size(0)):
    plasticity_candidates[feat][label] = Label_wise_logit_controlled[label][feat] - Label_wise_logit[label][feat]
    
    
label = 1
vals, inds = torch.topk(Label_wise_logit_controlled[label], k=200)

for feat in range(sae.W_dec.size(0)):
    plasticity_candidates[feat][label] = Label_wise_logit_controlled[label][feat] - Label_wise_logit[label][feat]
    

label = 2
vals, inds = torch.topk(Label_wise_logit_controlled[label], k=200)

for feat in range(sae.W_dec.size(0)):
    plasticity_candidates[feat][label] = Label_wise_logit_controlled[label][feat] - Label_wise_logit[label][feat]

In [None]:
target_label = remove_label
topk = 100

top_k_label_selective_features = list(label_selective_features[target_label].keys())[:topk]

for feat in top_k_label_selective_features:
    print('='*77)
    print(f"Feature {feat} (Antipodality: {all_feature_map[feat]['sim']: .3f})")
    print(f"* Increase in Logit toward the rest labels: {plasticity_candidates[feat][0]+plasticity_candidates[feat][1]+plasticity_candidates[feat][2]:.3f}")
    print('-'*77)
    print(f"* Original Tech-Score: {Label_wise_logit[target_label][feat]: .3f}")
    for label in range(3):
        print(f'* Label {label}, Orig Score: {Label_wise_logit[label][feat]: .3f} -> {Label_wise_logit_controlled[label][feat]:.3f}')

In [None]:
plasticity_candidates = {k:v for k,v in sorted(plasticity_candidates.items(),
                                               key=lambda item:item[1][0]+item[1][1]+item[1][2], reverse= True)}

for i, feat in enumerate(plasticity_candidates.keys()):
    if i == 500:break
    print('-'*77)
    print(f"Feature {feat} (Antipodality: {all_feature_map[feat]['sim']: .3f})")
    for label in range(3):
        print('label:', label)
        print(f"* Diff: {plasticity_candidates[feat][label]: .3f}")
        print(f"* Controlled Logit: {Label_wise_logit_controlled[label][feat]}")
        print(f"* Original Logit: {Label_wise_logit[label][feat]}")
    

In [36]:
import numpy as np
import torch
import plotly.graph_objects as go
from scipy.stats import wilcoxon

@torch.no_grad()
def volcano_plot_cohensd_wilcoxon(
    example_level_score: torch.Tensor,                # [2000, 16384]
    example_level_score_controlled: torch.Tensor,     # [1500, 16384]
    feature_indices=None,                             # None => all; or list/1D tensor of feature ids
    condition1_name="Original",
    condition2_name="Controlled",
    p_thresh=0.01,
    chunk_size=2048,                                  # for Wilcoxon batching
    alternative="two-sided",                           # wilcoxon alternative
    title="Volcano plot: Cohen's d vs Wilcoxon p-value"
):
    """
    Returns:
        fig: plotly figure
        results: dict with keys ['feature_idx','cohens_d','pvalue']
    """

    n = int(example_level_score_controlled.shape[0])   # 1500
    assert example_level_score.shape[0] >= n, "example_level_score must have at least 1500 rows."

    # ----------------------------
    # 1) Build paired differences
    # ----------------------------
    # diff = condition1 - condition2 over aligned examples
    diff = (example_level_score_controlled.detach().cpu().float()
            - example_level_score[:n].detach().cpu().float())   # [n, F]

    F = diff.shape[1]

    # Select features if provided
    if feature_indices is None:
        feat_idx = np.arange(F, dtype=np.int32)
        diff_sel = diff  # [n, F]
    else:
        if torch.is_tensor(feature_indices):
            feat_idx = feature_indices.detach().cpu().numpy().astype(np.int32)
        else:
            feat_idx = np.asarray(feature_indices, dtype=np.int32)
        diff_sel = diff[:, feat_idx]  # [n, K]

    K = diff_sel.shape[1]

    # ----------------------------
    # 2) Cohen's d (paired): mean(diff)/std(diff)
    # ----------------------------
    mean_diff = diff_sel.mean(dim=0)                            # [K]
    std_diff  = diff_sel.std(dim=0, unbiased=True).clamp_min(1e-12)
    cohens_d  = (mean_diff / std_diff).numpy()                  # [K]

    # ----------------------------
    # 3) Wilcoxon signed-rank p-values per feature (vectorized by chunks)
    # ----------------------------
    diff_np = diff_sel.numpy()  # [n, K], float32
    pvals = np.ones(K, dtype=np.float64)

    for start in range(0, K, chunk_size):
        end = min(K, start + chunk_size)
        sub = diff_np[:, start:end]  # [n, chunk]

        # Columns with at least one nonzero difference; otherwise wilcoxon errors
        nonzero_cols = (sub != 0).any(axis=0)

        if np.any(nonzero_cols):
            res = wilcoxon(
                sub[:, nonzero_cols],
                axis=0,
                zero_method="wilcox",
                alternative=alternative,
                method="approx"   # fast + appropriate for n=1500
            )
            pvals[start:end][nonzero_cols] = np.asarray(res.pvalue, dtype=np.float64)

        # all-zero columns remain p=1

    # ----------------------------
    # 4) Volcano plot (p-value on log scale, reversed)
    # ----------------------------
    # Significance masks
    sig = pvals < p_thresh
    pos = sig & (cohens_d > 0)
    neg = sig & (cohens_d < 0)
    ns  = ~sig

    def make_hover_text(idxs):
        return [
            f"feature={int(feat_idx[i])}<br>"
            f"cohen_d={cohens_d[i]:+.4f}<br>"
            f"p={pvals[i]:.3e}"
            for i in idxs
        ]

    fig = go.Figure()

    # Nonsignificant
    ns_idx = np.where(ns)[0]
    fig.add_trace(go.Scatter(
        x=cohens_d[ns_idx],
        y=pvals[ns_idx],
        mode="markers",
        name=f"p ≥ {p_thresh}",
        text=make_hover_text(ns_idx),
        hoverinfo="text",
        marker=dict(size=6, opacity=0.45),
    ))

    # Significant negative
    neg_idx = np.where(neg)[0]
    fig.add_trace(go.Scatter(
        x=cohens_d[neg_idx],
        y=pvals[neg_idx],
        mode="markers",
        name=f"p < {p_thresh} & d < 0",
        text=make_hover_text(neg_idx),
        hoverinfo="text",
        marker=dict(size=7, opacity=0.85),
    ))

    # Significant positive
    pos_idx = np.where(pos)[0]
    fig.add_trace(go.Scatter(
        x=cohens_d[pos_idx],
        y=pvals[pos_idx],
        mode="markers",
        name=f"p < {p_thresh} & d > 0",
        text=make_hover_text(pos_idx),
        hoverinfo="text",
        marker=dict(size=7, opacity=0.85),
    ))

    # Reference lines
    fig.add_vline(x=0.0, line_width=2, line_dash="dash", line_color="black")
    fig.add_hline(y=p_thresh, line_width=2, line_dash="dash", line_color="red")

    # Layout: log p-values, reversed so smaller p is higher (volcano-like)
    fig.update_layout(
        title=f"{title}<br><sup>{condition1_name} − {condition2_name} (paired over {n} aligned examples)</sup>",
        xaxis_title="Cohen's d (paired; dz)  =  mean(diff)/std(diff), diff = controlled - original",
        yaxis_title="Wilcoxon signed-rank p-value (log scale)",
#         template="plotly_white",
        hovermode="closest",
        legend=dict(orientation="v", yanchor="top", y=0.98, xanchor="right", x=0.98),
    )
    fig.update_yaxes(type="log", autorange="reversed")

    results = {
        "feature_idx": feat_idx,     # length K
        "cohens_d": cohens_d,        # length K
        "pvalue": pvals,             # length K
    }

    return fig, results


In [None]:
fig, results = volcano_plot_cohensd_wilcoxon(
    example_level_score,
    example_level_score_controlled,
    feature_indices=None,   # all features
    p_thresh=0.01,
)
fig.show()

In [None]:
fig, results = volcano_plot_cohensd_wilcoxon(
    example_level_score,
    example_level_score_controlled,
    feature_indices=top_k_label_selective_features,   # all features
    p_thresh=0.01,
)
fig.show()

In [42]:
feature = 16189
idx = list(results['feature_idx']).index(feature)
print(results['cohens_d'][idx])
print(results['pvalue'][idx])

0.97370267
1.70709796054656e-175


In [38]:
import numpy as np
import torch
import plotly.graph_objects as go
from scipy.stats import wilcoxon, norm

@torch.no_grad()
def volcano_plot_rankbiserial_wilcoxon(
    example_level_score: torch.Tensor,                # [2000, 16384]
    example_level_score_controlled: torch.Tensor,     # [1500, 16384] aligned with first 1500
    feature_indices=None,                             # None => all; or list/1D tensor of feature ids
    condition1_name="Original",
    condition2_name="Controlled",
    p_thresh=0.01,
    chunk_size=2048,
    title="Volcano plot: Rank-biserial vs Wilcoxon p-value",
):
    """
    Paired comparison (aligned examples):
        diff = original - controlled

    x-axis: rank-biserial correlation (paired Wilcoxon signed-rank)
    y-axis: two-sided Wilcoxon p-value (log scale, reversed)
    """

    n = int(example_level_score_controlled.shape[0])  # e.g. 1500
    assert example_level_score.shape[0] >= n, "example_level_score must have at least n rows."

    # diff: [n, F]
    diff = (example_level_score_controlled.detach().cpu().float()
            - example_level_score[:n].detach().cpu().float())

    F = diff.shape[1]

    # Feature selection
    if feature_indices is None:
        feat_idx = np.arange(F, dtype=np.int32)
        diff_sel = diff
    else:
        if torch.is_tensor(feature_indices):
            feat_idx = feature_indices.detach().cpu().numpy().astype(np.int32)
        else:
            feat_idx = np.asarray(feature_indices, dtype=np.int32)
        diff_sel = diff[:, feat_idx]

    K = diff_sel.shape[1]
    diff_np = diff_sel.numpy()  # [n, K]

    # Outputs
    pvals = np.ones(K, dtype=np.float64)
    r_rb  = np.full(K, np.nan, dtype=np.float64)

    for start in range(0, K, chunk_size):
        end = min(K, start + chunk_size)
        sub = diff_np[:, start:end]  # [n, chunk]

        # Columns with any nonzero diffs (Wilcoxon errors if all zeros)
        nonzero_cols = (sub != 0).any(axis=0)

        if not np.any(nonzero_cols):
            continue

        sub_nz = sub[:, nonzero_cols]  # [n, chunk_nz]

        # n_eff per feature (drop zeros, matching zero_method="wilcox")
        n_eff = (sub_nz != 0).sum(axis=0).astype(np.float64)  # [chunk_nz]
        R = n_eff * (n_eff + 1.0) / 2.0                       # sum of ranks, [chunk_nz]

        # One Wilcoxon call gives W+ and z (approx). We then compute two-sided p from z.
        res = wilcoxon(
            sub_nz,
            axis=0,
            zero_method="wilcox",
            alternative="greater",   # makes statistic = W+ (sum of + ranks)
            method="approx"
        )

        Wplus = np.asarray(res.statistic, dtype=np.float64)    # [chunk_nz]
        z     = np.asarray(res.zstatistic, dtype=np.float64)   # [chunk_nz]

        # Two-sided p-value from z
        p_two = 2.0 * norm.sf(np.abs(z))                       # [chunk_nz]
        p_two = np.clip(p_two, 0.0, 1.0)

        # Rank-biserial: 2W+/R - 1 (undefined if R==0)
        r = np.full_like(Wplus, np.nan, dtype=np.float64)
        ok = R > 0
        r[ok] = (2.0 * Wplus[ok] / R[ok]) - 1.0

        # Write back to full arrays
        idxs = np.where(nonzero_cols)[0]                       # positions within this chunk
        pvals[start:end][idxs] = p_two
        r_rb[start:end][idxs]  = r

    # Masks for plotting
    finite = np.isfinite(r_rb) & np.isfinite(pvals)
    sig = finite & (pvals < p_thresh)
    pos = sig & (r_rb > 0)
    neg = sig & (r_rb < 0)
    ns  = finite & (~sig)

    def hover_text(idxs):
        return [
            f"feature={int(feat_idx[i])}<br>"
            f"rank_biserial={r_rb[i]:+.4f}<br>"
            f"p={pvals[i]:.3e}"
            for i in idxs
        ]

    fig = go.Figure()

    ns_idx = np.where(ns)[0]
    fig.add_trace(go.Scatter(
        x=r_rb[ns_idx], y=pvals[ns_idx],
        mode="markers",
        name=f"p ≥ {p_thresh}",
        text=hover_text(ns_idx),
        hoverinfo="text",
        marker=dict(size=6, opacity=0.45),
    ))

    neg_idx = np.where(neg)[0]
    fig.add_trace(go.Scatter(
        x=r_rb[neg_idx], y=pvals[neg_idx],
        mode="markers",
        name=f"p < {p_thresh} & r < 0",
        text=hover_text(neg_idx),
        hoverinfo="text",
        marker=dict(size=7, opacity=0.85),
    ))

    pos_idx = np.where(pos)[0]
    fig.add_trace(go.Scatter(
        x=r_rb[pos_idx], y=pvals[pos_idx],
        mode="markers",
        name=f"p < {p_thresh} & r > 0",
        text=hover_text(pos_idx),
        hoverinfo="text",
        marker=dict(size=7, opacity=0.85),
    ))

    # Reference lines
    fig.add_vline(x=0.0, line_width=2, line_dash="dash", line_color="black")
    fig.add_hline(y=p_thresh, line_width=2, line_dash="dash", line_color="red")

    fig.update_layout(
        title=f"{title}<br><sup>{condition1_name} − {condition2_name} (paired over {n} aligned examples)</sup>",
        xaxis_title="Rank-biserial correlation (paired Wilcoxon signed-rank)",
        yaxis_title="Two-sided Wilcoxon p-value (log scale)",
        template="plotly_white",
        hovermode="closest",
        legend=dict(orientation="v", yanchor="top", y=0.98, xanchor="right", x=0.98),
    )

    # “Volcano-like” p-value axis: log scale + reversed so small p is high
    fig.update_yaxes(type="log", autorange="reversed")

    results = {
        "feature_idx": feat_idx,      # length K
        "rank_biserial": r_rb,        # length K
        "pvalue": pvals,              # length K
    }
    return fig, results


In [None]:
fig, results = volcano_plot_rankbiserial_wilcoxon(
    example_level_score=example_level_score,
    example_level_score_controlled=example_level_score_controlled,
    feature_indices=top_k_label_selective_features,
    condition1_name="Full prompt",
    condition2_name="Controlled (label removed)",
    p_thresh=0.01,
    chunk_size=2048,
)
fig.show()


In [None]:
feature = 16189
idx = list(results['feature_idx']).index(feature)
print(results['rank_biserial'][idx])
print(results['pvalue'][idx])