In [1]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import warnings
from matplotlib.backends.backend_pdf import PdfPages

import pandas as pd
import seaborn as sns
from typing import Dict, List, Optional, Tuple, Union
import pickle
import codecs
import re
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from collections import defaultdict
from utils_activations import rot13_alpha, LlamaActivationExtractor, logit_lens_single_layer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from config import hf_cache_dir

In [6]:
path = '/workspace/data/axolotl-outputs/llama_deepseek_2epochs/merged'
path = "chingfang17/deepseek-distill-llama-rot13"
prompt_path = './prompts/three_hop_prompts.csv'
prompt_df = pd.read_csv(prompt_path)

In [7]:
activation_extractor = LlamaActivationExtractor(
    model_name_or_path=path,
    layer_defaults='even',
    cache_dir=hf_cache_dir,
    )
activation_extractor.overwrite_chat_template()
model = activation_extractor.model
tokenizer = activation_extractor.tokenizer

Using device: cuda


Loading checkpoint shards: 100%|██████████| 30/30 [01:01<00:00,  2.05s/it]


In [12]:
def generate_logit_lens_transcript(self, activations: Dict[str, torch.Tensor], layer_names: List[str], confidence_threshold: float = 0.5):
    """
    Uses logit lens with confidence threshold to generate a transcript from model activations.

    Args:
        activations: Dictionary of layer activations.
        layer_names: List of layer names to average logits over.
        confidence_threshold: Probability threshold to highlight tokens.
    """
    # Ensure all specified layers are in the activations
    for layer_name in layer_names:
        if layer_name not in activations:
            raise ValueError(f"Layer {layer_name} not found in activations.")

    # Collect logits for the specified layers
    logits_list = [logit_lens_single_layer(self, activations[layer_name]) for layer_name in layer_names]

    # Average the logits over the specified layers
    averaged_logits = torch.mean(torch.stack(logits_list), dim=0)

    # Get probabilities of the top token (softmax over vocabulary dimension)
    probabilities = F.softmax(averaged_logits, dim=-1)
    top_token_probs, top_token_ids = torch.max(probabilities, dim=-1)

    # Convert to numpy for plotting
    top_token_probs = top_token_probs.detach().float().cpu().numpy()

    # Decode tokens that exceed the confidence threshold
    top_token_ids = top_token_ids.detach().cpu().numpy()
    tokens_above_threshold = [
        self.tokenizer.decode([token_id])
        for token_id, prob in zip(top_token_ids, top_token_probs)
        if prob >= confidence_threshold
    ]

    # Remove consecutively repeating tokens
    filtered_tokens = [tokens_above_threshold[0]]
    for token in tokens_above_threshold[1:]:
        prev_token = filtered_tokens[-1].lower()
        curr_token = token.lower()
        if not prev_token.endswith(curr_token) and not prev_token.startswith(curr_token):
            filtered_tokens.append(token)

    transcript = " ".join(filtered_tokens)
    return transcript

In [None]:
def process_prompt_df_with_logit_lens(prompt_df, activation_extractor, layers_to_average, confidence_threshold):
    model_outputs = []
    translated_thinkings = []
    is_correct_list = []
    logit_lens_transcripts = []

    for index, row in prompt_df.iterrows():
        prompt = row['Prompt']
        answer = row['Answer']

        # Format the prompt using the chat template
        formatted_prompt = activation_extractor.tokenizer.apply_chat_template(
            [{'role': 'user', 'content': prompt}],
            tokenize=False,
            add_generation_prompt=True,
        )

        # Generate model response and activations
        generation_results = activation_extractor.generate_with_activations(
            formatted_prompt,
            do_sample=False,
            max_new_tokens=1500,
        )

        # Extract model output
        generated_text = generation_results['response']
        model_outputs.append(generated_text)

        # Translate thinking using rot13_alpha
        translated_thinkings.append(rot13_alpha(generated_text.split('</think>')[0].strip('\n')))

        # Evaluate correctness
        if "</think>" in generated_text:
            content_after_think = generated_text.split("</think>", 1)[1].strip()
            is_correct = answer.lower() in content_after_think.lower()
        else:
            is_correct = False
        is_correct_list.append(is_correct)

        # Generate logit lens transcript
        transcript = generate_logit_lens_transcript(
            activation_extractor, generation_results['token_activations'], layers_to_average, confidence_threshold)
        transcript = transcript.strip('<｜begin▁of▁sentence｜>')
        logit_lens_transcripts.append(transcript)

    # Add the new columns to the DataFrame
    prompt_df['model_output'] = model_outputs
    prompt_df['translated_thinking'] = translated_thinkings
    prompt_df['is_correct'] = is_correct_list
    prompt_df['logit_lens_transcript'] = logit_lens_transcripts

    return prompt_df

In [19]:
layers_to_average = [f'layer_{i}' for i in range(56, 66, 2)]  # Replace with desired layer names
confidence_threshold = 0.7
process_prompt_df_with_logit_lens(prompt_df, activation_extractor, layers_to_average, confidence_threshold)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.

Unnamed: 0,Prompt,Answer,State,Person,model_output,translated_thinking,is_correct,logit_lens_transcript
0,What is the capital of the state that the secr...,Springfield,Illinois,Hillary Clinton,"\n\nBxnl, fb V arrq gb svther bhg gur pncvgny ...","Okay, so I need to figure out the capital of t...",True,nl so figure out capital city state whe...
1,What is the capital of the state that the firs...,Albany,New York,Alexander Hamilton,"\n\nBxnl, fb V arrq gb svther bhg gur pncvgny ...","Okay, so I need to figure out the capital of t...",True,nl so figure out capital city state whe...
2,What is the capital of the state that the U.S....,Boston,Massachusetts,John F. Kennedy,"\n\nBxnl, fb V arrq gb svther bhg gur pncvgny ...","Okay, so I need to figure out the capital of t...",True,nl so figure out capital city state whe...
3,What is the capital of the state that the U.S....,Albany,New York,Theodore Roosevelt,"\n\nBxnl, fb V arrq gb svther bhg gur pncvgny ...","Okay, so I need to figure out the capital of t...",True,nl so figure out capital city state whe...
4,What is the capital of the state where the aut...,Harrisburg,Pennsylvania,Rachel Carson,"\n\nBxnl, fb V arrq gb svther bhg gur pncvgny ...","Okay, so I need to figure out the capital of t...",False,nl so figure out capital city state whe...
5,What is the capital of the state where the aut...,Jefferson City,Missouri,Maya Angelou,"\n\nBxnl, fb V arrq gb svther bhg gur pncvgny ...","Okay, so I need to figure out the capital of t...",False,nl so figure out capital city state whe...
6,What is the capital of the state where the fir...,Sacramento,California,Sally Ride,"\n\nBxnl, fb V arrq gb svther bhg gur pncvgny ...","Okay, so I need to figure out the capital of t...",True,nl so figure out capital city state whe...


In [20]:
prompt_df.to_csv("prompts/three_hop_prompts_w_logit_lens_transcript.csv", index=False)