In [1]:
from data_loader.dataset import *
from data_loader.data_loaders import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
taxonomy_name = "semeval_food"
data_path = "data/SemEval-Food/semeval_food.pickle.bin"
raw_graph_dataset = MAGDataset(name=taxonomy_name, path=data_path, raw=False, existing_partition=True)


loading pickled dataset
dataset loaded


In [3]:
img_path = "data/imgs"
img_feat = "data/img_feat/semeval_food_blip_feat"
json_file = taxonomy_name + "_dataset_final.jsonl"
json_data_path = os.path.join(img_path, json_file)
tokenizer_path = "/home/u2120230655/codes/VTC/all-mpnet-base-v2"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
negative_size = 4
num_tokens = 4

dataset = Dataset_Stage2(
    graph_dataset=raw_graph_dataset,
    json_data_path=json_data_path,
    img_feature_dir=img_feat,
    negative_size=negative_size,
    tokenizer=tokenizer,
    num_tokens=num_tokens)

adding pseudo leaf
building node2pos, node2edge
building valid and test node list
924 1190
Finish loading dataset (2.7509231567382812 seconds)
Loading metadata from data/imgs/semeval_food_dataset_final.jsonl...
Loaded 1488 metadata entries.


Flattening dataset into a 'big pool': 100%|██████████| 1190/1190 [00:00<00:00, 24639.61it/s]

Created 11845 total training pairs.





In [5]:
dataset.__getitem__(0).keys()

dict_keys(['vis_q', 'q_seg_start', 'q_seg_end', 'vis_p', 'vis_c', 'vis_s', 'c_seg_p', 'c_seg_c', 'c_seg_s', 'c_seg_end', 'label'])

In [6]:
def verify_dataset_output(dataset):
    print("\n=== Verifying Dataset Output (Single Sample) ===")
    
    # 取第 0 个样本
    item = dataset[0]
    if item is None:
        print("Error: Item 0 is None!")
        return

    print("Keys present:", item.keys())
    
    # 1. 检查形状
    for k, v in item.items():
        if isinstance(v, torch.Tensor):
            print(f"{k}: shape={v.shape}, dtype={v.dtype}")
        else:
            print(f"{k}: {type(v)}")

    # 2. 解码文本 (这是最重要的!)
    tokenizer = dataset.tokenizer
    
    print("\n--- Decoding Query ---")
    # Q: SegStart + [IMG] + SegEnd
    # 我们把 start 和 end 拼起来看看 (虽然中间缺了 Image Token)
    q_ids = torch.cat([item['q_seg_start'], item['q_seg_end']]).tolist()
    print(tokenizer.decode(q_ids))
    # 预期: [CLS] Query Node: Definition: "...", Image: [SEP]

    print("\n--- Decoding Candidate ---")
    # C: SegP + [IMG] + SegC + [IMG] + SegS + [IMG] + SegEnd
    c_ids = torch.cat([
        item['c_seg_p'], 
        torch.tensor([tokenizer.mask_token_id]), # 模拟 P_Img
        item['c_seg_c'], 
        torch.tensor([tokenizer.mask_token_id]), # 模拟 C_Img
        item['c_seg_s'], 
        torch.tensor([tokenizer.mask_token_id]), # 模拟 S_Img
        item['c_seg_end']
    ]).tolist()
    print(tokenizer.decode(c_ids))
    # 预期: [CLS] Parent... Img: [MASK]; Child... Img: [MASK]; Sibling... Img: [MASK]; [SEP]

def verify_dataloader_output(dataloader):
    print("\n=== Verifying DataLoader Output (Batch) ===")
    
    for batch in dataloader:
        print("Batch Keys:", batch.keys())
        
        # 1. 检查 Tensor 维度 (Batch Size)
        bs = batch['vis_q'].size(0)
        print(f"Batch Size: {bs}")
        
        # 2. 检查 Image Features
        print(f"Visual Q shape: {batch['vis_q'].shape} (Expected: [B, 256])")
        
        # 3. 检查 Mask 逻辑
        # 打印第一个样本的 mask sum，看看长度对不对
        for k in ['q_seg_start', 'c_seg_p']:
            mask = batch[f"{k}_mask"]
            length = mask[0].sum().item()
            print(f"{k} valid length (sample 0): {length}")
            
        # 4. 检查 Label
        print(f"Labels: {batch['label'][:10]} ...")
        
        break # 只看一个 Batch

In [7]:
verify_dataset_output(dataset)


=== Verifying Dataset Output (Single Sample) ===
Keys present: dict_keys(['vis_q', 'q_seg_start', 'q_seg_end', 'vis_p', 'vis_c', 'vis_s', 'c_seg_p', 'c_seg_c', 'c_seg_s', 'c_seg_end', 'label'])
vis_q: shape=torch.Size([256]), dtype=torch.float32
q_seg_start: shape=torch.Size([28]), dtype=torch.int64
q_seg_end: shape=torch.Size([1]), dtype=torch.int64
vis_p: shape=torch.Size([256]), dtype=torch.float32
vis_c: shape=torch.Size([256]), dtype=torch.float32
vis_s: shape=torch.Size([256]), dtype=torch.float32
c_seg_p: shape=torch.Size([26]), dtype=torch.int64
c_seg_c: shape=torch.Size([11]), dtype=torch.int64
c_seg_s: shape=torch.Size([30]), dtype=torch.int64
c_seg_end: shape=torch.Size([1]), dtype=torch.int64
label: shape=torch.Size([]), dtype=torch.float32

--- Decoding Query ---
<s> query node : definition : " absinth is strong green liqueur flavored with wormwood and anise ", image : </s>

--- Decoding Candidate ---
<s> parent node : definition : " liqueur is strong highly flavored swee

In [9]:
dataloader = Stage2DataLoader(
    data_path=data_path,
    taxonomy_name=taxonomy_name,
    img_root_dir=img_path,
    tokenizer_path=tokenizer_path,
    img_feat_dir=img_feat,
    num_image_tokens=4,
    batch_size=8,
    negative_size=negative_size,
    num_workers=2,
    shuffle=True,
    )

Loading Tokenizer from /home/u2120230655/codes/VTC/all-mpnet-base-v2...
Instantiating Dataset_Stage2...
loading pickled dataset
dataset loaded
adding pseudo leaf
building node2pos, node2edge
building valid and test node list
924 1190
Finish loading dataset (2.7590415477752686 seconds)
Loading metadata from data/imgs/semeval_food_dataset_final.jsonl...
Loaded 1488 metadata entries.


Flattening dataset into a 'big pool': 100%|██████████| 1190/1190 [00:00<00:00, 25076.85it/s]

Created 11845 total training pairs.





In [10]:
verify_dataloader_output(dataloader)


=== Verifying DataLoader Output (Batch) ===


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Batch Keys: dict_keys(['vis_q', 'vis_p', 'vis_c', 'vis_s', 'label', 'q_seg_start_ids', 'q_seg_start_mask', 'q_seg_end_ids', 'q_seg_end_mask', 'c_seg_p_ids', 'c_seg_p_mask', 'c_seg_c_ids', 'c_seg_c_mask', 'c_seg_s_ids', 'c_seg_s_mask', 'c_seg_end_ids', 'c_seg_end_mask'])
Batch Size: 8
Visual Q shape: torch.Size([8, 256]) (Expected: [B, 256])
q_seg_start valid length (sample 0): 27
c_seg_p valid length (sample 0): 30
Labels: tensor([0., 0., 0., 0., 0., 1., 1., 0.]) ...
