In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px  
import random
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset,load_from_disk
from transformer_lens import HookedTransformer
from typing import Any, Generator, Iterator, Literal, cast
from sae_lens import SAE
from transformers import (
    AutoTokenizer,
    LlavaNextForConditionalGeneration,
    LlavaNextProcessor,
    AutoModelForCausalLM,
)
os.environ["http_proxy"] = "http://127.0.0.1:7890"
os.environ["https_proxy"] = "http://127.0.0.1:7890"

from transformer_lens.HookedLlava import HookedLlava
from sae_lens.activation_visualization import (
    load_llava_model,
    load_sae,
    separate_feature,
    run_model,
)
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" 
model_name = "llava-hf/llava-v1.6-mistral-7b-hf"
model_path="/data/models/llava-v1.6-mistral-7b-hf"
sae_path="/data/changye/model/llavasae_obliec100k_SAEV"
sae_device="cuda:1"
device="cuda:0"

In [None]:
 # 加载模型
processor,  hook_language_model = load_llava_model(
        model_name, model_path, device,n_devices=2,stop_at_layer=17
    )
sae = load_sae(sae_path, sae_device)
# del vision_model
# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
# sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
#     sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
#     device = device
# )



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Quantization check time: 0.00s
Configuration loading time: 0.00s
Model configuration processing time: 1.83s
State dict loading time: 0.01s
Tokenizer setup time: 0.38s
Embedding setup time: 79.95s
Move device time: 0.00s
Set up time: 0.00s
Model creation time: 80.34s
State dict processing time: 37.76s
Device moving time: 26.25s
Total loading time: 146.18s


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


## loading dataset

In [None]:
dataset_path="/data/changye/data/SPA-VL"
system_prompt= "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
user_prompt= 'USER: \n<image> {input}'
assistant_prompt= '\nASSISTANT: {output}'
split_token= 'ASSISTANT:'
train_dataset = load_dataset(
            dataset_path,
            split="train",
            trust_remote_code=True,
        )
print(train_dataset)
sample_size = 1000
total_size = len(train_dataset)
random_indices = random.sample(range(total_size), sample_size)
sampled_dataset = train_dataset.select(random_indices)

# 定义格式化函数
def format_sample(raw_sample: dict[str, Any]) -> dict[str, Any]:
    """
    格式化样本，只提取 question 和 image 字段，并生成所需的 prompt。
    """
    # 获取并清洗 question 字段
    prompt = raw_sample['question'].replace('<image>\n', '').replace('\n<image>', '').replace('<image>', '')
    
    # 加载和处理 image 字段
    image = raw_sample['image']
    # if isinstance(image, str):  # 如果 image 是路径
    #     image = Image.open(image).convert('RGBA')
    # elif hasattr(image, "convert"):  # 如果是 PIL.Image 对象
    image=image.resize((336,336))
    image = image.convert('RGBA')

    
    # 格式化 Prompt
    formatted_prompt = (
        f'{system_prompt}'
        f'{user_prompt.format(input=prompt)}'
        f'{assistant_prompt.format(output="")}'
    )
    
    return {
        'prompt': formatted_prompt,
        'image': image,
        'image_name':raw_sample['image_name']
    }

# 使用 map 方法处理数据集
formatted_dataset = sampled_dataset.map(
    format_sample,
    num_proc=80,  # 根据您的 CPU 核心数量调整
    remove_columns=['chosen','rejected','question'],
)
# print(formatted_dataset)
# 如果需要进一步处理，可以将 formatted_dataset 转换为列表
formatted_sample = formatted_dataset[:]
# print(formatted_sample['image_name'][0])

hf_dataset = Dataset.from_dict(formatted_sample)

# 保存为 Arrow 格式
save_path = "/data/changye/data/SPA_VL1k"
os.makedirs(save_path, exist_ok=True)
hf_dataset.save_to_disk(save_path)
print(f"Dataset saved to {save_path}")




Generating train split:   0%|          | 0/93258 [00:00<?, ? examples/s]

Dataset({
    features: ['image', 'question', 'chosen', 'rejected', 'image_name'],
    num_rows: 93258
})


Map (num_proc=80):   0%|          | 0/1000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset saved to /data/changye/data/SPA_VL1k


In [4]:
# image_name_list=[]

# for data in tqdm(train_dataset):
#     image_name=data['image_name']
#     if image_name in image_name_list:
#         print("error!")
#         break
#     else:
#         image_name_list.append(image_name)

