In [None]:
import json
import argparse
import os
import datetime
from tqdm import tqdm
import regex as re
from torchvision.ops import box_iou
import torch
from transformers import AutoProcessor
def save_json(file_path, data):
    """
    Save data to a JSON file.

    Args:
        file_path (str): Path to the JSON file.
        data (dict): Data to save.
    """
    with open(file_path, 'w') as f:
        json.dump(data, f, indent=4)

def load_json(file_path):
    """
    Load data from a JSON file.

    Args:
        file_path (str): Path to the JSON file.

    Returns:
        dict: Data loaded from the file.
    """
    with open(file_path, 'r') as f:
        return json.load(f)

def extract_bbox_from_text(ans):
    pattern = re.compile(r'\[(((0|1)\.(\d){3}\,\s*){3}((0|1)\.(\d){3}))\]')
    match_list = pattern.findall(ans)

    if len(match_list) > 0:
        answer = [list(map(float,match[0].split(","))) for match in match_list]
    else:
        answer = "FAILED"
    return answer

def calculate_iou(gt_bbox_list, pred_bbox_list):
    iou_matrix = box_iou(torch.tensor(gt_bbox_list).float(), torch.tensor(pred_bbox_list).float())
    iou_argsort_matrix = torch.argsort(iou_matrix.flatten(),descending=True).argsort().reshape(iou_matrix.shape)#iouが大きい順にソートしたインデックスを取得
    # print("-" * 50)
    # print(iou_matrix)
    pred_index_list =  torch.full((len(pred_bbox_list),), False, dtype=torch.bool)
    gt_index_list = torch.full((len(gt_bbox_list),), False, dtype=torch.bool)

    iou_info_list = []

    for i in range(len(gt_bbox_list)):
        max_iou_index = torch.where(iou_argsort_matrix == i)
        if not gt_index_list[max_iou_index[0]] and not pred_index_list[max_iou_index[1]]:
            iou_info_list.append( {
                "gt_index": max_iou_index[0].item(),
                "pred_index": max_iou_index[1].item(),
                "iou_value": iou_matrix[max_iou_index].item()
            })
            gt_index_list[max_iou_index[0]] = True
            pred_index_list[max_iou_index[1]] = True
    # print(iou_info_list)
    return iou_info_list

def sort_list_of_dicts(data, key, reverse=False):
    """
    Sort a list of dictionaries by the specified key.

    Args:
        data (list): List of dictionaries to sort.
        key (str): Key to sort by.
        reverse (bool): Sort in descending order if True, ascending if False.

    Returns:
        list: Sorted list of dictionaries.
    """
    return sorted(data, key=lambda x: x[key], reverse=reverse)

def paligemma_get_bbox(text: str,*args, **kwargs):
    pattern = r"(((<loc\d{4}>){4}))"
    matches = re.findall(pattern, text)
    # print("matches", matches)
    bbox_list = []
    for m in matches:
        y1, x1, y2, x2 = [int(x)/1023.0 for x in re.findall(r'\d+', m[1])]
        bbox_list.append([x1, y1, x2, y2])
    return bbox_list, []



In [63]:
print("Loading JSON data...")
correct_json_path = "/data_ssd/refcoco_plus/refcoco_plus_kosmos2_validation.json"
correct_data = load_json(correct_json_path)

generated_json_path = "/data_ssd/USER_DATA/omote/iam-llms-finetune/experiment_output/kosmos-2/refcoco-pulus_train-and-eval_vision-proj-llm_distance-forward-kl-loss_b128acum4_2025-06-25T02_33_21/checkpoint-4695/eval_output/refcoco_plus_kosmos2_validation/2025-06-26T08_22_16/eval_output.json"
generated_data = load_json(generated_json_path)

assert len(correct_data) == len(generated_data), "Length of correct and generated data does not match."

correct_data = sort_list_of_dicts(correct_data, "id")
generated_data = sort_list_of_dicts(generated_data, "id")

Loading JSON data...


In [64]:


all_iou_list = []
generated_iou_list = []

gt_iou_num_count= 0
matched_data_num = 0

iou_threshold = 0.5

processor = AutoProcessor.from_pretrained("/data_ssd/huggingface_model_weights/microsoft/kosmos-2-patch14-224")

