In [1]:
import sys 
sys.path.append('/data3/KJE/code/WIL_DeepLearningProject_2/VLM_Hallu')
import argparse
import os
import random
from typing import List, Union, Optional, Dict, Tuple
import gc

import numpy as np
import pandas as pd

import torch
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score

import wandb

from transformers import AutoProcessor, LlavaForConditionalGeneration, set_seed  # noqa: F401
import src.probing_utils as utils

[2025-09-03 11:17:25,450] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/opt/anaconda3/envs/py3_9/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/opt/anaconda3/envs/py3_9/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/opt/anaconda3/envs/py3_9/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/opt/anaconda3/envs/py3_9/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::substr(unsigned long, unsigned long) const@GLIBCXX_3.4'
/opt/anaconda3/envs/py3_9/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'
/opt/anaconda3/envs/py3_9/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlopen'
/opt/anaconda3/envs/py3_9/compiler_compat/ld: /usr/local/cuda/lib64/libcufile

In [2]:
df = pd.read_csv('/data3/KJE/code/WIL_DeepLearningProject_2/VLM_Hallu/data/preprocess/aokvqa_pope_adversarial.csv')
required_cols = {"image_path", "question", "label"}
missing = required_cols - set(df.columns)
if missing:
    raise ValueError(f"Missing required column(s) in CSV: {missing}")

train_df = pd.read_csv('/data3/KJE/code/WIL_DeepLearningProject_2/VLM_Hallu/data/preprocess/llava-1.5-7b-hf-vizwiz_train-llava_answers.csv')


train_idx = list(range(len(train_df)))
valid_idx = list(range(len(df)))

In [3]:
tokens_to_probe: List[str] = ['image_token', 'first_text_token', 'last_text_token',
                            -8, -7, -6, -5, -4, -3, -2, -1]

In [4]:
def parse_subset_layers(arg_val: Optional[str], num_layers: int) -> List[int]:
    if not arg_val:
        return list(range(num_layers))
    raw = [s.strip() for s in arg_val.split(",") if s.strip() != ""]
    idxs: List[int] = []
    for s in raw:
        i = int(s)
        if i < 0:
            i = num_layers + i  # negative index from end
        if not (0 <= i < num_layers):
            raise ValueError(f"subset_layers index out of range after normalization: {i} (num_layers={num_layers})")
        idxs.append(i)
    # keep order but uniquify
    seen = set()
    ordered = []
    for i in idxs:
        if i not in seen:
            ordered.append(i)
            seen.add(i)
    return ordered

In [5]:
model, processor, tokenizer = utils.load_llava('llava-hf/llava-1.5-7b-hf')
num_layers = utils.get_num_layers_from_config(model)
subset_layers = parse_subset_layers(None, num_layers)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [20]:
probes: Dict[Tuple[str, int], SGDClassifier] = {}
initialized: Dict[Tuple[str, int], bool] = {}

In [7]:
def get_num_layers_from_config(model) -> int:
    cfg = getattr(model, "config", None)
    if hasattr(cfg, "text_config") and hasattr(cfg.text_config, "num_hidden_layers"):
        return int(cfg.text_config.num_hidden_layers)
    if hasattr(cfg, "num_hidden_layers"):
        return int(cfg.num_hidden_layers)
    raise ValueError("Cannot find num_hidden_layers in config.")

def get_image_token_id(model) -> Optional[int]:
    return getattr(getattr(model, "config", None), "image_token_index", None)

def tokkey(t):
    return str(t) if isinstance(t, int) else t

In [8]:
#collect_features_for_split 
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
num_layers = get_num_layers_from_config(model)
image_token_id = get_image_token_id(model)

neg_ks = sorted([k for k in tokens_to_probe if isinstance(k, int)])

token_keys = [tokkey(t) for t in tokens_to_probe]
features: Dict[str, Dict[int, List[np.ndarray]]] = {k: {l: [] for l in range(num_layers)} for k in token_keys}
labels: Dict[str, List[int]] = {k: [] for k in token_keys}

In [9]:
image_root = '/data3/KJE/code/WIL_DeepLearningProject_2/VLM_Hallu/data/COCO/images/val2014'
def _img_path(p):
    if image_root is None:
        return p
    return os.path.join(image_root, p)

In [10]:
def build_prompt(tokenizer, question: str) -> str:
    # 메시지 안에 이미지 + 텍스트 구성
    content = [{"type": "image"}, {"type": "text", "text": question}]
    if hasattr(tokenizer, "apply_chat_template"):
        messages = [{"role": "user", "content": content}]
        try:
            prompt = tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=False
            )
            return prompt
        except Exception:
            pass
    # fallback: 단순 프롬프트
    return "<image>\n" + question.strip() + "\n"

In [11]:
from tqdm import tqdm
from PIL import Image

batch_size = 16 

for start in tqdm(range(0, len(train_idx), batch_size)):
        batch_idx = train_idx[start : start + batch_size]
        rows = df.iloc[batch_idx]

        # load images & prompts
        images = []
        prompts = []
        ys = []
        for _, r in rows.iterrows():
            img = Image.open(_img_path(r["image_path"])).convert("RGB")
            images.append(img)
            prompts.append(build_prompt(tokenizer, str(r["question"]))
                           if "question" in r else build_prompt(tokenizer, str(r["text"])))
            ys.append(int(r["label"]))
        # encode
        batch = processor(
            images=images,
            text=prompts,
            padding=True,
            return_tensors="pt"
        )
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        pixel_values = batch["pixel_values"].to(dtype=dtype, device=device)

        out = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            output_hidden_states=True,
            return_dict=True,
            use_cache=False
        )

        output = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            output_hidden_states=True,
            return_dict_in_generate=True,
            use_cache=False
        )

        break

  0%|          | 0/937 [00:00<?, ?it/s]

  0%|          | 0/937 [00:23<?, ?it/s]