In [5]:
inputs = processor(
        text=formatted_sample['prompt'],
        images=formatted_sample['image'],
        return_tensors='pt',
        padding='max_length',  # 设置padding为最大长度
        max_length=256,  # 设置最大长度
    ).to(device)

# 打印一个处理后的示例
print((inputs['input_ids'].shape))
torch.cuda.empty_cache()


torch.Size([2, 256])


In [6]:
# for batch in processed_dataset:
#     # print(dir(batch))
#     image_indices, feature_act = run_model(batch, hook_language_model, sae, sae_device)
#     break  

image_indices, feature_act = run_model(inputs, hook_language_model, sae, sae_device)


tensor([[  35,   36,   37,  ..., 1208, 1209, 1210],
        [  35,   36,   37,  ..., 1208, 1209, 1210]], device='cuda:0')
out (tensor([[[ -4.8597,  -4.7012,  -0.1998,  ...,   0.1763,   0.1783,   0.1782],
         [ -6.7856,  -6.8383,  -3.4369,  ...,   0.1607,   0.1632,   0.1608],
         [ -7.6597,  -8.0547,  -2.2168,  ...,   0.2865,   0.2930,   0.2836],
         ...,
         [ -8.4231,  -8.8412,   3.6364,  ...,   0.1510,   0.1481,   0.1504],
         [ -6.9022,  -7.0688,   1.6313,  ...,  -0.0247,  -0.0260,  -0.0219],
         [ -6.4647,  -6.3016,  10.9200,  ...,   0.2390,   0.2362,   0.2436]],

        [[ -4.8597,  -4.7012,  -0.1998,  ...,   0.1763,   0.1783,   0.1782],
         [ -6.7856,  -6.8383,  -3.4369,  ...,   0.1607,   0.1632,   0.1608],
         [ -7.6597,  -8.0547,  -2.2168,  ...,   0.2865,   0.2930,   0.2836],
         ...,
         [-10.2070, -10.4642,   7.1662,  ...,   0.1577,   0.1569,   0.1592],
         [ -6.9841,  -7.2379,   3.6436,  ...,  -0.0236,  -0.0251,  -0.019

In [7]:
print((image_indices.shape))
print(feature_act.shape)




torch.Size([2, 1176])
torch.Size([2, 1244, 65536])


In [8]:
cooccurrence_feature=separate_feature(image_indices, feature_act)
print(len(cooccurrence_feature[1]))

630


In [9]:
data_dict={}
for i in range(len(cooccurrence_feature)):
    data_dict[formatted_sample['image_name'][i]]=cooccurrence_feature[i]
print(data_dict)
batch_size = 10000
for i in range(0, len(data_dict), batch_size):
    batch_dict = dict(list(data_dict.items())[i:i+batch_size])
    torch.save(batch_dict, f'data_batch_{i // batch_size}.pt')

{'5655.jpg': [40963, 18436, 32792, 28699, 41002, 10285, 59445, 10293, 18488, 57401, 47165, 53326, 38998, 30808, 6238, 28768, 24673, 59489, 34919, 51305, 107, 8302, 4212, 39029, 39031, 49278, 36995, 55428, 49287, 2205, 24737, 26786, 49316, 41139, 4276, 47291, 200, 24785, 43217, 24787, 2262, 14553, 219, 26850, 26854, 49389, 8430, 63725, 63742, 26882, 47366, 18696, 28942, 20750, 30992, 59673, 6439, 6448, 57653, 59701, 47424, 43330, 47426, 49482, 14684, 63837, 35167, 8549, 49509, 55653, 10607, 12668, 6524, 8574, 63868, 386, 39298, 61829, 2454, 57754, 37293, 16831, 53696, 22979, 6603, 16864, 20963, 27117, 55790, 495, 29181, 33278, 57860, 39436, 41502, 37410, 4645, 57900, 37430, 62007, 2618, 55876, 59982, 29271, 605, 21092, 4718, 17007, 47727, 51833, 55933, 43657, 43658, 57998, 21135, 2705, 31378, 43689, 62125, 29372, 33489, 13011, 43732, 60116, 41688, 41695, 35557, 6885, 2791, 37607, 47845, 39660, 751, 13039, 6897, 47864, 17145, 17146, 47872, 33539, 51977, 21257, 25357, 25358, 23309, 37666,