eval_dict ={}

for i in range(len(correct_data)):
    assert correct_data[i]["id"] == generated_data[i]["id"], f"ID mismatch at index {i}."
    ann_id = correct_data[i]["ann_id"]
    if  ann_id not in eval_dict:
        eval_dict[ann_id] = {
            "ann_id": ann_id,
            "gt_name": correct_data[i]["gt_entities"][0][0],
            "correct_data": [correct_data[i]["gt_entities"][0][-1]],
            "generated_data": []
        }
    # eval_dict[ann_id]["correct_data"].append(correct_data[i]["gt_entities"][0][-1])
    caption, entities = processor.post_process_generation(generated_data[i]["conversations"][0]["value"]+generated_data[i]["conversations"][1]["value"])
    # for e in entities:
    #     if e[0] == eval_dict[ann_id]["gt_name"]:
    #         generated_bbox = e[0][-1]
    #         break
    #print(generated_data[i]["conversations"][0]["value"]+generated_data[i]["conversations"][1]["value"])
    generated_bbox = entities[0][-1] if len(entities) > 0 else []
    
    
    eval_dict[ann_id]["generated_data"].extend(generated_bbox)

total_data_num = len(eval_dict)
    
for eval_item in eval_dict.values():
    correct_bbox = eval_item["correct_data"]
    generated_bbox = eval_item["generated_data"]
    gt_iou_num_count += len(correct_bbox)
    
    if len(generated_bbox) == 0:
        iou_list = [0.0] * len(correct_bbox)
    else:
        iou_list = [item["iou_value"] for item in calculate_iou(correct_bbox, generated_bbox)]
        generated_iou_list.extend(iou_list)
        if len(iou_list) < len(correct_bbox):
            iou_list.extend([0.0] * (len(correct_bbox) - len(iou_list)))
            
    all_iou_list.extend(iou_list)
    iou_threshold_count = sum(1 for iou in iou_list if iou >= iou_threshold)
    if iou_threshold_count > 0:
        matched_data_num += 1
    # else:
    #     print(f"Warning: No IoU above threshold {iou_threshold} for ann_id {eval_item}.")

assert len(all_iou_list) == gt_iou_num_count, f"Length of all_iou_list {len(all_iou_list)} does not match gt_iou_num_count {len(gt_iou_num_count)}."


print("-" * 50)
print(f"Total data number: {total_data_num}")
print(f"Matched data number: {matched_data_num}")
print(f"len all_iou_list: {len(all_iou_list)}")
print(f"len generated_iou_list: {len(generated_iou_list)}")
print("-" * 50)
accuracy = matched_data_num / total_data_num
print(f"Accuracy: {accuracy}")
mean_all_iou = sum(all_iou_list) / len(all_iou_list) if len(all_iou_list) > 0 else 0
print(f"Mean IoU: {mean_all_iou}")
mean_generated_iou = sum(generated_iou_list) / len(generated_iou_list) if len(generated_iou_list) > 0 else 0
print(f"Mean Generated IoU: {mean_generated_iou}")
print("-" * 50)


--------------------------------------------------
Total data number: 3805
Matched data number: 3258
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.8562417871222077
Mean IoU: 0.7872461899363591
Mean Generated IoU: 0.7872461899363591
--------------------------------------------------


##validation
#zeroshot
--------------------------------------------------
Total data number: 3805
Matched data number: 517
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.13587385019710907
Mean IoU: 0.21221880688480102
Mean Generated IoU: 0.21221880688480102
--------------------------------------------------

#finetune
#1epoch
-------------------------------------------------
Total data number: 3805
Matched data number: 2710
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.7122207621550591
Mean IoU: 0.6653832597237046
Mean Generated IoU: 0.6653832597237046
--------------------------------------------------
#2epoch
--------------------------------------------------
Total data number: 3805
Matched data number: 3015
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.7923784494086727
Mean IoU: 0.7365307942018128
Mean Generated IoU: 0.7365307942018128
--------------------------------------------------
#3epoch
--------------------------------------------------
Total data number: 3805
Matched data number: 3066
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.8057818659658345
Mean IoU: 0.7467946519262589
Mean Generated IoU: 0.7467946519262589
--------------------------------------------------
#4epoch
--------------------------------------------------
Total data number: 3805
Matched data number: 3126
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.8215505913272011
Mean IoU: 0.7569882920839176
Mean Generated IoU: 0.7569882920839176
--------------------------------------------------
#5epoch
--------------------------------------------------
Total data number: 3805
Matched data number: 3133
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.8233902759526939
Mean IoU: 0.7591249816942074
Mean Generated IoU: 0.7591249816942074
--------------------------------------------------

