In [None]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px  
import random
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset,load_from_disk
from transformer_lens import HookedTransformer
from typing import Any, Generator, Iterator, Literal, cast
from sae_lens import SAE
from transformers import (
    AutoTokenizer,
    LlavaNextForConditionalGeneration,
    LlavaNextProcessor,
    AutoModelForCausalLM,
)
os.environ["http_proxy"] = "http://127.0.0.1:7890"
os.environ["https_proxy"] = "http://127.0.0.1:7890"

from transformer_lens.HookedLlava import HookedLlava
from sae_lens.activation_visualization import (
    load_llava_model,
    load_sae,
    separate_feature,
)
model_name = "llava-hf/llava-v1.6-mistral-7b-hf"
model_path="/data/models/llava-v1.6-mistral-7b-hf"
sae_path="/data/changye/model/llavasae_obliec100k_SAEV"
sae_device="cuda:7"
device="cuda:0"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
 # 加载模型
processor, vision_model, vision_tower, multi_modal_projector, hook_language_model = load_llava_model(
        model_name, model_path, device,n_devices=8
    )
sae = load_sae(sae_path, sae_device)
# del vision_model
# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
# sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
#     sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
#     device = device
# )



## loading dataset

In [14]:
dataset_path="/data/changye/data/SPA-VL"
system_prompt= "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
user_prompt= 'USER: \n<image> {input}'
assistant_prompt= '\nASSISTANT: {output}'
split_token= 'ASSISTANT:'
train_dataset = load_dataset(
            dataset_path,
            split="train",
            trust_remote_code=True,
        )
print(train_dataset)
sample_size = 8
total_size = len(train_dataset)
random_indices = random.sample(range(total_size), sample_size)
sampled_dataset = train_dataset.select(random_indices)

# 定义格式化函数
def format_sample(raw_sample: dict[str, Any]) -> dict[str, Any]:
    """
    格式化样本，只提取 question 和 image 字段，并生成所需的 prompt。
    """
    # 获取并清洗 question 字段
    prompt = raw_sample['question'].replace('<image>\n', '').replace('\n<image>', '').replace('<image>', '')
    
    # 加载和处理 image 字段
    image = raw_sample['image']
    # if isinstance(image, str):  # 如果 image 是路径
    #     image = Image.open(image).convert('RGBA')
    # elif hasattr(image, "convert"):  # 如果是 PIL.Image 对象
    image = image.convert('RGBA')

    
    # 格式化 Prompt
    formatted_prompt = (
        f'{system_prompt}'
        f'{user_prompt.format(input=prompt)}'
        f'{assistant_prompt.format(output="")}'
    )
    
    return {
        'prompt': formatted_prompt,
        'image': image,
    }

# 使用 map 方法处理数据集
formatted_dataset = sampled_dataset.map(
    format_sample,
    num_proc=80,  # 根据您的 CPU 核心数量调整
    remove_columns=['chosen','rejected','image_name','question'],
)
print(formatted_dataset)
# 如果需要进一步处理，可以将 formatted_dataset 转换为列表
formatted_sample = formatted_dataset[:]
# print(formatted_sample)





num_proc must be <= 8. Reducing num_proc to 8 for dataset of size 8.


Dataset({
    features: ['image', 'question', 'chosen', 'rejected', 'image_name'],
    num_rows: 93258
})


Map (num_proc=8):   0%|          | 0/8 [00:00<?, ? examples/s]

Dataset({
    features: ['image', 'prompt'],
    num_rows: 8
})


In [None]:
def preprocess_function(examples):
    images = examples['image']
    texts = examples['prompt']
    inputs = processor(images=images, text=texts, padding='max_length', truncation=True, return_tensors="pt")
    return inputs

# 使用 processor 对数据集进行批处理
processed_dataset = formatted_dataset.map(
    preprocess_function,
    batched=True,
    batch_size=8,  # 根据您的显存和需求调整
    remove_columns=formatted_dataset.column_names,
)

# 打印一个处理后的示例
for batch in processed_dataset:
    print(batch)
    break  
