In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict
import warnings
warnings.filterwarnings('ignore')

plt.style.use('default')
plt.rcParams['font.size'] = 12
plt.rcParams['figure.figsize'] = (12, 8)

In [None]:
from models.model_factory import create_model
from tasks.i2t_tasks import OperatorInductionTask

model_name = "OpenGVLab/InternVL3-8B-Instruct"
data_dir = "./VL-ICL"

print(f"Loading model: {model_name}")
model = create_model('internvl', model_name)
tokenizer = model.tokenizer

# load dataset
task = OperatorInductionTask(data_dir)
print(f"Loaded {len(task.query_data)} queries and {len(task.support_data)} support examples")

In [None]:
import random, copy
n_shot = 4
query = task.query_data[0]

# select demonstrations
support = task.select_demonstrations(query, n_shot)

prompt_parts = [task.get_task_instruction()]
images = []
for demo in support:
    demo_text = task.format_demonstration(demo, include_image_token=True, mode="constrained")
    prompt_parts.append(demo_text)
    if 'image' in demo:
        for path in demo['image']:
            images.append(task.load_image(path))

query_text = task.format_query(query, include_image_token=True, mode="constrained") + " Answer: "
prompt_parts.append(query_text)
for path in query.get('image', []):
    images.append(task.load_image(path))

full_prompt = "

".join(prompt_parts)
print(full_prompt)
print(f"Total images: {len(images)}")

In [None]:
pixel_values = None
num_patches_list = None

if images:
    if len(images) == 1:
        pixel_values = model.load_image(images[0], max_num=12).to(torch.bfloat16).cuda()
    else:
        pixel_values_list = []
        num_patches_list = []
        for img in images:
            pv = model.load_image(img, max_num=6)
            pixel_values_list.append(pv)
            num_patches_list.append(pv.size(0))
        pixel_values = torch.cat(pixel_values_list, dim=0).to(torch.bfloat16).cuda()
print(pixel_values.shape if pixel_values is not None else None)

In [None]:
import types
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
actual_input_ids = None
actual_attention_mask = None

orig_generate = model.model.generate

def instrumented_generate(self, pixel_values=None, input_ids=None, attention_mask=None, **kwargs):
    global actual_input_ids, actual_attention_mask
    actual_input_ids = input_ids.clone()
    actual_attention_mask = attention_mask.clone() if attention_mask is not None else None
    return orig_generate(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, **kwargs)

model.model.generate = types.MethodType(instrumented_generate, model.model)

In [None]:
attention_data: Dict[int, torch.Tensor] = {}
hooks = []

language_model = model.model.language_model
num_layers = len(language_model.model.layers)

# enable attention outputs if supported
if hasattr(model.model, 'config'):
    model.model.config.output_attentions = True

for layer_idx in range(num_layers):
    layer = language_model.model.layers[layer_idx]
    attn_module = None
    for attr in ['self_attn', 'attention', 'attn']:
        if hasattr(layer, attr):
            attn_module = getattr(layer, attr)
            break
    if attn_module is not None:
        def create_hook(idx):
            def hook_fn(module, input, output):
                # output may be (hidden_states, attn_probs)
                if isinstance(output, tuple) and len(output) > 1:
                    attn = output[1]
                elif isinstance(output, tuple):
                    attn = output[0]
                else:
                    attn = getattr(module, 'attn_probs', None)
                if attn is not None:
                    attention_data[idx] = attn.detach().cpu()
            return hook_fn
        hooks.append(attn_module.register_forward_hook(create_hook(layer_idx)))

In [None]:
with torch.no_grad():
    gen_cfg = dict(max_new_tokens=1, do_sample=False)
    if pixel_values is not None and num_patches_list is not None:
        response = model.model.chat(tokenizer, pixel_values, full_prompt, gen_cfg, num_patches_list=num_patches_list)
    elif pixel_values is not None:
        response = model.model.chat(tokenizer, pixel_values, full_prompt, gen_cfg)
    else:
        response = model.model.chat(tokenizer, None, full_prompt, gen_cfg)
print('Model response:', response)

# restore generate
model.model.generate = orig_generate

for h in hooks:
    h.remove()

In [None]:
assert actual_input_ids is not None, "failed to capture token ids"
token_ids = actual_input_ids[0].tolist()
texts = [tokenizer.decode([tid]) for tid in token_ids]

# group consecutive IMG_CONTEXT tokens
groups = []
labels = []
i = 0
img_count = 0
while i < len(token_ids):
    if token_ids[i] == img_context_token_id:
        j = i
        while j < len(token_ids) and token_ids[j] == img_context_token_id:
            j += 1
        groups.append(list(range(i,j)))
        img_count += 1
        labels.append(f"IMG{img_count}")
        i = j
    else:
        groups.append([i])
        labels.append(texts[i])
        i += 1

In [None]:
layer_to_show = min(24, num_layers-1)
head_to_show = 0

attn = attention_data[layer_to_show]  # [batch, heads, q_len, k_len]
attn = attn[0, head_to_show].numpy()

agg = np.zeros((len(groups), len(groups)))
for m, idxs_m in enumerate(groups):
    for n, idxs_n in enumerate(groups):
        agg[m,n] = attn[np.ix_(idxs_m, idxs_n)].mean()

In [None]:
plt.figure(figsize=(10,8))
sns.heatmap(agg, xticklabels=labels, yticklabels=labels, cmap='viridis', vmin=0, vmax=agg.max())
plt.title(f'Layer {layer_to_show} Head {head_to_show} Attention')
plt.xlabel('Key positions')
plt.ylabel('Query positions')
plt.tight_layout()
plt.show()