#distance loss
#1epoch
--------------------------------------------------
Total data number: 3805
Matched data number: 2950
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.7752956636005256
Mean IoU: 0.713230377256273
Mean Generated IoU: 0.713230377256273
--------------------------------------------------

#distance kl loss
#1epoch
--------------------------------------------------
Total data number: 3805
Matched data number: 2920
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.7674113009198423
Mean IoU: 0.7105647883428146
Mean Generated IoU: 0.7105647883428146
--------------------------------------------------

#2poch
--------------------------------------------------
Total data number: 3805
Matched data number: 3139
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.8249671484888305
Mean IoU: 0.760148260024047
Mean Generated IoU: 0.760148260024047
--------------------------------------------------

#3epoch
--------------------------------------------------
Total data number: 3805
Matched data number: 3208
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.8431011826544021
Mean IoU: 0.7760436532541659
Mean Generated IoU: 0.7760436532541659
--------------------------------------------------

#4epoch
--------------------------------------------------
Total data number: 3805
Matched data number: 3242
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.8520367936925098
Mean IoU: 0.7825445631494484
Mean Generated IoU: 0.7825445631494484
--------------------------------------------------

#5epoch
--------------------------------------------------
Total data number: 3805
Matched data number: 3258
len all_iou_list: 3805
len generated_iou_list: 3805
--------------------------------------------------
Accuracy: 0.8562417871222077
Mean IoU: 0.7872461899363591
Mean Generated IoU: 0.7872461899363591
--------------------------------------------------

In [24]:
for i in range(len(correct_data)):
    ann_id = correct_data[i]["ann_id"]
    if ann_id in [71286, 1069789,2162745]:
        print(f"Warning: Ann_id {ann_id} is in the list of problematic IDs.")
        print(f"Correct data: {correct_data[i]["gt_entities"]}")
        print(f"Generated data: {generated_data[i]["gt_entities"]}")
        print(generated_data[i]["conversations"][0]["value"]+generated_data[i]["conversations"][1]["value"])
        caption, entities = processor.post_process_generation(generated_data[i]["conversations"][0]["value"]+generated_data[i]["conversations"][1]["value"])
        print(entities)
        print("-" * 50)