In [None]:
hs = torch.stack(output.hidden_states[-0][1:], dim=0)

: 

In [13]:
def nonpad_indices(attn_mask: torch.Tensor) -> torch.Tensor:
    # returns positions where attention_mask == 1
    return torch.nonzero(attn_mask, as_tuple=False).squeeze(-1)

def locate_positions_for_sample(
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    image_token_id: Optional[int],
    neg_ks: List[int],
) -> Dict[str, Optional[int]]:
    
    pos = {}
    idxs = nonpad_indices(attention_mask)
    if idxs.numel() == 0:
        # all pad? shouldn't happen
        for k in neg_ks:
            pos[str(k)] = None
        pos["image_token"] = None
        pos["first_text_token"] = None
        pos["last_text_token"] = None
        return pos

    # last/first non-pad
    first_idx = int(idxs[0].item())
    last_idx = int(idxs[-1].item())

    # negative offsets from end
    for k in neg_ks:  # e.g., -8..-1
        kk = abs(k)
        if kk <= idxs.numel():
            pos[str(k)] = int(idxs[-kk].item())
        else:
            pos[str(k)] = None

    # image token
    if image_token_id is not None:
        img_pos = torch.nonzero(input_ids == image_token_id, as_tuple=False)
        if img_pos.numel() > 0:
            img_first = int(img_pos[0].item())
        else:
            img_first = None
    else:
        img_first = None
    pos["image_token"] = img_first

    if img_first is not None:
        greater = idxs[idxs > img_first]
        if greater.numel() > 0:
            pos["first_text_token"] = int(greater[0].item())
        else:
            pos["first_text_token"] = None
    else:
        pos["first_text_token"] = first_idx

    # last_text_token
    pos["last_text_token"] = last_idx

    return pos

In [62]:
B = input_ids.size(0)
for b in range(B):
    yb = ys[b]
    # import pdb;pdb.set_trace()
    # pos_map = locate_positions_for_sample(
    #     input_ids[b], attention_mask[b], image_token_id, neg_ks
    # ) 

    pos_map = {}
    idxs = nonpad_indices(attention_mask[b])
    if idxs.numel() == 0:
        # all pad? shouldn't happen
        for k in neg_ks:
            pos_map[str(k)] = None
        pos_map["image_token"] = None
        pos_map["first_text_token"] = None
        pos_map["last_text_token"] = None

    # last/first non-pad
    first_idx = int(idxs[0].item())
    last_idx = int(idxs[-1].item())

    # negative offsets from end
    for k in neg_ks:  # e.g., -8..-1
        kk = abs(k)
        if kk <= idxs.numel():
            pos_map[str(k)] = int(idxs[-kk].item())+20
        else:
            pos_map[str(k)] = None

    # # image token
    # if image_token_id is not None:
    #     img_pos = torch.nonzero(input_ids[b] == image_token_id, as_tuple=False)
    #     if img_pos.numel() > 0:
    #         img_first = int(img_pos[0].item())
    #     else:
    #         img_first = None
    # else:
    #     img_first = None
    # pos_map["image_token"] = img_first

    # if img_first is not None:
    #     greater = idxs[idxs > img_first]
    #     if greater.numel() > 0:
    #         pos_map["first_text_token"] = int(greater[0].item())
    #     else:
    #         pos_map["first_text_token"] = None
    # else:
    #     pos_map["first_text_token"] = first_idx

    # # last_text_token
    # pos_map["last_text_token"] = last_idx


    # # build a per-token valid mask & per-token position
    # for t in tokens_to_probe:
    #     k = tokkey(t)                  
    #     p = pos_map.get(k, None)         
    #     if p is None:
    #         continue  # skip this sample for this token
    #     # append label once per sample (per token)
    #     labels[k].append(yb)
    #     # for each layer, slice feature
    #     for l in range(num_layers):
    #         vec = hs[l, b, p, :].detach().float().cpu().numpy()  # (H,)
    #         features[k][l].append(vec)


In [63]:
pos_map

{'-8': 608,
 '-7': 609,
 '-6': 610,
 '-5': 611,
 '-4': 612,
 '-3': 613,
 '-2': 614,
 '-1': 615}

In [16]:
out.keys()

odict_keys(['logits', 'hidden_states', 'image_hidden_states'])

In [41]:
output.keys()

odict_keys(['sequences', 'hidden_states'])

In [49]:
len(output.hidden_states)

20

In [43]:
len(output['sequences'][0]), output.hidden_states[-1][0].shape , out.hidden_states[0].shape

(616, torch.Size([16, 615, 4096]), torch.Size([16, 596, 4096]))

In [61]:
for i in range(0,20):
    print(output.hidden_states[-0][0].shape) #0~32 layer 수 

torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])
torch.Size([16, 596, 4096])


In [26]:
input_decode = processor.decode(input_ids[0])
output_decode =processor.decode(output.sequences[0]) 

In [27]:
output_decode

'<s> USER: <image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><i

In [None]:
pos_map

{'-8': 588,
 '-7': 589,
 '-6': 590,
 '-5': 591,
 '-4': 592,
 '-3': 593,
 '-2': 594,
 '-1': 595,
 'image_token': 5,
 'first_text_token': 6,
 'last_text_token': 595}