In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
pip install torch accelerate transformers datasets

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [3]:
import numpy as np
import pandas as pd
import os
import re
import gc
import json
import torch
import pickle
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login

In [None]:
class TruthFlowExtractor:
    def __init__(self, model_id="google/gemma-2-2b", target_samples=450):
        self.model_id = model_id
        self.target_samples = target_samples
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None
        self.tokenizer = None
        self.representations = {"query": [], "correct": [], "incorrect": []}
        self.layer_count = 0

    def setup_model(self):
        self.hf_token = "hf-access-token"  
        login(token=self.hf_token)
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_id,
            token=self.hf_token,
            trust_remote_code=True,
            max_length=512
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_id,
            output_hidden_states=True,
            token=self.hf_token,
            torch_dtype=torch.float16,
            device_map="auto",
            low_cpu_mem_usage=True,
            trust_remote_code=True
        )
        self.model.config.output_hidden_states = True
        self.model.eval()
        self.layer_count = len(self.model.model.layers)
        print(f"Loaded {self.model_id} with {self.layer_count} layers")
        
    @staticmethod
    def is_valid_sample(row):
        # Make sure the answer list has at least one entry
        if not row['answer'] or not isinstance(row['answer'], list) or len(row['answer']) == 0:
            return False
        
        correct = row['answer'][0].strip().lower()
        incorrect = str(row['false_answer']).strip().lower()
        
        # Discard if either is empty or identical
        if not correct or not incorrect or correct == incorrect:
            return False
        
        # Avoid false_answer that contains correct answer or vice versa
        if correct in incorrect or incorrect in correct:
            return False
        
        # Optional: Limit text lengths (avoid noise or extremely long texts)
        if len(correct.split()) < 1 or len(correct.split()) > 20:
            return False
        if len(incorrect.split()) < 1 or len(incorrect.split()) > 30:
            return False
        
        return True
    @staticmethod
    def normalize(text):
        return re.sub(r"\s+", " ", str(text)).strip()
            
    def load_dataset(self, target_samples=450):
        """Load dataset from Hugging Face: OamPatel/iti_trivia_qa_val with random sampling"""
        try:
            # Load the dataset from Hugging Face (validation split)
            ds = load_dataset("OamPatel/iti_nq_open_val", split="validation")
            print(f"Total dataset size: {len(ds)}")

            filtered_data = ds.filter(self.is_valid_sample)
            subset = filtered_data.shuffle(seed=42).select(range(450))
            formatted_data = [{
                "question": row["question"],
                "correct_answer": self.normalize(row["answer"][0]),
                "incorrect_answer": self.normalize(row["false_answer"])
            } for row in subset]
            
            # Preview
            print(formatted_data[0])
            return formatted_data
            
        except Exception as e:
            print(f"Error loading dataset: {e}")
            print("Please check the dataset name and structure")
            return []
    

    
    def extract_representation(self, text, return_last=True):
        """Extract representation for query (last token)"""
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(self.model.device)

        with torch.no_grad():
            output = self.model(**inputs, output_hidden_states=True)

        hidden_states = output.hidden_states

        if return_last:
            last_token_index = inputs['input_ids'].shape[1] - 1
            representations = [layer[:, last_token_index, :].squeeze(0) for layer in hidden_states]
        else:
            attention_mask = inputs['attention_mask']
            token_counts = attention_mask.sum(dim=1).unsqueeze(-1)
            representations = [
                (layer * attention_mask.unsqueeze(-1)).sum(dim=1) / token_counts
                for layer in hidden_states
            ]
            representations = [r.squeeze(0) for r in representations]

        return representations

    def extract_answer_only_representation(self, question, answer):
        """Extract representation for answer tokens only (average over answer tokens)"""
        # Tokenize question and answer separately
        q_tokens = self.tokenizer(question, return_tensors="pt", add_special_tokens=False)
        a_tokens = self.tokenizer(answer, return_tensors="pt", add_special_tokens=False)
    
        # Concatenate input_ids and attention_mask
        input_ids = torch.cat([q_tokens["input_ids"], a_tokens["input_ids"]], dim=1)
        attention_mask = torch.cat([q_tokens["attention_mask"], a_tokens["attention_mask"]], dim=1)
        inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
    
        with torch.no_grad():
            output = self.model(**inputs, output_hidden_states=True)
        hidden_states = output.hidden_states  # list of [1, seq_len, hidden_dim]
    
        # Indices for answer tokens
        answer_start = q_tokens["input_ids"].shape[1]
        answer_end = input_ids.shape[1]
    
        result = []
        for layer_h in hidden_states:
            # Only average over answer tokens
            answer_h = layer_h[0, answer_start:answer_end, :]
            result.append(answer_h.mean(dim=0).cpu())
        return result

    def run_extraction(self, data):
        """Run extraction on the dataset"""
        print(f"Processing {len(data)} samples...")
        for sample in tqdm(data, desc="Extracting representations"):
            question = sample["question"]
            correct_answer = sample["correct_answer"]
            incorrect_answer = sample["incorrect_answer"]

            # Extract query representation (last token)
            q_repr = self.extract_representation(question, return_last=True)

            # Extract correct answer representation (average over answer tokens)
            correct_repr = self.extract_answer_only_representation(question, correct_answer)

            # Extract incorrect answer representation (average over answer tokens)
            incorrect_repr = self.extract_answer_only_representation(question, incorrect_answer)

            # Store representations
            self.representations["query"].append(q_repr)
            self.representations["correct"].append(correct_repr)
            self.representations["incorrect"].append(incorrect_repr)
            
            # Clear cache to prevent memory issues
            torch.cuda.empty_cache()

    def save_npz(self, path="truthflow_hiddenstates.npz"):
        def stack_group(name):
            if not self.representations[name]:
                raise ValueError(f"No representations found for '{name}'")
            return np.stack([torch.stack(x).cpu().numpy() for x in self.representations[name]])
    
        np.savez(path,
                 query=stack_group("query"),
                 correct=stack_group("correct"),
                 incorrect=stack_group("incorrect"))
        print(f"Saved to {path}")

    def clear_representations(self):
        """Clear stored representations to free memory"""
        self.representations = {"query": [], "correct": [], "incorrect": []}

In [12]:
if __name__ == "__main__":
    # Initialize extractor
    extractor = TruthFlowExtractor(model_id="google/gemma-2-2b", target_samples=450)
    
    # Setup model
    extractor.setup_model()
    
    # Load the Hugging Face dataset with random sampling
    data = extractor.load_dataset(target_samples=450)
    
    # Run extraction
    extractor.run_extraction(data)
    
    # Save results
    extractor.save_npz("nq_hidden_states.npz")

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded google/gemma-2-2b with 26 layers
Total dataset size: 3610
{'question': 'who voices randy in f is for family', 'correct_answer': 'T.J. Miller', 'incorrect_answer': 'Adam Sandler'}
Processing 450 samples...


Extracting representations:   0%|          | 0/450 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Extracting representations: 100%|██████████| 450/450 [01:36<00:00,  4.66it/s]


Saved to nq_hidden_states.npz