Correct data: [['<--- that cow', [0.015625, 0.453125, 0.265625, 0.640625]]]
Generated data: [['<--- that cow', [0.015625, 0.453125, 0.265625, 0.640625]]]
<image><grounding> <phrase> <--- that cow</phrase><object><patch_index_0448><patch_index_0648></object>
[('<patch_index_448><patch_index_648>', (0, 0), [(0.015625, 0.453125, 0.265625, 0.640625)])]
--------------------------------------------------
Correct data: [['cow closest to tree with full side showing', [0.015625, 0.453125, 0.265625, 0.640625]]]
Generated data: [['cow closest to tree with full side showing', [0.015625, 0.453125, 0.265625, 0.640625]]]
<image><grounding> <phrase> cow closest to tree with full side showing</phrase><object><patch_index_0448><patch_index_0648></object>
[('cow closest to tree with full side showing', (0, 42), [(0.015625, 0.453125, 0.265625, 0.640625)])]
--------------------------------------------------
Correct data: [['black spot on head, facing other direction', [0.015625, 0.453125, 0.265625, 0.64062

In [8]:
print(eval_dict[correct_data[0]["ann_id"]])

{'gt_name': 'Cellphone reflection', 'correct_data': [[0.5417799999999999, 0.13483483483483483, 0.87852, 0.7730330330330331]], 'generated_data': [(0.546875, 0.140625, 0.890625, 0.765625), (0.546875, 0.140625, 0.890625, 0.765625), (0.546875, 0.140625, 0.890625, 0.765625)]}


In [None]:

    
    




for i in tqdm(range(total_data_num)):
    assert correct_data[i]["id"] == generated_data[i]["id"], f"ID mismatch at index {i}."
    caption, entities = processor.post_process_generation(generated_data[i]["conversations"][0]["value"]+generated_data[i]["conversations"][1]["value"])
    gt_entities = correct_data[i]["gt_entities"]
    gt_name = gt_entities[0][0]
    
    for e in entities:
        if e[0] == gt_name:
            generated_bbox = e[-1][0]
            break
    
    gt_iou_num_count += len(correct_bbox)
    # if len(correct_bbox) >  1:
    #     print(correct_data[i]["id"])
    #     print(correct_bbox)
    #     print(generated_bbox)
    if generated_bbox == "FAILED":
        iou_list = [0.0] * len(correct_bbox)
    else:
        iou_list = [item["iou_value"] for item in calculate_iou(correct_bbox, generated_bbox)]
        generated_iou_list.extend(iou_list)
        if len(iou_list) < len(correct_bbox):
            iou_list.extend([0.0] * (len(correct_bbox) - len(iou_list)))

    all_iou_list.extend(iou_list)
    iou_threshold_count = sum(1 for iou in iou_list if iou >= iou_threshold)
    if iou_threshold_count > 0:
        matched_data_num += 1
        anomaly_matched_data_num += 1
            

assert len(all_iou_list) == gt_iou_num_count, f"Length of all_iou_list {len(all_iou_list)} does not match gt_iou_num_count {len(gt_iou_num_count)}."
print("-" * 50)
print(len(all_iou_list))
print(f"len all_iou_list: {len(all_iou_list)}")
print(f"len generated_iou_list: {len(generated_iou_list)}")

mean_all_iou = sum(all_iou_list) / len(all_iou_list) if len(all_iou_list) > 0 else 0
print(f"Mean IoU: {mean_all_iou}")
mean_generated_iou = sum(generated_iou_list) / len(generated_iou_list) if len(generated_iou_list) > 0 else 0
print(f"Mean Generated IoU: {mean_generated_iou}")

# iou_threshold_count = sum(1 for iou in all_iou_list if iou >= iou_threshold)
# print(f"Number of IoU >= {iou_threshold}: {iou_threshold_count}")

# matched_data_num += iou_threshold_count
# anomaly_matched_data_num = iou_threshold_count if iou_threshold_count > 0 else 0


print("-" * 50)
print(f"Total data number: {total_data_num}")
print(f"Normal data number: {normal_data_num}")
print(f"Anomaly data number: {anomaly_data_num}")

print(f"Model predict normal data number: {model_predict_normal_data_num}")
print(f"Model predict anomaly data number: {model_predict_anomaly_data_num}")

print(f"Matched data number: {matched_data_num}")
print(f"Anomaly matched data number: {anomaly_matched_data_num}")
print("-" * 50)
accuracy = matched_data_num / total_data_num
print(f"Accuracy: {accuracy}")
precision = anomaly_matched_data_num / model_predict_anomaly_data_num if model_predict_anomaly_data_num > 0 else 0
print(f"Precision: {precision}")
recall = anomaly_matched_data_num / anomaly_data_num if anomaly_data_num > 0 else 0
print(f"Recall: {recall}")
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
print(f"F1 Score: {f1_score}")
print("-" * 50)

output_data = {
        "filename": args.generated_json,
        "correct_json": args.gt_json,
        "timestamp": current_date,
        "scores": {
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1_score,
            "mean_all_iou": mean_all_iou,
            "mean_generated_iou": mean_generated_iou,
        },
        "data_num": {
            "total_data_num": total_data_num,
            "normal_data_num": normal_data_num,
            "anomaly_data_num": anomaly_data_num,
            "model_predict_normal_data_num": model_predict_normal_data_num,
            "model_predict_anomaly_data_num": model_predict_anomaly_data_num,
            "matched_data_num": matched_data_num,
            "anomaly_matched_data_num": anomaly_matched_data_num,
            "gt_iou_num_count": gt_iou_num_count,
            "generated_iou_num_count": len(generated_iou_list),
        },
        "other_info": {
            "iou_threshold": iou_threshold,
        }
    }