# PE

In [3]:
!pwd

/workspace/MuKA


In [4]:
cd ..

/workspace


In [15]:
import torch
import torch.nn as nn

import torch
import torch.nn as nn

import torch
import torch.nn as nn

class MaskedLearnableQueryPositionalEmbedding(nn.Module):
    def __init__(self, max_len_q: int, d_model: int):
        super().__init__()
        self.query_pos_embedding = nn.Embedding(max_len_q, d_model)

    def forward(self, query, q_pos_labels=None):
        """
        query: [B, L_q, D]
        q_pos_labels: [B, L_q] (0 or 1)
        """
        B, L_q, D = query.shape

        q_positions = torch.arange(L_q, device=query.device).unsqueeze(0).expand(B, L_q)
        q_pos_emb = self.query_pos_embedding(q_positions)

        if q_pos_labels is not None:
            q_mask = q_pos_labels.unsqueeze(-1).float()
            # gradient stop for mask==0
            q_pos_emb = q_pos_emb.detach() * (1 - q_mask) + q_pos_emb * q_mask

        return query + q_pos_emb



from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertEncoder
from FLMR.flmr.models.flmr.configuration_flmr import FLMRConfig, FLMRTextConfig, FLMRVisionConfig
config = FLMRConfig.from_pretrained("/workspace/PreFLMR_ViT-G")

transformer_mapping_config_base = config.transformer_mapping_config_base
transformer_mapping_config = BertConfig.from_pretrained(transformer_mapping_config_base)
transformer_mapping_network = BertEncoder(transformer_mapping_config)

In [11]:
late_interaction_embedding_size = config.dim
print(late_interaction_embedding_size)

128


In [12]:
transformer_mapping_output_linear = nn.Linear(
    transformer_mapping_config.hidden_size, late_interaction_embedding_size
)

In [16]:
# ===== 1️⃣ 랜덤 입력 텐서 생성 =====
B, L_q, L_k, D = 2, 5, 8, 768
query = torch.randn(B, L_q, D)
key_value = torch.randn(B, L_k, D)

# query position labels (0=skip, 1=learnable)
q_labels = torch.tensor([[0,1,0,1,0],
                         [1,1,0,0,1]])

# encoder attention mask (optional, 1=keep, 0=mask)
encoder_extended_attention_mask = torch.ones(B, 1, 1, L_k)

In [17]:
pos_enc = MaskedLearnableQueryPositionalEmbedding(max_len_q=512, d_model=D)

# query에 positional encoding 적용
query_with_pos = pos_enc(query, q_labels)

In [24]:
# ===== 4️⃣ Cross-attention 실행 =====
transformer_mapping_outputs = transformer_mapping_network(
    hidden_states=query_with_pos,            # query
    encoder_hidden_states=key_value,         # key/value
    encoder_attention_mask=encoder_extended_attention_mask
)

# ===== 5️⃣ Output 추출 =====
transformer_mapping_output_features = transformer_mapping_outputs.last_hidden_state
transformer_mapping_output_features = transformer_mapping_output_linear(
                    transformer_mapping_output_features
                )
print(transformer_mapping_output_features.shape)  # [B, L_q, D]

torch.Size([2, 11, 128])


# PE with img concat true

In [19]:
class MaskedLearnableQueryPositionalEmbedding(nn.Module):
    def __init__(self, max_len_q: int, d_model: int, use_img_emb=False):
        super().__init__()
        self.query_pos_embedding = nn.Embedding(max_len_q, d_model)
        self.use_img_emb = use_img_emb

    def forward(self, query, q_pos_labels=None, img_embs=None):
        """
        query: [B, L_q, D]
        q_pos_labels: [B, L_q] (0 or 1)
        img_embs: list of tensors, each [B, L_img, D]  (optional)
        """
        B, L_q, D = query.shape

        # ===== q_labels 확장 (img_embs 삽입 반영) =====
        if self.use_img_emb and q_pos_labels is not None and img_embs is not None:
            expanded_labels_list = []
            for b in range(B):
                labels = q_pos_labels[b].tolist()
                new_labels = []
                img_idx = 0
                for l in labels:
                    new_labels.append(l)
                    if l == 1:
                        # img_emb 삽입
                        if img_idx < len(img_embs):
                            img_seq_len = img_embs[img_idx].size(1)
                            new_labels.extend([1]*img_seq_len)
                            img_idx += 1
                expanded_labels_list.append(new_labels)
            # padding/truncating to match max_len_q
            max_len = max(len(l) for l in expanded_labels_list)
            q_pos_labels_expanded = torch.zeros(B, max_len, device=query.device)
            for b, l in enumerate(expanded_labels_list):
                q_pos_labels_expanded[b, :len(l)] = torch.tensor(l, device=query.device)
            q_pos_labels = q_pos_labels_expanded
            # query도 max_len에 맞춰 0-padding 필요
            if max_len > L_q:
                pad_len = max_len - L_q
                query = torch.cat([query, torch.zeros(B, pad_len, D, device=query.device)], dim=1)
                L_q = max_len

        # ===== Positional Embedding 적용 =====
        q_positions = torch.arange(L_q, device=query.device).unsqueeze(0).expand(B, L_q)
        q_pos_emb = self.query_pos_embedding(q_positions)

        if q_pos_labels is not None:
            q_mask = q_pos_labels.unsqueeze(-1).float()
            # gradient stop for mask==0
            q_pos_emb = q_pos_emb.detach() * (1 - q_mask) + q_pos_emb * q_mask

        return query + q_pos_emb


In [23]:
# ===== 2️⃣ 랜덤 입력 생성 =====
B, L_q, L_k, D = 2, 6, 8, 768
query = torch.randn(B, L_q, D)
q_labels = torch.tensor([[0,1,1,0,1,0],
                         [1,0,1,0,1,0]])

img1 = torch.randn(B, 2, D)
img2 = torch.randn(B, 3, D)
img_embs = [img1, img2]

key_value = torch.randn(B, L_k, D)
encoder_extended_attention_mask = torch.ones(B, 1, 1, L_k)
encoder_extended_attention_mask[:, :, :, -1] = 0 # 임시 실제로는 뒤에 padding 에 대한 attention mask 가 들어감

# ===== 3️⃣ Positional Encoding 적용 =====
pos_enc = MaskedLearnableQueryPositionalEmbedding(max_len_q=512, d_model=D, use_img_emb=True)
query_with_pos = pos_enc(query, q_labels, img_embs)

# ===== 5️⃣ Cross-Attention 실행 =====
transformer_mapping_outputs = transformer_mapping_network(
    hidden_states=query_with_pos,
    encoder_hidden_states=key_value,
    encoder_attention_mask=encoder_extended_attention_mask
)

# ===== 6️⃣ Output 확인 =====
transformer_mapping_output_features = transformer_mapping_outputs.last_hidden_state
transformer_mapping_output_features = transformer_mapping_output_linear(
                    transformer_mapping_output_features
                )
print(transformer_mapping_output_features.shape)  # [B, L_query+img_total_len, D]

torch.Size([2, 11, 128])


In [29]:
# ===== 2️⃣ 랜덤 입력 생성 =====
B, L_q, L_k, D = 2, 6, 8, 768
query = torch.randn(B, L_q, D)
q_labels = torch.tensor([[0,1,1,0,1,0],
                         [1,0,1,0,1,0]])

img1 = torch.randn(B, 2, D)
img2 = torch.randn(B, 3, D)
img_embs = [img1, img2]

key_value = torch.randn(B, L_k, D)
encoder_extended_attention_mask = torch.ones(B, 1, 1, L_k)
encoder_extended_attention_mask[:, :, :, -1] = 0 # 임시 실제로는 뒤에 padding 에 대한 attention mask 가 들어감

# ===== 3️⃣ Positional Encoding 적용 =====
pos_enc = MaskedLearnableQueryPositionalEmbedding(max_len_q=512, d_model=D, use_img_emb=False)
query_with_pos = pos_enc(query, q_labels, img_embs)

# ===== 5️⃣ Cross-Attention 실행 =====
transformer_mapping_outputs = transformer_mapping_network(
    hidden_states=query_with_pos,
    encoder_hidden_states=key_value,
    encoder_attention_mask=encoder_extended_attention_mask
)

# ===== 6️⃣ Output 확인 =====
transformer_mapping_output_features = transformer_mapping_outputs.last_hidden_state
transformer_mapping_output_features = transformer_mapping_output_linear(
                    transformer_mapping_output_features
                )
print(transformer_mapping_output_features.shape)  # [B, L_query+img_total_len, D]

torch.Size([2, 6, 128])


# eval result visualization

In [1]:
import json

report = "/workspace/MuKA/reports/EVQA_multi_image_test_set_init.json"

with open(report, "r") as f:
    data = json.load(f)


In [2]:
from datasets import load_dataset

dataset = load_dataset("/workspace/datasets")
ds = dataset['test']

In [3]:
idx = ds['question_id'].index("question_0")

In [4]:
dataset['test'][idx]['related_entity']

'Australia'

In [None]:
retrieved_passages = []

In [None]:
found_distractor = any(
    any(d in p['passage_id'] for d in distractor_passage)
    for p in retrieved_passages
)

In [8]:
k = 5
found_count = 0            # gold recall
distractor_count = 0       # distractor recall
total = len(data)

for sample in data:
    retrieved_passages = sample[1]['retrieved_passage'][:k]
    idx = ds['question_id'].index(sample[0])

    gold_passage = dataset['test'][idx]['pos_item_ids']
    distractor_passage = ["".join(dataset['test'][idx]['related_entity'])]
    # distractor_passage = list(dataset['test'][idx]['related_entity'])
    # distractor_passage = ["".join(d) if isinstance(d, list) else d for d in distractor_passage]

    # gold 포함 여부
    found = any(p['passage_id'] in gold_passage for p in retrieved_passages)
    if found:
        found_count += 1

    # distractor 포함 여부
    # found_distractor = any(p['passage_id'] in distractor_passage for p in retrieved_passages)
    # if found_distractor:
    #     distractor_count += 1
    found_distractor = any(
        any(d in p['passage_id'] for d in distractor_passage)
        for p in retrieved_passages
    )
    if found_distractor:
        distractor_count += 1
        
print(f"Gold found: {found_count} / {total} ({found_count / total:.4f})")
print(f"Distractor found: {distractor_count} / {total} ({distractor_count / total:.4f})")

Gold found: 113 / 1390 (0.0813)
Distractor found: 1388 / 1390 (0.9986)


In [9]:
distractor_passage

['L',
 'a',
 'm',
 'p',
 'i',
 'd',
 'e',
 's',
 ' ',
 'b',
 'o',
 'e',
 't',
 'i',
 'c',
 'u',
 's']

# Dataset Stats

In [15]:
passage_ds = load_from_disk(f"/workspace/processed_test_passage_ds")

# passage_contents = list(zip(
#     passage_ds['passage_content'],
#     [None] * len(passage_ds),
#     passage_ds['passage_img_path']
# ))
MAX_IMAGES = 5

passage_contents = [
    (content, None, img_paths[:MAX_IMAGES], token_labels)
    for content, img_paths, token_labels in zip(
        passage_ds['passage_content'],
        passage_ds['passage_img_paths'],
        passage_ds['token_labels']
    )
]

In [16]:
sample

("Heracleum mantegazzianum, commonly known as giant hogweed, is a monocarpic perennial herbaceous plant in the carrot family Apiaceae. H.\xa0mantegazzianum is also known as cartwheel-flower, giant cow parsley, giant cow parsnip, or hogsbane. In New Zealand, it is also sometimes called wild parsnip (not to be confused with Pastinaca sativa) or wild rhubarb.\nGiant hogweed is native to the western Caucasus region of Eurasia. It was introduced to Britain as an ornamental plant in the 19th century, and has also spread to other areas in Western Europe, the United States, and Canada. Its close relatives, Sosnowsky's hogweed and Persian hogweed, have similarly spread to other parts of Europe.\nThe sap of giant hogweed is phototoxic and causes phytophotodermatitis in humans, resulting in blisters and scars.  These serious reactions are due to the furanocoumarin derivatives in the leaves, roots, stems, flowers, and seeds of the plant. Consequently, it is considered to be a noxious weed in many 

In [17]:
for sample in passage_contents:
    print(sample)
    break

("Heracleum mantegazzianum, commonly known as giant hogweed, is a monocarpic perennial herbaceous plant in the carrot family Apiaceae. H.\xa0mantegazzianum is also known as cartwheel-flower, giant cow parsley, giant cow parsnip, or hogsbane. In New Zealand, it is also sometimes called wild parsnip (not to be confused with Pastinaca sativa) or wild rhubarb.\nGiant hogweed is native to the western Caucasus region of Eurasia. It was introduced to Britain as an ornamental plant in the 19th century, and has also spread to other areas in Western Europe, the United States, and Canada. Its close relatives, Sosnowsky's hogweed and Persian hogweed, have similarly spread to other parts of Europe.\nThe sap of giant hogweed is phototoxic and causes phytophotodermatitis in humans, resulting in blisters and scars.  These serious reactions are due to the furanocoumarin derivatives in the leaves, roots, stems, flowers, and seeds of the plant. Consequently, it is considered to be a noxious weed in many 

In [11]:
def _sort_by_length(ids, mask, bsize, *args):
    if ids.size(0) <= bsize:
        return ids, mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices
    
    return_array = [ids[indices], mask[indices]]
    for arg in args:
        if isinstance(arg, torch.Tensor):
            return_array.append(arg[indices])
        else:
            # arg is a list, and we want to sort the list according to indices
            return_array.append([arg[i] for i in indices])

    return *return_array, reverse_indices

In [1]:
test_passage_path = "/workspace/processed_test_passage_ds"
train_passage_path = "/workspace/processed_train_passage_ds"

import json 
import os
from datasets import load_dataset, load_from_disk

test_passages = load_from_disk(test_passage_path)
train_passages = load_from_disk(train_passage_path)

In [8]:
def get_contiguous_spans(token_labels, MAX_IMAGES):
    spans = []
    start = None
    for i, label in enumerate(token_labels):
        if label == 1 and start is None:
            start = i
        elif label == 0 and start is not None:
            spans.append((start, i))  # [start, end)
            start = None
    if start is not None:  # 마지막이 1로 끝나는 경우
        spans.append((start, len(token_labels)))
    return spans[:MAX_IMAGES]

get_contiguous_spans([0,1,0,1,0,1,0], 10)

[(1, 2), (3, 4), (5, 6)]

In [6]:
sample.keys()

dict_keys(['language', 'passage_id', 'passage_content', 'passage_img_path', 'token_labels', 'passage_img_paths'])

In [7]:
from collections import Counter

img_counts = []

for sample in test_passages:
    img = sample['passage_img_paths']
    img_counts.append(len(img))

# 분포 집계
img_dist = Counter(img_counts)

print("Span 개수 분포:")
for num_img, freq in sorted(img_dist.items()):
    print(f"{num_img} spans: {freq} samples")

Span 개수 분포:
0 spans: 2033 samples
1 spans: 3278 samples
2 spans: 3828 samples
3 spans: 4134 samples
4 spans: 3915 samples
5 spans: 3569 samples
6 spans: 3334 samples
7 spans: 2968 samples
8 spans: 2604 samples
9 spans: 2356 samples
10 spans: 2041 samples
11 spans: 1829 samples
12 spans: 1593 samples
13 spans: 1484 samples
14 spans: 1277 samples
15 spans: 1152 samples
16 spans: 1039 samples
17 spans: 975 samples
18 spans: 816 samples
19 spans: 744 samples
20 spans: 655 samples
21 spans: 571 samples
22 spans: 520 samples
23 spans: 450 samples
24 spans: 435 samples
25 spans: 360 samples
26 spans: 362 samples
27 spans: 329 samples
28 spans: 304 samples
29 spans: 272 samples
30 spans: 214 samples
31 spans: 201 samples
32 spans: 176 samples
33 spans: 142 samples
34 spans: 160 samples
35 spans: 144 samples
36 spans: 116 samples
37 spans: 106 samples
38 spans: 111 samples
39 spans: 85 samples
40 spans: 75 samples
41 spans: 74 samples
42 spans: 70 samples
43 spans: 72 samples
44 spans: 60 sampl

In [3]:
from collections import Counter

span_counts = []

for sample in test_passages:
    spans = get_contiguous_spans(sample['token_labels'])
    span_counts.append(len(spans))

# 분포 집계
count_dist = Counter(span_counts)

print("Span 개수 분포:")
for num_spans, freq in sorted(count_dist.items()):
    print(f"{num_spans} spans: {freq} samples")

Span 개수 분포:
0 spans: 2033 samples
1 spans: 3477 samples
2 spans: 4111 samples
3 spans: 4518 samples
4 spans: 4208 samples
5 spans: 3765 samples
6 spans: 3489 samples
7 spans: 3146 samples
8 spans: 2607 samples
9 spans: 2389 samples
10 spans: 1971 samples
11 spans: 1806 samples
12 spans: 1646 samples
13 spans: 1364 samples
14 spans: 1206 samples
15 spans: 1113 samples
16 spans: 942 samples
17 spans: 858 samples
18 spans: 758 samples
19 spans: 654 samples
20 spans: 611 samples
21 spans: 517 samples
22 spans: 446 samples
23 spans: 447 samples
24 spans: 359 samples
25 spans: 343 samples
26 spans: 294 samples
27 spans: 263 samples
28 spans: 241 samples
29 spans: 201 samples
30 spans: 177 samples
31 spans: 157 samples
32 spans: 135 samples
33 spans: 150 samples
34 spans: 112 samples
35 spans: 115 samples
36 spans: 88 samples
37 spans: 70 samples
38 spans: 83 samples
39 spans: 67 samples
40 spans: 65 samples
41 spans: 46 samples
42 spans: 59 samples
43 spans: 71 samples
44 spans: 36 samples
4

In [12]:
token_labels = [0] + [0] + [1]+ [1]+[0]*508

In [13]:
def get_contiguous_spans(token_labels):
    spans = []
    start = None
    for i, label in enumerate(token_labels):
        if label == 1 and start is None:
            start = i
        elif label == 0 and start is not None:
            spans.append((start, i))  # [start, end)
            start = None
    if start is not None:  # 마지막이 1로 끝나는 경우
        spans.append((start, len(token_labels)))
    return spans

span = get_contiguous_spans(token_labels)

In [None]:
for sample in test_passages:
    print(sample['token_labels'])
    print(len(sample['token_labels']))
    # if sample['token_labels'] is not None and sum(sample['token_labels']) == 0:
    #     zero_token_label_count += 1
        # print(sample)
    break

In [16]:
len(sample['token_labels'])

248

In [14]:
span

[(2, 4)]

In [7]:
# zero_token_label_count = 0

# for sample in test_passages:
#     if sample['token_labels'] is not None and sum(sample['token_labels']) == 0:
#         zero_token_label_count += 1
#         # print(sample)

# print("token_labels가 모두 0인 샘플 개수:", zero_token_label_count)

{'language': 'en', 'passage_id': 'WikiWeb_Limonium sinuatum_1', 'passage_content': 'It is a short-lived perennial plant, and is often treated as an annual. The leaves are pinnate, lobed, and lance-shaped – up to 10\xa0cm (3.9\xa0in) long. All parts are downy. The winged flower stems appear in summer, and are about 70\xa0cm (28\xa0in) tall. The flowers present in short, papery clusters in colours ranging from white to pink, purple, and yellow. It has been known to become invasive.', 'passage_img_path': '/workspace/M2KR_Images/EVQA/images/Tema_Nezahat_Gokyigit_Park_1060783_20080513133708.JPG', 'token_labels': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'passage_img_paths': []}
{'language': 'en', 'passage_id': 'WikiWeb_Betula pumila

In [10]:
ds = load_dataset("/workspace/datasets")

In [11]:
ds

DatasetDict({
    train: Dataset({
        features: ['main_entity', 'related_entity', 'relation', 'attribute', 'source_title', 'user_prompt', 'system_prompt_main', 'system_prompt_related', 'question', 'answer', 'bad_sample', 'bidirectional', 'paraphrased_question', 'img_path', 'missing', 'question_id', 'instruction', 'pos_item_ids'],
        num_rows: 128548
    })
    test: Dataset({
        features: ['main_entity', 'related_entity', 'relation', 'attribute', 'source_title', 'user_prompt', 'system_prompt_main', 'system_prompt_related', 'question', 'answer', 'bad_sample', 'bidirectional', 'paraphrased_question', 'img_path', 'missing', 'question_id', 'instruction', 'pos_item_ids'],
        num_rows: 1393
    })
})

In [8]:
test_passages

Dataset({
    features: ['language', 'passage_id', 'passage_content', 'passage_img_path', 'token_labels', 'passage_img_paths'],
    num_rows: 51472
})

In [9]:
train_passages

Dataset({
    features: ['language', 'passage_id', 'passage_content', 'passage_img_path', 'token_labels', 'passage_img_paths'],
    num_rows: 50205
})

# Init

In [1]:
import os
import json
from datasets import load_dataset

ds = load_dataset("/workspace/datasets")

root_report_path = "/workspace/MuKA/reports"

In [2]:
def add_path_prefix_in_img_path(example, prefix):
    if example["img_path"] != None:
        example["img_path"] = os.path.join(prefix, example["img_path"])
    return example

ds = ds.map(add_path_prefix_in_img_path, fn_kwargs={"prefix": "/workspace/M2KR_Images/EVQA"},
            keep_in_memory=True)

Map:   0%|          | 0/128548 [00:00<?, ? examples/s]

Map:   0%|          | 0/1393 [00:00<?, ? examples/s]

# Setup

In [1]:
import transformers
from transformers import TrainingArguments, Trainer, HfArgumentParser
from transformers import AutoImageProcessor
from transformers import AutoConfig
from datasets import load_dataset

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import Dataset
from torch.utils.data import random_split

import warnings
import random
import os
import json
from dataclasses import dataclass
from PIL import Image
from pprint import pformat

from flmr import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer #, FLMRModelForRetrieval
from datasets import Dataset
import random

@dataclass
class MyArguments:
    model_name_or_path :str = "/workspace/PreFLMR_ViT-G" # "LinWeizheDragon/PreFLMR_ViT-G"
    image_processor_name :str = "/workspace/FLMR/CLIP-ViT-bigG-14-laion2B-39B-b160k" # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
    dataset_hf_path :str = "/workspace/datasets" # "BByrneLab/multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR"
    dataset :str = "EVQA" # "Infoseek"
    sample_examples :int = -1
    num_negative_examples :int = 4
    image_root_dir :str = "/workspace/M2KR_Images/EVQA"
    split_eval_from_train_examples :int = -1
    freeze_vision_encoder :bool = True
    freeze_text_encoder :bool = False
    # We use the Adam optimizer (Kingma and Ba, 2015) 
    # with a fixed learning rate of 10−4 for the mapping structure 
    # and 10−5 for the rest parameters in all experiments in all training stages.
    mapping_structure_lr :float = 1e-4
    non_mapping_structure_lr :float = 1e-5    
    doc_use_images :bool = False
    doc_image_root_dir :str = ""
    doc_image_title2image :str = ""
    title_key: str = ""


class PreFLMRTrainer(Trainer):
    # added , num_items_in_batch=None since transformers 4.49.0 requires this input
    def compute_loss(self, model, inputs, num_items_in_batch=None, return_outputs=False):
        outputs = model(**inputs, return_dict=True)
        ib_loss = outputs["in_batch_negative_loss"]
        outputs["loss"] = ib_loss

        return (ib_loss, outputs) if return_outputs else ib_loss

    # https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L3353C5-L3353C90
    # def save_model(self, output_dir=None, _internal_call: bool = False):
    #     super().save_model(output_dir, _internal_call)
    #     if output_dir is None:
    #         output_dir = self.args.output_dir
        
    #     self.query_tokenizer.save_pretrained(os.path.join(output_dir, 'query_tokenizer'))
    #     self.context_tokenizer.save_pretrained(os.path.join(output_dir, 'context_tokenizer'))
    def save_model(self, output_dir=None, _internal_call: bool = False):
        # super().save_model(output_dir, _internal_call)
        
        if output_dir is None:
            output_dir = self.args.output_dir
        
        self.model.save_pretrained(output_dir, safe_serialization=False)
            
        self.query_tokenizer.save_pretrained(
            os.path.join(output_dir, "query_tokenizer"),
            safe_serialization=False
        )

        self.context_tokenizer.save_pretrained(
            os.path.join(output_dir, "context_tokenizer"),
            safe_serialization=False
        )
    

class PreFLMRDataset(Dataset):
    
    def __init__(self,
                 args,
                 data_df, passages_df, 
                 query_tokenizer, context_tokenizer, image_processor):
        self.args = args
        self.data_df = data_df
        self.passages_df = passages_df
        self.query_tokenizer = query_tokenizer
        self.context_tokenizer = context_tokenizer
        self.image_processor = image_processor
        
        self.unique_passage_ids = set(self.passages_df.index)
        
        if self.args.doc_use_images:
            self.doc_image_title2image = json.load(open(self.args.doc_image_title2image))
    
    
    def __len__(self):
        return len(self.data_df)
    
    def __getitem__(self, idx):
        # The key change: handle both integer and list indices for batching.
        # This makes the method robust to how the DataLoader provides indices.
        if isinstance(idx, int):
            row = self.data_df.iloc[idx]
            return self._process_single_item(row)
        else: # idx is a list of integers
            rows = self.data_df.iloc[idx]
            batch = [self._process_single_item(row) for _, row in rows.iterrows()]
            return {k: [dic[k] for dic in batch] for k in batch[0]}

    def _process_single_item(self, row):
        """Helper method to process a single row of data."""
        query = row['instruction'] + row['question']
        
        # Positive example processing
        pos_item_ids = row['pos_item_ids']
        pos_item_id = random.choice(pos_item_ids)
        pos_psg_row = self.passages_df.loc[pos_item_id] 
        pos_passage = pos_psg_row['passage_content']

        # Query image processing
        query_image_path = os.path.join(self.args.image_root_dir, row['img_path'])
        query_image = Image.open(query_image_path).convert('RGB')
        query_pixel_values = self.image_processor(query_image, return_tensors='pt')['pixel_values']

        # Negative examples processing
        neg_item_ids = random.sample(list(self.unique_passage_ids - set(pos_item_ids)), 
                                      self.args.num_negative_examples)
        neg_psg_rows = [self.passages_df.loc[neg_item_id] for neg_item_id in neg_item_ids]
        neg_passages = [r['passage_content'] for r in neg_psg_rows]
        
        passages = [pos_passage] + neg_passages
        
        inputs = dict(
            query=query,
            passages=passages,
            query_pixel_values=query_pixel_values
        )
        
        # Document image processing
        if self.args.doc_use_images:
            pos_image_path = os.path.join(self.args.doc_image_root_dir, self.doc_image_title2image[pos_psg_row[self.args.title_key]])
            neg_image_paths = [os.path.join(self.args.doc_image_root_dir, self.doc_image_title2image[r[self.args.title_key]])
                                for r in neg_psg_rows]
            context_images = [Image.open(image_path).convert('RGB') 
                                for image_path in [pos_image_path] + neg_image_paths]
            context_pixel_values = self.image_processor(context_images, return_tensors='pt')['pixel_values']
            
            inputs["context_pixel_values"] = context_pixel_values
            
        return inputs
        
    # def __getitem__(self, idx):
    #     row = self.data_df.iloc[idx]
    #     query = row['instruction'] + row['question']
        
    #     pos_item_ids = row['pos_item_ids']
    #     pos_item_id = random.choice(pos_item_ids)
    #     pos_psg_row = self.passages_df.loc[pos_item_id] 
    #     pos_passage = pos_psg_row['passage_content']

    #     query_image_path = os.path.join(self.args.image_root_dir, row['img_path']) # ???
    #     query_image = Image.open(query_image_path).convert('RGB')
    #     query_pixel_values = self.image_processor(query_image, return_tensors='pt')['pixel_values'] # [1, 3, 224, 224]
        
    #     # negatives
    #     neg_item_ids = random.sample(list(self.unique_passage_ids - set(pos_item_ids)), 
    #                                  self.args.num_negative_examples)
    #     neg_psg_rows = [ self.passages_df.loc[neg_item_id] for neg_item_id in neg_item_ids ]
    #     neg_passages = [r['passage_content'] for r in neg_psg_rows]
        
    #     passages = [pos_passage] + neg_passages
        
    #     inputs = dict(
    #         query=query,
    #         passages=passages,
    #         query_pixel_values=query_pixel_values
    #     )
        
    #     if self.args.doc_use_images:
    #         pos_image_path = os.path.join(self.args.doc_image_root_dir, self.doc_image_title2image[pos_psg_row[self.args.title_key]])
    #         neg_image_paths = [ os.path.join(self.args.doc_image_root_dir, self.doc_image_title2image[r[self.args.title_key]])
    #                            for r in neg_psg_rows]
    #         context_images = [Image.open(image_path).convert('RGB') 
    #                           for image_path in [pos_image_path] + neg_image_paths]
    #         context_pixel_values = self.image_processor(context_images, return_tensors='pt')['pixel_values']
            
    #         inputs["context_pixel_values"] = context_pixel_values
        
    #     return inputs
        
    def collate_fn(self, batch):
        queries = [ex['query'] for ex in batch]
        passages = [] # [pos, neg, neg, neg, pos, ...]
        for ex in batch:
            passages.extend(ex['passages'])

        Q_encoding = self.query_tokenizer(queries)
        Q_pixel_values = torch.cat([ex['query_pixel_values'] for ex in batch], dim=0)
        D_encoding = self.context_tokenizer(passages)
        
        # according to `modeling_flmr.py:FLMRModelForRetrieval.forward`
        inputs = dict(
            query_input_ids=Q_encoding['input_ids'],
            query_attention_mask=Q_encoding['attention_mask'],
            query_pixel_values=Q_pixel_values,
            context_input_ids=D_encoding['input_ids'],
            context_attention_mask=D_encoding['attention_mask'],
            use_in_batch_negatives=True,
            in_batch_negatives_from_all_gpus=False,
            num_negative_examples=self.args.num_negative_examples,
            query_concat_output_from_vision_encoder=True,
            query_concat_output_from_text_encoder=True,
            context_concat_output_from_vision_encoder=False,
            context_concat_output_from_text_encoder=True,
        )
        
        if self.args.doc_use_images:
            context_pixel_values = torch.cat([ex['context_pixel_values'] for ex in batch], dim=0)
            inputs['context_pixel_values'] = context_pixel_values
            inputs['context_concat_output_from_vision_encoder'] = True
            
        return inputs




In [2]:
from argparse import Namespace

my_args = MyArguments

training_args = Namespace(
    seed = 42,
    do_eval=False,
)


In [3]:
transformers.set_seed(training_args.seed)

In [4]:
## setting up tokenizer
query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained(
    my_args.model_name_or_path, subfolder="query_tokenizer")
context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained(
    my_args.model_name_or_path, subfolder="context_tokenizer")
image_processor = AutoImageProcessor.from_pretrained(my_args.image_processor_name)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [6]:
my_args.dataset_hf_path

'/workspace/datasets'

In [7]:
data = load_dataset(my_args.dataset_hf_path)['train']

In [8]:
def filter_long_examples(example):
    tokens = query_tokenizer(
        example["paraphrased_question"],
        padding=False,
        truncation=False,
        return_tensors="pt"
    )
    length = tokens["input_ids"].shape[1]
    return length <= 64

print(f"Original dataset size: {len(data)}")
data = data.filter(filter_long_examples)
print(f"Filtered dataset size (<= 128 tokens): {len(data)}")

Original dataset size: 128548
Filtered dataset size (<= 128 tokens): 127659


In [9]:
query_instructions = ["Using the provided image, obtain documents that address the subsequent question:",
"Retrieve documents that provide an answer to the question alongside the image:",
"Extract documents linked to the question provided in conjunction with the image:",
"Utilizing the given image, obtain documents that respond to the following question:",
"Using the given image, access documents that provide insights into the following question:",
"Obtain documents that correspond to the inquiry alongside the provided image: ",
"With the provided image, gather documents that offer a solution to the question: ",
"Utilizing the given image, obtain documents that respond to the following question:",
]

for item in query_instructions:
    tokens = query_tokenizer(
        item,
        padding=False,
        truncation=False,
        return_tensors="pt"
    )
    length = tokens["input_ids"].shape[1]
    print(length)

16
16
16
17
18
15
18
17


In [14]:
data_df = data.to_pandas().set_index('question_id')

In [15]:
passages_df = load_dataset("/workspace/multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR", f'{my_args.dataset}_passages')['train_passages']\
    .to_pandas()

# evqa have duplicates in passage_id, deduplicate here
if len(passages_df) != len(passages_df['passage_id'].unique()):
    print('## deduplicating passage_ids, before: {}, after: {}'.format(
        len(passages_df),
        len(passages_df['passage_id'].unique())
    ))
    passages_df.drop_duplicates('passage_id', inplace=True)

# keep passage_id column, evqa needs it
passages_df['passage_id_index'] = passages_df['passage_id']
passages_df.set_index('passage_id_index', inplace=True) # important
    
dataset = PreFLMRDataset(args=my_args,
                            data_df=data_df, passages_df=passages_df,
                            query_tokenizer=query_tokenizer, 
                            context_tokenizer=context_tokenizer,
                            image_processor=image_processor)
collate_fn = dataset.collate_fn

## deduplicating passage_ids, before: 50205, after: 50195


In [16]:
if my_args.split_eval_from_train_examples != -1:
    print(f'## splitting eval set of size {my_args.split_eval_from_train_examples} from training set...')
    torch.manual_seed(training_args.seed)
    dataset, eval_dataset = random_split(dataset, [
        len(dataset) - my_args.split_eval_from_train_examples,
        my_args.split_eval_from_train_examples
    ])
elif training_args.do_eval:
    print(f'## building eval dataset...')
    eval_data_df = load_dataset(my_args.dataset_hf_path, f'{my_args.dataset}_data')['valid']\
        .to_pandas().set_index('question_id')
    eval_passages_df = load_dataset(my_args.dataset_hf_path, f'{my_args.dataset}_passages')['valid_passages']\
        .to_pandas().set_index('passage_id')
    eval_dataset = PreFLMRDataset(args=my_args,
                                data_df=eval_data_df, passages_df=eval_passages_df,
                                query_tokenizer=query_tokenizer, 
                                context_tokenizer=context_tokenizer,
                                image_processor=image_processor)
else:
    eval_dataset = None

In [17]:
# print(f'## len(dataset): {len(dataset)}')
# print(f'## dataset[0]: {pformat(dataset[0])}')
# print(f'## eval_dataset: {pformat(eval_dataset)}')

In [18]:
!pwd

/workspace/MuKA


In [19]:
cd ..

/workspace


In [None]:
import copy
import os
import pathlib
import string
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch import Tensor, nn
from torch.utils.cpp_extension import load
import torch.nn.functional as F

from transformers import AutoModel, AutoConfig
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.clip import CLIPVisionModel
# FLMR.flmr.models.flmr
#####

from FLMR.flmr.models.flmr.modeling_flmr import FLMRContextEncoderOutput, FLMRQueryEncoderOutput, FLMRModelForRetrievalOutput, FLMRPreTrainedModel, FLMRPretrainedModelForRetrieval, FLMRMultiLayerPerceptron, FLMRTextModel, FLMRVisionModel
from FLMR.flmr.models.flmr.modeling_flmr import FLMR_MODEL_INPUTS_DOCSTRING, _CONFIG_FOR_DOC, FLMR_MODEL_QUERY_INPUTS_DOCSTRING, FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING
#####
from FLMR.flmr.models.flmr.configuration_flmr import FLMRConfig, FLMRTextConfig, FLMRVisionConfig
from FLMR.flmr.models.flmr.tokenization_flmr import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer
from FLMR.flmr.models.flmr.tokenization_flmr_fast import FLMRQueryEncoderTokenizerFast, FLMRContextEncoderTokenizerFast
from FLMR.flmr.models.flmr.flmr_utils import (
    colbert_score,
    colbert_score_reduce,
    get_rank,
    get_world_size,
)


# Integrate into Huggingface 

In [21]:
pwd

'/workspace'

In [41]:
import os
from PIL import Image
document_content = "One Hanover (formerly known as India House, Hanover Bank Building, and New York Cotton Exchange Building) is a commercial building at 1 Hanover Square, on the southwestern edge of the square, in the Financial District of Lower Manhattan in New York City."
sample_doc_images_path = "/workspace/z_debug_img_examples"

# 폴더 안의 모든 파일 이름 가져오기
image_files = [f for f in os.listdir(sample_doc_images_path) 
               if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]

# 이미지 불러오기
images = [Image.open(os.path.join(sample_doc_images_path, f)) for f in image_files]

doc_images = image_processor(images)

pixel_values = doc_images['pixel_values'] # this is a list of image

doc_seg_1 = "One Hanover (formerly known as India House, Hanover Bank Building, and New York Cotton Exchange Building) is a commercial building at " 
image_seg_1 = "1 Hanover Square" 
doc_seg_2 = ", on the southwestern edge of the square, in the Financial District of Lower " 
image_seg_2 = "Manhattan" 
doc_seg_3 = " in " 
image_seg_3 = "New York City" 
doc_seg_4 = "."

segments = [doc_seg_1, image_seg_1, doc_seg_2, image_seg_2, doc_seg_3, image_seg_3, doc_seg_4]

token_labels = []
all_input_ids = []

for seg in segments:
    # 토크나이징
    tokenized = context_tokenizer([seg], padding=False, truncation=False)
    input_ids = tokenized['input_ids'][0]  # 배치 0번째
    
    # segment type 결정
    if seg in [doc_seg_1, doc_seg_2, doc_seg_3, doc_seg_4]:
        label = 0
    else:
        label = 1

    # special token (101, 2, 102) 제외
    filtered_input_ids = []
    filtered_labels = []
    for tid in input_ids:
        if tid in [101, 2, 102]:
            continue
        filtered_input_ids.append(tid)
        filtered_labels.append(label)

    all_input_ids.extend(filtered_input_ids)
    token_labels.extend(filtered_labels)

all_input_ids = [101] + [2] + all_input_ids + [102]
token_labels = [0] + [0] + token_labels + [0]  # special token은 임의로 doc=0 처리 가능

print("final input_ids:", all_input_ids)
print("final token_labels:", token_labels)
print("len input_ids:", len(all_input_ids))
print("len token_labels:", len(token_labels))

max_length = 512
pad_token_id = 0
# padding 필요 시 뒤쪽 채우기
padding_len = max_length - len(all_input_ids)
if padding_len > 0:
    all_input_ids += [pad_token_id] * padding_len
    token_labels += [0] * padding_len  # padding token은 doc=0 처리
else:
    # 512 이상이면 자르기
    all_input_ids = all_input_ids[:max_length]
    token_labels = token_labels[:max_length]

print("final input_ids length:", len(all_input_ids))      # 항상 512
print("final token_labels length:", len(token_labels))    # 항상 512

doc_input = context_tokenizer(
    [document_content],  # 문자열을 리스트로 감싸기
    padding='max_length',
    truncation=True,
    max_length=512,
    return_tensors='pt'
)
print(doc_input['input_ids'].shape)
input_ids = doc_input['input_ids']
attention_mask = doc_input['attention_mask']

query_text = 'Retrieve documents that provide an answer to the question alongside the image:' + 'Which notable landmark located in the Financial District of this New York City borough is included in the Wall Street Historic District?'
query_image_sample_path = os.path.join("/workspace/M2KR_Images/EVQA", 'workspace/question_images/10_mile_panorama_of_NYC%2C_Feb.%2C_2018.jpg')
query_image = Image.open(query_image_sample_path).convert("RGB")
query_inputs = image_processor(images=query_image, return_tensors="pt")

query_pixel_values = query_inputs['pixel_values']
query_input_text = query_tokenizer(query_text, max_length=84)
query_input_ids = query_input_text['input_ids']
query_attention_mask = query_input_text['attention_mask']

final input_ids: [101, 2, tensor(2028), tensor(14393), tensor(1006), tensor(3839), tensor(2124), tensor(2004), tensor(2634), tensor(2160), tensor(1010), tensor(14393), tensor(2924), tensor(2311), tensor(1010), tensor(1998), tensor(2047), tensor(2259), tensor(6557), tensor(3863), tensor(2311), tensor(1007), tensor(2003), tensor(1037), tensor(3293), tensor(2311), tensor(2012), tensor(1015), tensor(14393), tensor(2675), tensor(1010), tensor(2006), tensor(1996), tensor(8772), tensor(3341), tensor(1997), tensor(1996), tensor(2675), tensor(1010), tensor(1999), tensor(1996), tensor(3361), tensor(2212), tensor(1997), tensor(2896), tensor(7128), tensor(1999), tensor(2047), tensor(2259), tensor(2103), tensor(1012), 102]
final token_labels: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0]
len input_ids: 52
len token_labels: 52
final input_ids length: 512
final token_labels length: 512
torch

In [59]:
import importlib
import MuKA.modeling_flmr_dosung as dosung

importlib.reload(dosung)

from MuKA.modeling_flmr_dosung import FLMRModelForRetrieval

In [36]:
# from FLMR.flmr.models.flmr.configuration_flmr import FLMRConfig, FLMRTextConfig, FLMRVisionConfig
config = FLMRConfig.from_pretrained(my_args.model_name_or_path)

model = FLMRModelForRetrievalDS.from_pretrained(
    my_args.model_name_or_path,
    query_tokenizer=query_tokenizer,
    context_tokenizer=context_tokenizer,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [50]:
context_outputs = model.doc(
    input_ids=input_ids,
    attention_mask=attention_mask,
    pixel_values=pixel_values,
    token_labels=token_labels,
)

512


In [52]:
D, D_mask = context_outputs.late_interaction_output, context_outputs.context_mask

In [51]:
query_outputs = model.query(
    input_ids=query_input_ids,
    attention_mask=query_attention_mask,
    pixel_values=query_pixel_values,
    # token_labels=token_labels,
)

In [53]:
Q = query_outputs.late_interaction_output

In [55]:
score = model.score(Q, D, D_mask)

In [56]:
score

tensor([108.3231], grad_fn=<SumBackward1>)

# Implement Indexing

In [1]:
from datasets import load_from_disk

passage_ds = load_from_disk("/workspace/processed_test_passage_ds")
print(passage_ds)

Dataset({
    features: ['language', 'passage_id', 'passage_content', 'passage_img_path', 'token_labels', 'passage_img_paths'],
    num_rows: 51472
})


In [2]:
passage_contents = list(zip(
    passage_ds['passage_content'],
    [None] * len(passage_ds),
    passage_ds['passage_img_paths'],
    passage_ds['token_labels']
))
passage_ids = passage_ds["passage_id"]

In [3]:
from typing import Optional, Tuple, Union
from transformers import AutoImageProcessor
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm

def _sort_by_length(ids, mask, bsize, *args):
    if ids.size(0) <= bsize:
        return ids, mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices
    
    return_array = [ids[indices], mask[indices]]
    for arg in args:
        if isinstance(arg, torch.Tensor):
            return_array.append(arg[indices])
        else:
            # arg is a list, and we want to sort the list according to indices
            return_array.append([arg[i] for i in indices])

    return *return_array, reverse_indices


def _split_into_batches(ids, mask, bsize, *args):
    batches = []
    for offset in range(0, ids.size(0), bsize):
        batch = [ids[offset:offset+bsize], mask[offset:offset+bsize]]
        for arg in args:
            batch.append(arg[offset:offset+bsize])
        batches.append(batch)
    return batches


def _stack_3D_tensors(groups):
    bsize = sum([x.size(0) for x in groups])
    maxlen = max([x.size(1) for x in groups])
    hdim = groups[0].size(2)

    output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)

    offset = 0
    for x in groups:
        endpos = offset + x.size(0)
        output[offset:endpos, :x.size(1)] = x
        offset = endpos

    return output

In [4]:
from flmr import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer, FLMRConfig, FLMRModelForRetrieval
from transformers import AutoImageProcessor

In [5]:
query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained(
    "/workspace/PreFLMR_ViT-G", subfolder="query_tokenizer")
context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained(
    "/workspace/PreFLMR_ViT-G", subfolder="context_tokenizer")
image_processor = AutoImageProcessor.from_pretrained("/workspace/FLMR/CLIP-ViT-bigG-14-laion2B-39B-b160k")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [6]:
config = FLMRConfig.from_pretrained("/workspace/PreFLMR_ViT-G")

model = FLMRModelForRetrieval.from_pretrained(
    "/workspace/PreFLMR_ViT-G",
    query_tokenizer=query_tokenizer,
    context_tokenizer=context_tokenizer,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
bsize=8
keep_dims=True
to_cpu=False
showprogress=False
return_tokens=False

In [18]:
docs = passage_contents 

In [22]:
# docs can be
# (1) list of text
# (2) list of tuples (text, image_features, None)
# (3) list of tuples (text, None, image_paths)
# (4) list of tuples (text, None, images_paths, token_labels)

if isinstance(docs[0], tuple):
    texts = []
    image_features = []
    image_paths = []
    multi_image_paths = []
    token_labels_list = []
    for doc in docs:
        if len(doc) == 3:
            text, image_feature, image_path = doc
        elif len(doc) == 4:
            text, image_feature, images_paths, t_labels = doc
            multi_image_paths.append(images_paths)
            token_labels_list.append(t_labels)
        else:
            raise ValueError("Tuple format not recognized")

        texts.append(text)
        image_features.append(image_feature)

        if len(doc) == 3:
            image_paths.append(image_path)
    
    docs = texts
    if image_features[0] is not None:
        image_features = torch.FloatTensor(np.stack(image_features))
        is_input_image_features = True
    else:
        is_input_image_features = False

    multimodal_docs = True
else:
    image_features = None
    image_paths = None
    multimodal_docs = False

if bsize:
    # we change this part to enable dynamically loading image features to avoid memory overflow
    # This bsize function is used in the original ColBERT codebase to split inputs into multiple batches
    context_encoding = model.context_tokenizer(docs)
    ids, mask = context_encoding['input_ids'], context_encoding['attention_mask']
    
    if multimodal_docs:
        if multi_image_paths:
            ids, mask, image_features, multi_image_paths, reverse_indices = _sort_by_length(ids, mask, bsize, image_features, multi_image_paths)
            batches = _split_into_batches(ids, mask, bsize, image_features, multi_image_paths)
        else:
            # print(ids[0], mask[0], image_features[0], image_paths[0])
            # print(image_features.shape)
            ids, mask, image_features, image_paths, reverse_indices = _sort_by_length(ids, mask, bsize, image_features, image_paths)
            # print(image_features.shape)
            # print(len(ids), len(mask), len(image_features), len(image_paths))
            # print(ids[0], mask[0], image_features[0], image_paths[0])
            batches = _split_into_batches(ids, mask, bsize, image_features, image_paths)
    else:
        ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
        batches = _split_into_batches(ids, mask, bsize)

    # text_batches, reverse_indices = model.context_tokenizer(docs, bsize=bsize)
    
    returned_text = []
    if return_tokens:
        text_batches = [(input_ids, attention_mask) for input_ids, attention_mask, _, _ in batches]
        returned_text = [text for batch in text_batches for text in batch[0]]
        returned_text = [returned_text[idx] for idx in reverse_indices.tolist()]
        returned_text = [returned_text]
    
    
    keep_dims_ = True if keep_dims == 'flatten' else keep_dims
    return_mask = True if keep_dims == 'flatten' else False

    encoded_batches = []

    for batch in tqdm(batches):
        if multimodal_docs:
            if multi_image_paths:
                input_ids, attention_mask, image_features, batch_multi_image_paths = batch
                all_pixel_values = []
                for imgs in batch_multi_image_paths:
                    images = [Image.open(img_path).convert("RGB") for img_path in imgs]
                    pixel_values = image_processor(images, return_tensors="pt").pixel_values
                    all_pixel_values.append(pixel_values)
                context_output = model.doc(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pixel_values=all_pixel_values,
                    keep_dims=True if keep_dims == 'flatten' else keep_dims,
                    return_mask=True if keep_dims == 'flatten' else False,
                    # to_cpu=to_cpu,
                    concat_output_from_vision_encoder=True,
                    token_labels=token_labels  # pass token labels if available
                )                        
            else:
                input_ids, attention_mask, image_features, image_paths = batch
                if is_input_image_features:
                    context_output = model.doc(
                        input_ids=input_ids, 
                        attention_mask=attention_mask, 
                        image_features=image_features,
                        keep_dims=keep_dims_, 
                        return_mask=return_mask, 
                        # to_cpu=to_cpu,
                        concat_output_from_vision_encoder=True,
                    )
                else:
                    # Open the images in image_paths and convert to pixel_values by using ImageProcessor
                    images = [Image.open(image_path).convert("RGB") for image_path in image_paths]
                    pixel_values = image_processor(images, return_tensors="pt").pixel_values
                    context_output = model.doc(
                        input_ids, 
                        attention_mask, 
                        pixel_values=pixel_values, 
                        keep_dims=keep_dims_, 
                        return_mask=return_mask, 
                        # to_cpu=to_cpu,
                        concat_output_from_vision_encoder=True,
                    )
        else:
            input_ids, attention_mask = batch
            context_output = model.doc(
                input_ids=input_ids,
                attention_mask=attention_mask,
                keep_dims=keep_dims_, 
                return_mask=return_mask, 
                # to_cpu=to_cpu,
            )
        encoded_batches.append(context_output)

100%|██████████| 13/13 [01:04<00:00,  4.94s/it]


# Model

In [23]:
def build_flmr_model(
    config: FLMRConfig,
    token_fusion_mode="replace",
    query_tokenizer=None,
    context_tokenizer=None,
):
    model = {}

    model["config"] = config
    model["token_fusion_mode"] = token_fusion_mode
    model["vision_model_version"] = config.vision_model_version

    model["context_text_encoder"] = FLMRTextModel(config.text_config)
    model["context_text_encoder_linear"] = nn.Linear(config.text_config.hidden_size, config.dim, bias=False)

    # Tokenizer 초기화
    if query_tokenizer is None:
        logger.warning(
            "query_tokenizer is not provided. Using bert-base-uncased. "
            "Pass FLMRQueryEncoderTokenizer if you need extended vocab."
        )
        query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained("bert-base-uncased")

    if context_tokenizer is None:
        logger.warning(
            "context_tokenizer is not provided. Using bert-base-uncased. "
            "Pass FLMRContextEncoderTokenizer if you need extended vocab."
        )
        context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained("bert-base-uncased")

    model["query_tokenizer"] = query_tokenizer
    model["context_tokenizer"] = context_tokenizer

    model["mapping_network_prefix_length"] = config.mapping_network_prefix_length
    model["vision_encoder_embedding_size"] = config.vision_config.hidden_size
    model["text_encoder_embedding_size"] = config.text_config.hidden_size
    model["late_interaction_embedding_size"] = config.dim

    # Vision encoder
    if config.use_vision_encoder:
        model["context_vision_projection"] = FLMRMultiLayerPerceptron(
            (
                model["vision_encoder_embedding_size"],
                (model["late_interaction_embedding_size"] * model["mapping_network_prefix_length"]) // 2,
                model["late_interaction_embedding_size"] * model["mapping_network_prefix_length"],
            )
        )
        model["context_vision_encoder"] = FLMRVisionModel(config.vision_config)

        if config.use_transformer_mapping_network:
            try:
                from transformers import BertConfig
                from transformers.models.bert.modeling_bert import BertEncoder
            except Exception as e:
                raise ImportError(f"Failed to import BertConfig/BertEncoder. {e}")

            transformer_mapping_config = BertConfig.from_pretrained(config.transformer_mapping_config_base)

            assert (
                config.text_config.hidden_size == transformer_mapping_config.hidden_size
            ), "Text hidden_size and transformer hidden_size must match for cross attention."

            transformer_mapping_config.num_hidden_layers = config.transformer_mapping_num_hidden_layers
            transformer_mapping_config.is_decoder = True
            transformer_mapping_config.add_cross_attention = True

            model["transformer_mapping_input_linear"] = nn.Linear(
                model["vision_encoder_embedding_size"], transformer_mapping_config.hidden_size
            )
            model["transformer_mapping_network"] = BertEncoder(transformer_mapping_config)
            model["transformer_mapping_output_linear"] = nn.Linear(
                transformer_mapping_config.hidden_size, model["late_interaction_embedding_size"]
            )

    # Text encoder 공유 여부
    if config.separate_query_and_context_text_encoder:
        model["query_text_encoder"] = copy.deepcopy(model["context_text_encoder"])
        model["query_text_encoder_linear"] = copy.deepcopy(model["context_text_encoder_linear"])
    else:
        model["query_text_encoder"] = model["context_text_encoder"]
        model["query_text_encoder_linear"] = model["context_text_encoder_linear"]

    # Vision encoder 공유 여부
    if config.use_vision_encoder:
        if config.separate_query_and_context_vision_encoder:
            model["query_vision_encoder"] = copy.deepcopy(model["context_vision_encoder"])
            model["query_vision_projection"] = copy.deepcopy(model["context_vision_projection"])
        else:
            model["query_vision_encoder"] = model["context_vision_encoder"]
            model["query_vision_projection"] = model["context_vision_projection"]

    # CPU extension
    if config.load_cpu_extension:
        try:
            FLMRModelForRetrieval.try_load_torch_extensions()
        except Exception as e:
            raise RuntimeError(
                "Unable to load `segmented_maxsim.cpp`. "
                "Download it manually from HuggingFace repo and place next to model file.\n"
                f"{e}"
            )

    # Mask punctuation
    if config.mask_punctuation:
        model["skiplist"] = {
            w: True
            for symbol in string.punctuation
            for w in [symbol, context_tokenizer.encode(symbol, add_special_tokens=False)[0]]
        }

    # Instruction mask
    if config.mask_instruction_token is not None:
        model["mask_instruction"] = True
        model["instruction_token_id"] = query_tokenizer.encode(
            config.mask_instruction_token, add_special_tokens=False
        )[0]
    else:
        model["mask_instruction"] = False

    # Loss
    model["loss_fn"] = torch.nn.CrossEntropyLoss()

    return model


In [24]:
config = FLMRConfig.from_pretrained(my_args.model_name_or_path)

# 네 함수로 모델 구조 초기화
model = build_flmr_model(
    config,
    token_fusion_mode="replace",
    query_tokenizer=query_tokenizer,
    context_tokenizer=context_tokenizer,
)

# Query Side

In [25]:
pixel_values = None,
image_features = None,
concat_output_from_vision_encoder = None,
concat_output_from_text_encoder = None,
output_attentions = None,
output_hidden_states = None,
token_labels = None

In [26]:
concat_output_from_vision_encoder = None

if concat_output_from_vision_encoder is None:
    concat_output_from_vision_encoder = model['config'].query_concat_output_from_vision_encoder

concat_output_from_text_encoder = None 

if concat_output_from_text_encoder is None:
    concat_output_from_text_encoder = model['config'].query_concat_output_from_text_encoder

In [27]:
output_attentions = None

output_hidden_states = None

output_attentions = output_attentions if output_attentions is not None else model['config'].output_attentions
output_hidden_states = (
    output_hidden_states if output_hidden_states is not None else model['config'].output_hidden_states
)

In [28]:
import os
from PIL import Image


text = 'Retrieve documents that provide an answer to the question alongside the image:' + 'Which notable landmark located in the Financial District of this New York City borough is included in the Wall Street Historic District?'
image_sample_path = os.path.join("/workspace/M2KR_Images/EVQA", 'workspace/question_images/10_mile_panorama_of_NYC%2C_Feb.%2C_2018.jpg')
image = Image.open(image_sample_path).convert("RGB")
inputs = image_processor(images=image, return_tensors="pt")

pixel_values = inputs['pixel_values']
input_text = query_tokenizer(text, max_length=84)
input_ids = input_text['input_ids']
attention_mask = input_text['attention_mask']

In [33]:
input_modality = []
if pixel_values is not None or image_features is not None:
    input_modality.append("image")
if input_ids is not None and attention_mask is not None:
    input_modality.append("text")

text_encoder_outputs = None
vision_encoder_outputs = None
transformer_mapping_outputs = None

In [34]:
print(input_modality)

['image', 'text']


In [35]:
image_features = None

In [36]:
if "image" in input_modality:
    assert (
        pixel_values is not None or image_features is not None
    ), "pixel_values or image_features must be provided if image modality is used"
    assert (
        pixel_values is None or image_features is None
    ), "pixel_values and image_features cannot be provided at the same time"

In [37]:
model.keys()

dict_keys(['config', 'token_fusion_mode', 'vision_model_version', 'context_text_encoder', 'context_text_encoder_linear', 'query_tokenizer', 'context_tokenizer', 'mapping_network_prefix_length', 'vision_encoder_embedding_size', 'text_encoder_embedding_size', 'late_interaction_embedding_size', 'context_vision_projection', 'context_vision_encoder', 'transformer_mapping_input_linear', 'transformer_mapping_network', 'transformer_mapping_output_linear', 'query_text_encoder', 'query_text_encoder_linear', 'query_vision_encoder', 'query_vision_projection', 'skiplist', 'mask_instruction', 'instruction_token_id', 'loss_fn'])

In [38]:
# Forward the text encoder
# input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
text_encoder_outputs = model['query_text_encoder'](input_ids, attention_mask=attention_mask) # need to fix limit to 32 token
text_encoder_hidden_states = text_encoder_outputs[0]
text_embeddings = model['query_text_encoder_linear'](text_encoder_hidden_states)

In [39]:
def mask(input_ids, skiplist):
    mask = [[(x not in skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
    return mask

In [40]:
model['instruction_token_id']

1024

In [41]:
def query_mask(input_ids, skiplist):
    if not model['mask_instruction']:
        return mask(input_ids, skiplist)

    # find the position of end of instruction in input_ids
    # mask the tokens before the position
    sep_id = model['instruction_token_id']
    sep_positions = torch.argmax((input_ids == sep_id).int(), dim=1).tolist()
    # if any of the positions is lower than 1, set to 1
    for i, x in enumerate(sep_positions):
        if x < 1:
            sep_positions[i] = 1
            logger.error(f"can not find the separator in the input_ids: {input_ids[i].tolist()}")
    mask = [
        [
            (x not in skiplist) and (x != 0) and (index > sep_positions[seq_index] or index < 2)
            for index, x in enumerate(d)
        ]
        for seq_index, d in enumerate(input_ids.cpu().tolist())
    ]
    return mask

In [42]:
sample_query_mask = query_mask(input_ids,[])

In [43]:
mask = torch.tensor(sample_query_mask).unsqueeze(2).float()

text_embeddings = text_embeddings * mask

In [44]:
text_embeddings.shape

torch.Size([1, 84, 128])

In [45]:
if pixel_values is not None:
    batch_size = pixel_values.shape[0]
    # Forward the vision encoder
    # pixel_values = pixel_values.to(self.device)
    ####### need to modify #######
    if len(pixel_values.shape) == 5:
        # Multiple ROIs are provided
        # merge the first two dimensions
        pixel_values = pixel_values.reshape(
            -1, pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4]
        )
    vision_encoder_outputs = model['query_vision_encoder'](pixel_values, output_hidden_states=True)
    vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0]

In [46]:
# Forward the vision projection / mapping network
vision_embeddings = model['query_vision_projection'](vision_embeddings)
vision_embeddings = vision_embeddings.view(batch_size, -1, model['late_interaction_embedding_size'])

In [47]:
model['config'].use_transformer_mapping_network

True

In [48]:
# select the second last layer
vision_second_last_layer_hidden_states = vision_encoder_outputs.hidden_states[-2][:, 1:]
# transformer_mapping
transformer_mapping_input_features = model['transformer_mapping_input_linear'](
    vision_second_last_layer_hidden_states
)

# Cross attention
encoder_mask = torch.ones_like(mask).to(mask.device, dtype=mask.dtype)
if len(model['config'].query_mask_input_ids_skip_list) > 0:
    encoder_mask[torch.isin(input_ids, torch.tensor(model['config'].query_mask_input_ids_skip_list))] = 0
cross_attention_length = model['config'].transformer_mapping_cross_attention_length
if text_encoder_hidden_states.shape[1] > cross_attention_length:
    text_encoder_hidden_states = text_encoder_hidden_states[:, :cross_attention_length]
    encoder_mask = encoder_mask[:, :cross_attention_length]

In [49]:
import torch

def invert_attention_mask(encoder_attention_mask: torch.Tensor) -> torch.Tensor:
    """
    0/1 마스크를 attention-friendly mask로 변환
    1 → 0 (keep), 0 → -10000 (mask)
    
    Args:
        encoder_attention_mask: [batch_size, seq_len] 또는 [batch_size, seq_len, seq_len]
    
    Returns:
        extended_attention_mask: [batch_size, 1, 1, seq_len] 또는 [batch_size, 1, seq_len, seq_len]
    """
    if encoder_attention_mask.dim() == 3:
        # [batch, seq_len, seq_len] → [batch, 1, seq_len, seq_len]
        extended_attention_mask = encoder_attention_mask[:, None, :, :]
    elif encoder_attention_mask.dim() == 2:
        # [batch, seq_len] → [batch, 1, 1, seq_len]
        extended_attention_mask = encoder_attention_mask[:, None, None, :]
    else:
        raise ValueError(f"Wrong shape for attention_mask: {encoder_attention_mask.shape}")
    
    # 1 → 0, 0 → -10000 (attention 무시)
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
    return extended_attention_mask

In [50]:
# Obtain cross attention mask
encoder_extended_attention_mask = invert_attention_mask(encoder_mask.squeeze(-1))

In [51]:
# Pass through the transformer mapping
transformer_mapping_outputs = model['transformer_mapping_network'](
    transformer_mapping_input_features,
    encoder_hidden_states=text_encoder_hidden_states,
    encoder_attention_mask=encoder_extended_attention_mask,
)
transformer_mapping_output_features = transformer_mapping_outputs.last_hidden_state
# Convert the dimension to FLMR dim
transformer_mapping_output_features = model['transformer_mapping_output_linear'](
    transformer_mapping_output_features
)
# Merge with the vision embeddings
vision_embeddings = torch.cat([vision_embeddings, transformer_mapping_output_features], dim=1)

In [52]:
vision_embeddings.shape

torch.Size([1, 288, 128])

In [53]:
text_embeddings.shape

torch.Size([1, 84, 128])

In [54]:
if concat_output_from_vision_encoder and concat_output_from_text_encoder:
    Q = torch.cat([text_embeddings, vision_embeddings], dim=1)
    if isinstance(concat_output_from_vision_encoder, list) or isinstance(concat_output_from_text_encoder, list):
        # When lists are passed in, mask the output accordingly
        assert isinstance(concat_output_from_vision_encoder, list) and isinstance(concat_output_from_text_encoder, list), "concat_output_from_vision_encoder and concat_output_from_text_encoder must be of the same type."
        # obtain the size of each output
        text_size = text_embeddings.shape[1]
        vision_size = vision_embeddings.shape[1]

        # Prepare the mask
        concat_output_mask = torch.zeros_like(Q).to(Q.device)

        # Mask the late interaction outputs
        concat_output_mask[:, :text_size] = torch.tensor(concat_output_from_text_encoder).bool().unsqueeze(-1).unsqueeze(-1)
        concat_output_mask[:, text_size:] = torch.tensor(concat_output_from_vision_encoder).bool().unsqueeze(-1).unsqueeze(-1)

        Q = Q * concat_output_mask

elif concat_output_from_vision_encoder:
    Q = vision_embeddings
elif concat_output_from_text_encoder:
    Q = text_embeddings

vision_encoder_attentions = (
    vision_encoder_outputs.attentions
    if vision_encoder_outputs is not None
    and hasattr(vision_encoder_outputs, "attentions")
    and output_attentions
    else None
)
vision_encoder_hidden_states = (
    vision_encoder_outputs.hidden_states
    if vision_encoder_outputs is not None
    and hasattr(vision_encoder_outputs, "hidden_states")
    and output_hidden_states
    else None
)
text_encoder_attentions = (
    text_encoder_outputs.attentions
    if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions
    else None
)
text_encoder_hidden_states = (
    text_encoder_outputs.hidden_states
    if text_encoder_outputs is not None
    and hasattr(text_encoder_outputs, "hidden_states")
    and output_hidden_states
    else None
)
transformer_mapping_network_attentions = (
    transformer_mapping_outputs.attentions
    if transformer_mapping_outputs is not None
    and hasattr(transformer_mapping_outputs, "attentions")
    and output_attentions
    else None
)
transformer_mapping_network_hidden_states = (
    transformer_mapping_outputs.hidden_states
    if transformer_mapping_outputs is not None
    and hasattr(transformer_mapping_outputs, "hidden_states")
    and output_hidden_states
    else None
)

In [55]:
query_output = FLMRQueryEncoderOutput(
    pooler_output=Q[:, 0, :],
    late_interaction_output=torch.nn.functional.normalize(Q, p=2, dim=2),
    vision_encoder_attentions=vision_encoder_attentions,
    vision_encoder_hidden_states=vision_encoder_hidden_states,
    text_encoder_attentions=text_encoder_attentions,
    text_encoder_hidden_states=text_encoder_hidden_states,
    transformer_mapping_network_attentions=transformer_mapping_network_attentions,
    transformer_mapping_network_hidden_states=transformer_mapping_network_hidden_states,
)

# Document Side

In [57]:
import os
from PIL import Image
document_content = "One Hanover (formerly known as India House, Hanover Bank Building, and New York Cotton Exchange Building) is a commercial building at 1 Hanover Square, on the southwestern edge of the square, in the Financial District of Lower Manhattan in New York City."
sample_doc_images_path = "/workspace/z_debug_img_examples"

# 폴더 안의 모든 파일 이름 가져오기
image_files = [f for f in os.listdir(sample_doc_images_path) 
               if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]

# 이미지 불러오기
images = [Image.open(os.path.join(sample_doc_images_path, f)) for f in image_files]

doc_images = image_processor(images)

doc_input = model['context_tokenizer'](
    [document_content],  # 문자열을 리스트로 감싸기
    padding='max_length',
    truncation=True,
    max_length=512,
    return_tensors='pt'
)
print(doc_input['input_ids'].shape)
input_ids = doc_input['input_ids']
attention_mask = doc_input['attention_mask']
pixel_values = doc_images['pixel_values'] # this is a list of image

In [66]:
# input_ids: torch.Tensor,
# attention_mask: torch.Tensor,
# pixel_values: Optional[torch.Tensor] = None
# image_features: Optional[torch.Tensor] = None
concat_output_from_vision_encoder = None
concat_output_from_text_encoder = None
keep_dims = True
return_mask = True
output_attentions = None
output_hidden_states = None

In [67]:
if concat_output_from_vision_encoder is None:
    concat_output_from_vision_encoder =  model['config'].context_concat_output_from_vision_encoder

if concat_output_from_text_encoder is None:
    concat_output_from_text_encoder =  model['config'].context_concat_output_from_text_encoder

output_attentions = output_attentions if output_attentions is not None else  model['config'].output_attentions
output_hidden_states = (
    output_hidden_states if output_hidden_states is not None else  model['config'].output_hidden_states
)

input_modality = []
if pixel_values is not None or image_features is not None:
    input_modality.append("image")
if input_ids is not None and attention_mask is not None:
    input_modality.append("text")

text_encoder_outputs = None
vision_encoder_outputs = None

In [68]:
# Forward the text encoder
# input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
text_encoder_outputs = model['context_text_encoder'](input_ids, attention_mask=attention_mask)
text_embeddings = text_encoder_outputs[0]
text_embeddings = model['context_text_encoder_linear'](text_embeddings)

In [69]:
input_ids.shape

torch.Size([1, 512])

In [70]:
text_embeddings.shape

torch.Size([1, 512, 128])

In [71]:
def mask(input_ids, skiplist):
    mask = [[(x not in skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
    return mask

In [72]:
mask_ = torch.tensor(mask(input_ids, skiplist=model['skiplist'])).unsqueeze(2).float()
text_embeddings = text_embeddings * mask_

In [73]:
text_embeddings.shape

torch.Size([1, 512, 128])

In [74]:
# Forward the vision encoder
# pixel_values = pixel_values.to(self.device)
all_vision_embeddings = []
all_image_masks = []

for pixel_value in pixel_values:

    pixel_value = torch.tensor(pixel_value)           # shape: [3, 224, 224]
    pixel_value = pixel_value.unsqueeze(0)            # shape: [1, 3, 224, 224]
    
    vision_encoder_outputs = model['context_vision_encoder'](pixel_value)
    vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0]
    vision_embeddings = model['context_vision_projection'](vision_embeddings)
    vision_embeddings = vision_embeddings.view(
        -1, model['mapping_network_prefix_length'], model['late_interaction_embedding_size']
    )
    image_mask = torch.ones(vision_embeddings.shape[0], vision_embeddings.shape[1], 1)
    
    all_vision_embeddings.append(vision_embeddings)
    all_image_masks.append(image_mask)

# vision_embeddings = torch.cat(all_vision_embeddings, dim=0)
# image_mask = torch.cat(all_image_masks, dim=0)

# Now Implement Multi-Image (View) Document Embeddings

In [139]:
text_embeddings.shape

torch.Size([1, 512, 128])

In [140]:
doc_input['input_ids'].shape

torch.Size([1, 512])

In [141]:
doc_seg_1 = "One Hanover (formerly known as India House, Hanover Bank Building, and New York Cotton Exchange Building) is a commercial building at " 
image_seg_1 = "1 Hanover Square" 
doc_seg_2 = ", on the southwestern edge of the square, in the Financial District of Lower " 
image_seg_2 = "Manhattan" 
doc_seg_3 = " in " 
image_seg_3 = "New York City" 
doc_seg_4 = "."

segments = [doc_seg_1, image_seg_1, doc_seg_2, image_seg_2, doc_seg_3, image_seg_3, doc_seg_4]

token_labels = []
all_input_ids = []

for seg in segments:
    # 토크나이징
    tokenized = model['context_tokenizer']([seg], padding=False, truncation=False)
    input_ids = tokenized['input_ids'][0]  # 배치 0번째
    
    # segment type 결정
    if seg in [doc_seg_1, doc_seg_2, doc_seg_3, doc_seg_4]:
        label = 0
    else:
        label = 1

    # special token (101, 2, 102) 제외
    filtered_input_ids = []
    filtered_labels = []
    for tid in input_ids:
        if tid in [101, 2, 102]:
            continue
        filtered_input_ids.append(tid)
        filtered_labels.append(label)

    all_input_ids.extend(filtered_input_ids)
    token_labels.extend(filtered_labels)

all_input_ids = [101] + [2] + all_input_ids + [102]
token_labels = [0] + [0] + token_labels + [0]  # special token은 임의로 doc=0 처리 가능

print("final input_ids:", all_input_ids)
print("final token_labels:", token_labels)
print("len input_ids:", len(all_input_ids))
print("len token_labels:", len(token_labels))

final input_ids: [101, 2, tensor(2028), tensor(14393), tensor(1006), tensor(3839), tensor(2124), tensor(2004), tensor(2634), tensor(2160), tensor(1010), tensor(14393), tensor(2924), tensor(2311), tensor(1010), tensor(1998), tensor(2047), tensor(2259), tensor(6557), tensor(3863), tensor(2311), tensor(1007), tensor(2003), tensor(1037), tensor(3293), tensor(2311), tensor(2012), tensor(1015), tensor(14393), tensor(2675), tensor(1010), tensor(2006), tensor(1996), tensor(8772), tensor(3341), tensor(1997), tensor(1996), tensor(2675), tensor(1010), tensor(1999), tensor(1996), tensor(3361), tensor(2212), tensor(1997), tensor(2896), tensor(7128), tensor(1999), tensor(2047), tensor(2259), tensor(2103), tensor(1012), 102]
final token_labels: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0]
len input_ids: 52
len token_labels: 52


In [142]:
max_length = 512
pad_token_id = 0
# padding 필요 시 뒤쪽 채우기
padding_len = max_length - len(all_input_ids)
if padding_len > 0:
    all_input_ids += [pad_token_id] * padding_len
    token_labels += [0] * padding_len  # padding token은 doc=0 처리
else:
    # 512 이상이면 자르기
    all_input_ids = all_input_ids[:max_length]
    token_labels = token_labels[:max_length]

print("final input_ids length:", len(all_input_ids))      # 항상 512
print("final token_labels length:", len(token_labels))    # 항상 512

final input_ids length: 512
final token_labels length: 512


In [143]:
text_embeddings.shape

torch.Size([1, 512, 128])

In [144]:
all_vision_embeddings[0].shape

torch.Size([1, 32, 128])

In [146]:
# 예제 shapes
batch_size = 1
seq_len = 512
hidden_dim = 128

text_embeddings = text_embeddings.clone()         # [1, 512, 128]
token_labels = token_labels                       # list, length=512
all_vision_embeddings = all_vision_embeddings    # list of tensors, 각 [1, num_objects, 128]

combined_embeddings = []  # 최종 결과
vision_list_idx = 0       # all_vision_embeddings 순서 추적

text_idx = 0
while text_idx < seq_len:
    # 현재 token embedding
    combined_embeddings.append(text_embeddings[0, text_idx])  # [128]

    # 현재 token이 1이고 연속된 1 구간 마지막인지 확인
    if token_labels[text_idx] == 1:
        # 마지막 1인지 확인
        is_last_in_1_block = (text_idx == seq_len-1) or (token_labels[text_idx+1] == 0)
        if is_last_in_1_block:
            # all_vision_embeddings 순서대로 끼워 넣기
            vision_emb = all_vision_embeddings[vision_list_idx][0]  # [num_objects, 128] 중 0번째
            for v in range(vision_emb.shape[0]):
                combined_embeddings.append(vision_emb[v])  # [128]
            vision_list_idx += 1

    text_idx += 1

# 리스트 → tensor로 변환
combined_embeddings = torch.stack(combined_embeddings, dim=0)  # [seq_len + num_vision_tokens, 128]
combined_embeddings = combined_embeddings.unsqueeze(0)          # batch dimension 추가 [1, new_seq_len, 128]

In [147]:
combined_embeddings.shape

torch.Size([1, 608, 128])

In [150]:
D = combined_embeddings

In [151]:
combined_embeddings = []
combined_mask = []

vision_list_idx = 0
text_idx = 0
seq_len = text_embeddings.shape[1]

while text_idx < seq_len:
    # text embedding 추가
    combined_embeddings.append(text_embeddings[0, text_idx])
    combined_mask.append(mask_[0, text_idx])  # text mask 추가

    # token_labels 기준으로 연속된 1 구간 마지막인지 확인
    if token_labels[text_idx] == 1:
        is_last_in_1_block = (text_idx == seq_len-1) or (token_labels[text_idx+1] == 0)
        if is_last_in_1_block:
            # vision embedding 추가
            vision_emb = all_vision_embeddings[vision_list_idx][0]    # [num_objects, 128]
            vision_mask = all_image_masks[vision_list_idx][0]          # [num_objects]

            for v_idx in range(vision_emb.shape[0]):
                combined_embeddings.append(vision_emb[v_idx])
                combined_mask.append(vision_mask[v_idx])

            vision_list_idx += 1

    text_idx += 1

# 리스트 → tensor
combined_embeddings = torch.stack(combined_embeddings, dim=0).unsqueeze(0)  # [1, new_seq_len, 128]
combined_mask = torch.tensor(combined_mask).unsqueeze(0)                     # [1, new_seq_len]

In [153]:
combined_mask.shape

torch.Size([1, 608])

In [154]:
D = torch.nn.functional.normalize(D, p=2, dim=2)

In [None]:
# if self.use_gpu:
#     D = D.half()

if keep_dims is False:
    D, mask = D.cpu(), mask.bool().cpu().squeeze(-1)
    D = [d[mask[idx]] for idx, d in enumerate(D)]

In [155]:
vision_encoder_attentions = (
    vision_encoder_outputs.attentions
    if vision_encoder_outputs is not None
    and hasattr(vision_encoder_outputs, "attentions")
    and output_attentions
    else None
)
vision_encoder_hidden_states = (
    vision_encoder_outputs.hidden_states
    if vision_encoder_outputs is not None
    and hasattr(vision_encoder_outputs, "hidden_states")
    and output_hidden_states
    else None
)
text_encoder_attentions = (
    text_encoder_outputs.attentions
    if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions
    else None
)
text_encoder_hidden_states = (
    text_encoder_outputs.hidden_states
    if text_encoder_outputs is not None
    and hasattr(text_encoder_outputs, "hidden_states")
    and output_hidden_states
    else None
)

In [157]:
doc_output = FLMRContextEncoderOutput(
    pooler_output=D[:, 0, :],
    late_interaction_output=D,
    context_mask=combined_mask.bool() if return_mask else None,
    vision_encoder_attentions=vision_encoder_attentions,
    vision_encoder_hidden_states=vision_encoder_hidden_states,
    text_encoder_attentions=text_encoder_attentions,
    text_encoder_hidden_states=text_encoder_hidden_states,
)

In [158]:
doc_output

FLMRContextEncoderOutput(pooler_output=tensor([[-6.5469e-02, -5.4655e-02,  5.5116e-02,  4.0943e-02,  3.2786e-02,
         -3.3283e-02, -1.5472e-01, -1.4978e-01,  7.9919e-02, -2.7139e-02,
         -7.8336e-02,  8.7071e-02,  3.5316e-02, -1.3999e-02,  1.3035e-01,
         -1.1219e-01,  2.5035e-01,  8.6285e-02,  3.5063e-02,  2.0529e-02,
         -2.8919e-02,  7.0027e-03,  9.5450e-02,  5.6365e-02,  6.2840e-02,
         -2.6747e-01, -8.6408e-02, -4.4448e-02, -2.5829e-02,  2.6721e-01,
          8.8805e-04,  3.0560e-02,  1.0944e-01, -1.1481e-01,  7.1092e-02,
         -3.8833e-03,  4.0219e-02, -1.0752e-01, -4.3091e-02, -4.1950e-02,
          1.6034e-02, -2.3201e-02, -2.2948e-02, -7.7183e-02, -5.9975e-02,
          1.1562e-01, -3.6185e-02, -1.7240e-02,  1.4780e-01, -8.4972e-02,
          8.5642e-02,  6.6792e-02, -6.3070e-03,  1.5256e-01, -4.7794e-02,
          1.4319e-01, -3.4916e-02, -7.1033e-03,  3.8594e-02, -1.4227e-02,
          4.0437e-02, -4.4177e-03,  4.4324e-02,  1.0775e-01,  5.9642e-02,

# Tokenizer & Sample data format

In [46]:
data[0]['paraphrased_question']

'Which notable landmark located in the Financial District of this New York City borough is included in the Wall Street Historic District?'

In [43]:
data[0]['instruction']

'Retrieve documents that provide an answer to the question alongside the image:'

In [49]:
data[0]['pos_item_ids'][0]

'WikiWeb_1 Wall Street Court_1'

In [45]:
data[0]['img_path']

'workspace/question_images/10_mile_panorama_of_NYC%2C_Feb.%2C_2018.jpg'

In [50]:
image_sample_path = os.path.join("/workspace/M2KR_Images/EVQA", 'workspace/question_images/10_mile_panorama_of_NYC%2C_Feb.%2C_2018.jpg')

In [52]:
from PIL import Image
text = 'Retrieve documents that provide an answer to the question alongside the image:' + 'Which notable landmark located in the Financial District of this New York City borough is included in the Wall Street Historic District?'
image_sample_path = os.path.join("/workspace/M2KR_Images/EVQA", 'workspace/question_images/10_mile_panorama_of_NYC%2C_Feb.%2C_2018.jpg')
image = Image.open(image_sample_path).convert("RGB")
inputs = image_processor(images=image, return_tensors="pt")

In [58]:
text = data[0]['instruction']+data[0]['paraphrased_question']

In [98]:
query_tokens = query_tokenizer(text=data[0]['paraphrased_question'], max_length=84)

In [99]:
query_tokens_with_instruction = query_tokenizer(text=text, max_length=84,)

In [100]:
query_tokens

{'input_ids': tensor([[ 101,    1, 2029, 3862, 8637, 2284, 1999, 1996, 3361, 2212, 1997, 2023,
          2047, 2259, 2103, 5538, 2003, 2443, 1999, 1996, 2813, 2395, 3181, 2212,
          1029,  102,  103,  103,  103,  103,  103,  103,  103,  103,  103,  103,
           103,  103,  103,  103,  103,  103,  103,  103,  103,  103,  103,  103,
           103,  103,  103,  103,  103,  103,  103,  103,  103,  103,  103,  103,
           103,  103,  103,  103,  103,  103,  103,  103,  103,  103,  103,  103,
           103,  103,  103,  103,  103,  103,  103,  103,  103,  103,  103,  103]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}

In [101]:
query_tokens['input_ids'].shape

torch.Size([1, 84])

In [102]:
query_tokens_with_instruction

{'input_ids': tensor([[  101,     1, 12850,  5491,  2008,  3073,  2019,  3437,  2000,  1996,
           3160,  4077,  1996,  3746,  1024,  2029,  3862,  8637,  2284,  1999,
           1996,  3361,  2212,  1997,  2023,  2047,  2259,  2103,  5538,  2003,
           2443,  1999,  1996,  2813,  2395,  3181,  2212,  1029,   102,   103,
            103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
            103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
            103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
            103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
            103,   103,   103,   103]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [103]:
query_tokens_with_instruction['input_ids'].shape

torch.Size([1, 84])

In [104]:
query_tokenizer.decode(token_ids=query_tokens['input_ids'][0])

'[CLS] [unused0] which notable landmark located in the financial district of this new york city borough is included in the wall street historic district? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]'

In [105]:
query_tokenizer.decode(token_ids=query_tokens_with_instruction['input_ids'][0])

'[CLS] [unused0] retrieve documents that provide an answer to the question alongside the image : which notable landmark located in the financial district of this new york city borough is included in the wall street historic district? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]'

In [106]:
inputs.keys()

dict_keys(['pixel_values'])