In [1]:
import json
from pathlib import Path
import wandb
import glob
import os
import regex as re
from torchvision.ops import box_iou
import torch
from copy import deepcopy
from tqdm import tqdm
from sentence_transformers import SentenceTransformer,util
import sys
sys.path.append("/home/omote/cluster_project/iam2/eval")
from eval_utils.custom_oc_cost import get_cmap,get_ot_cost,DetectedInstance
import math
from copy import deepcopy
def load_json(file_path):
    """
    Load data from a JSON file.

    Args:
        file_path (str): Path to the JSON file.

    Returns:
        dict: Data loaded from the file.
    """
    with open(file_path, 'r') as f:
        return json.load(f)
    
def sort_list_of_dicts(data, key, reverse=False):
    """
    Sort a list of dictionaries by the specified key.

    Args:
        data (list): List of dictionaries to sort.
        key (str): Key to sort by.
        reverse (bool): Sort in descending order if True, ascending if False.

    Returns:
        list: Sorted list of dictionaries.
    """
    return sorted(data, key=lambda x: x[key], reverse=reverse)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# filter((row) => row["ce_iou_over_0.5_count"] != null and row["prop_iou_over_0.5_count"] != null)
# runs.summary["result_table"].table.rows[0].filter((row) => row["ce_iou_over_0.5_count"] < row["prop_iou_over_0.5_count"])

In [3]:
def extract_bbox_from_text(ans):
    pattern = re.compile(r'\[(((0|1)\.(\d){3}\,\s*){3}((0|1)\.(\d){3}))\]')
    match_list = pattern.findall(ans)

    if len(match_list) > 0:
        answer = [list(map(float,match[0].split(","))) for match in match_list]
    else:
        answer = "FAILED"
    return answer

def calculate_iou(gt_bbox_list, pred_bbox_list):
    # print(gt_bbox_list)
    # print(pred_bbox_list)
    iou_matrix = box_iou(torch.tensor(gt_bbox_list).float(), torch.tensor(pred_bbox_list).float())
    iou_matrix = torch.nan_to_num(iou_matrix, nan=0.0)  # NaNを0に置き換える
    iou_argsort_matrix = torch.argsort(iou_matrix.flatten(),descending=True).argsort().reshape(iou_matrix.shape)#iouが大きい順にソートしたインデックスを取得
    # print(iou_argsort_matrix)
    # print("-" * 50)
    # print(iou_matrix)
    pred_index_list =  torch.full((len(pred_bbox_list),), False, dtype=torch.bool)
    gt_index_list = torch.full((len(gt_bbox_list),), False, dtype=torch.bool)

    short_index_list = pred_index_list if len(pred_bbox_list) < len(gt_bbox_list) else gt_index_list
    iou_info_list = []

    # print(iou_matrix.numel())
    for i in range(iou_matrix.numel()):
        max_iou_index = torch.where(iou_argsort_matrix == i)
        if not gt_index_list[max_iou_index[0]] and not pred_index_list[max_iou_index[1]]:
            iou_info_list.append( {
                "gt_index": max_iou_index[0].item(),
                "pred_index": max_iou_index[1].item(),
                "iou_value": iou_matrix[max_iou_index].item()
            })
            gt_index_list[max_iou_index[0]] = True
            pred_index_list[max_iou_index[1]] = True
            # print(f"index {i} - gt_index: {max_iou_index[0].item()}, pred_index: {max_iou_index[1].item()}, iou_value: {iou_matrix[max_iou_index].item()}")
        
        if torch.all(short_index_list):
            break
        
    assert len(iou_info_list) == min(len(gt_bbox_list), len(pred_bbox_list)), f"Length mismatch: {len(iou_info_list)} != {min(len(gt_bbox_list), len(pred_bbox_list))}"
    # print(iou_info_list)
    # for iou_info in iou_info_list:
    #     if math.isnan(iou_info["iou_value"]):
    #         print(f"IOU value is NaN for gt index {iou_info['gt_index']} and pred index {iou_info['pred_index']}")
    #         print(iou_matrix[iou_info['gt_index'], iou_info['pred_index']])
    #         print(iou_matrix[iou_info['gt_index'], iou_info['pred_index']].item())
    #         print(iou_info["iou_value"])
    #         print(iou_matrix)
    
    return iou_info_list,iou_matrix,iou_argsort_matrix,pred_index_list, gt_index_list

In [4]:
def check_bbox_valid(bbox, box_w_h=[1, 1], min_bbox_size=1e-6):
    x1, y1, x2, y2 = bbox
    if x1 < 0 or y1 < 0 or x2 < 0 or y2 < 0:
        return False
    if x1 > box_w_h[0] or y1 > box_w_h[1] or x2 > box_w_h[0] or y2 > box_w_h[1]:
        return False
    if x1 >= x2 or y1 >= y2:
        return False
    bbox_area = (x2 - x1) * (y2 - y1)
    if bbox_area < min_bbox_size:
        return False
    return True

def parse_bbox_and_labels(processor,detokenizer_output: str):
    # pattern = r"((<loc\d{4}>){4})[^;<]+|((<loc\d{4}>){4})$"
    # matches = re.findall(pattern, detokenizer_output)
    # print("matches", matches)
    pattern = r"(((<loc\d{4}>){4})([^;<]+))"
    matches = re.findall(pattern, detokenizer_output)
    if len(matches) > 0 and matches[0][-1].strip() != "":
        label_list = []
        bbox_list = []
        for m in matches:
            y1, x1, y2, x2 = [int(x)/1024.0 for x in re.findall(r'\d+', m[1])]
            if check_bbox_valid([x1, y1, x2, y2], box_w_h=[1, 1], min_bbox_size=1e-6):
                bbox_list.append([x1, y1, x2, y2])
                label_list.append(m[-1].strip())
        return bbox_list, label_list

    pattern = r"([^<]+)((<loc\d{4}>){4})"
    matches = re.findall(pattern, detokenizer_output)
    if len(matches) > 0:
        label_pattern = r"^<image>detect\s([^<]+)<loc\d{4}>"
        match = re.findall(label_pattern, detokenizer_output)
        if len(match) > 0 and match[0].strip() != "":
            label = match[0].strip()
            bbox_list = []
            for m in matches:
                y1, x1, y2, x2 = [int(x)/1024.0 for x in re.findall(r'\d+', m[1])]
                if check_bbox_valid([x1, y1, x2, y2], box_w_h=[1, 1], min_bbox_size=1e-6):
                    bbox_list.append([x1, y1, x2, y2])
            label_list = [label] * len(bbox_list)
            return bbox_list, label_list
    
    return [], []


def add_bbox_to_wandb_image(wandb_image, entities,cat_2_id_dict=None,add_number=True):
    # load raw input photo
    # person_label_num = 20
    # other_label_num = 20
    # display_ids = {}
    # for i in range(person_label_num):
    #     display_ids.update({f"person{i+1}": i})
    # class_id_to_label = {int(v): k for k, v in display_ids.items()}
    # for num, i in enumerate(
    #     range(person_label_num, person_label_num + other_label_num)
    # ):
    #     class_id_to_label.update({i: f"p_other{num+1}"})
    assert type(wandb_image) == wandb.Image
    name_list = []
    bbox_list = []
    for entity in entities:
        bbox_list.extend(entity[-1])
        name_list.extend([entity[0]]*len(entity[-1]))
    # print(entities)
    # print(bbox_list)
    # print(name_list)
    assert len(name_list) == len(bbox_list)
        
    if cat_2_id_dict == None:
        tmp_class_num = 200
        id_2_cat_dict = {i:f"cat_{i}" for i in range(tmp_class_num)}
        # cat_2_id_dict = {}
        # # print(name_list)
        # for i,name in enumerate(name_list):
        #     # # print(name,i)
        #     # # print(type(name))
        #     # print({name:i})
        #     cat_2_id_dict.update({name:i})
    else:
        cat_2_id_dict = deepcopy(cat_2_id_dict)
        cat_2_id_dict.update({"unknown":max(cat_2_id_dict.values())+1})
    
        id_2_cat_dict = {v:k for k,v in cat_2_id_dict.items()}
        
    # import pdb;pdb.set_trace()
    # print(cat_2_id_dict)
    appear_num_dict = {k:-1 for k in cat_2_id_dict.keys()}
    class_id = -1
    if len(name_list) > 0:
        all_boxes = []
        # plot each bounding box for this image
        for name, bbox in zip(name_list, bbox_list):
            if cat_2_id_dict is not None and name in cat_2_id_dict:
                class_id = cat_2_id_dict[name]
                if add_number:
                    
                    appear_num_dict[name] += 1
                    name = f"{name}_{appear_num_dict[name]}"
            elif cat_2_id_dict is not None:
                class_id = cat_2_id_dict["unknown"]
                if add_number:
                    appear_num_dict[name] += 1
                    name = f"{name}_{appear_num_dict[name]}"
            else:
                class_id +=1
                
            box_data = {
                "position": {
                    "minX": bbox[0],
                    "maxX": bbox[2],
                    "minY": bbox[1],
                    "maxY": bbox[3],
                },
                "class_id": class_id,  # display_ids[b_name] if b_name in display_ids else 0,
                # optionally caption each box with its class and score
                "box_caption": name,
                # "domain" : "null",#"pixel",
                # "scores" : { }
            }
            all_boxes.append(box_data)

        # log to wandb: raw image, predictions, and dictionary of class labels for each class id
        box_image = wandb.Image(
            wandb_image,
            boxes={
                "predictions": {
                    "box_data": all_boxes,
                    "class_labels": id_2_cat_dict,
                }
            },
        )
            # box_image = wandb.Image(wandb_image, boxes = {"predictions": {"box_data": all_boxes}})
    else:
        box_image = wandb_image
    return box_image

In [5]:
def oc_cost(pred_instance_list,tgt_instance_list, alpha=0.5,beta=0.6):
    cmap_func = lambda x, y: get_cmap(x, y, alpha=alpha, beta=beta,label_or_sim="label")
    otc = get_ot_cost(pred_instance_list, tgt_instance_list, cmap_func)
    return otc

def similariry_score(str1, str2, model: SentenceTransformer):
    # compute embedding for both lists
    embedding_1 = model.encode(str1, show_progress_bar=False)
    embedding_2 = model.encode(str2, show_progress_bar=False)
    score = util.pytorch_cos_sim(embedding_1, embedding_2).item()
    
    #スコア丸め込み
    # score = min(score, 1.0)
    # score = max(score, 0.0)
    
    return score

def create_get_most_similar_category_func(category_list, sentence_transformer_model_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    similarity_model = SentenceTransformer(sentence_transformer_model_path).to(device)
    category_embeddings = similarity_model.encode(category_list, show_progress_bar=False, convert_to_tensor=True)
    def get_most_similar_category(category_name):
        category_embedding = similarity_model.encode(category_name, show_progress_bar=False, convert_to_tensor=True)
        scores = util.pytorch_cos_sim(category_embedding, category_embeddings).squeeze(0)
        most_similar_index = torch.argmax(scores).item()
        return category_list[most_similar_index], scores[most_similar_index].item(), scores
    return get_most_similar_category

def get_per_image_class_result_and_oc_cost(all_gt_annotations, all_pred_annotations, cat_name2id, iou_threshold=0.5):
    """
    Evaluate per-image and per-category results and calculate occlusion cost.

    Args:
        all_gt_annotations (dict): Ground truth annotations.
        all_pred_annotations (dict): Predicted annotations.
        cat_name2id (dict): Category name to ID mapping.
        iou_threshold (float): IOU threshold for true positives.

    Returns:
        dict: Per-image results and occlusion costs.
    """
    per_image_result_dict = {}
    break_index = 10
    oc_cost_list = []
    for index, gt_per_image_annotation_list in tqdm(all_gt_annotations.items()):
        pred_per_image_annotation_list = all_pred_annotations.get(index, [])
        
        # 画像ごとの評価
        pred_instance_list = [DetectedInstance(
            label=ann["category_id"],
            x1=ann["bbox_xyxy"][0],
            y1=ann["bbox_xyxy"][1],
            x2=ann["bbox_xyxy"][2],
            y2=ann["bbox_xyxy"][3]) for ann in pred_per_image_annotation_list]
        tgt_instance_list = [DetectedInstance(
            label=ann["category_id"],
            x1=ann["bbox_xyxy"][0],
            y1=ann["bbox_xyxy"][1],
            x2=ann["bbox_xyxy"][2],
            y2=ann["bbox_xyxy"][3]) for ann in gt_per_image_annotation_list]
        
        oc_cost_value = oc_cost(pred_instance_list, tgt_instance_list, alpha=0.5, beta=0.6)
        oc_cost_list.append(oc_cost_value)
        
        #画像ごと・カテゴリごとの評価準備
        gt_per_category_dict = {}
        pred_per_category_dict = {}
        per_category_result_dict = {"ann_id": gt_per_image_annotation_list[0]["ann_id"]}
        
        for category_id in cat_name2id.values():
            gt_per_category_dict[category_id] = None
            pred_per_category_dict[category_id] = None
            per_category_result_dict[category_id] = None
            
        for annotation in gt_per_image_annotation_list:
            if gt_per_category_dict[annotation["category_id"]] is None:
                gt_per_category_dict[annotation["category_id"]] = []
            gt_per_category_dict[annotation["category_id"]].append(annotation)
        
        for annotation in pred_per_image_annotation_list:
            if pred_per_category_dict[annotation["category_id"]] is None:
                pred_per_category_dict[annotation["category_id"]] = []
            pred_per_category_dict[annotation["category_id"]].append(annotation)
        
        
        for category_id, gt_annotations in gt_per_category_dict.items():
            pred_annotations = pred_per_category_dict[category_id]
            if gt_annotations is None and  pred_annotations is None:
                continue
            
            per_category_result = {
                "iou_list": [],
                "pred_iou_list": [],
                "tp_num": 0,
                "fp_num": 0,
                "fn_num": 0,
            }
            if gt_annotations is None and pred_per_category_dict[category_id] is not None:
                per_category_result["fp_num"] = len(pred_per_category_dict[category_id])
                pred_bbox_list = [ann["bbox_xyxy"] for ann in pred_annotations]
                per_category_result["pred_bbox_list"] = pred_bbox_list
                per_category_result["iou_info_list"] = []
            elif gt_annotations is not None:
                if pred_per_category_dict[category_id] is None:
                    per_category_result["fn_num"] = len(gt_annotations)
                    per_category_result["iou_list"] = [0.0] * len(gt_annotations)
                    per_category_result["pred_bbox_list"] = []
                    per_category_result["iou_info_list"] = []

                else: 
                    gt_bbox_list = [ann["bbox_xyxy"] for ann in gt_annotations]
                    pred_bbox_list = [ann["bbox_xyxy"] for ann in pred_annotations]
                    iou_info_list,iou_matrix,iou_argsort_matrix,pred_index_list, gt_index_listt = calculate_iou(gt_bbox_list, pred_bbox_list)
                    assert ((len(gt_bbox_list) < len(pred_bbox_list) and len(iou_info_list) == len(gt_bbox_list)) or (len(gt_bbox_list) >= len(pred_bbox_list) and len(iou_info_list) == len(pred_bbox_list))), f"Length mismatch in category {category_id}, index {index}: len(iou_info_list)={len(iou_info_list)}, len(gt_bbox_list)={len(gt_bbox_list)}, len(pred_bbox_list)={len(pred_bbox_list)}"
                    # if not((len(gt_bbox_list) < len(pred_bbox_list) and len(iou_info_list) == len(gt_bbox_list)) or \
                    #     (len(gt_bbox_list) >= len(pred_bbox_list) and len(iou_info_list) == len(pred_bbox_list))):
                        # print(f"index: {index}, category_id: {category_id}, len(iou_info_list): {len(iou_info_list)}, len(gt_bbox_list): {len(gt_bbox_list)}, len(pred_bbox_list): {len(pred_bbox_list)}")
                        # print(f"pred_bbox_list: {pred_bbox_list}")
                        # print(f"gt_bbox_list: {gt_bbox_list}")
                        # print(f"iou_info_list: {iou_info_list}")
                        # print(f"iou_matrix: {iou_matrix}")
                        # print(f"iou_argsort_matrix: {iou_argsort_matrix}")
                        # print(f"pred_index_list: {pred_index_list}")
                        # print(f"gt_index_list: {gt_index_listt}")
                        # raise ValueError("IOU information length mismatch")
                    iou_list = [info["iou_value"] for info in iou_info_list]
                    per_category_result["pred_iou_list"] = deepcopy(iou_list)
                    for iou in iou_list:
                        assert not math.isnan(iou), f"IOU value is NaN in category {category_id}, index {index}"
                    if len(iou_list) < len(gt_bbox_list):
                        iou_list += [0.0] * (len(gt_bbox_list) - len(iou_list))
                    
                    
                    # for iou in iou_list:
                    #     assert not math.isnan(iou), f"IOU value is NaN in category {category_id}, index {index}"
                    per_category_result["iou_list"] = iou_list
                    tp_num = sum(1 for iou in iou_list if iou >= iou_threshold)
                    per_category_result["tp_num"] = tp_num
                    per_category_result["fp_num"] = len(pred_bbox_list) - tp_num
                    per_category_result["fn_num"] = len(gt_bbox_list) - tp_num
                    per_category_result["iou_info_list"] = iou_info_list
                    per_category_result["pred_bbox_list"] = pred_bbox_list
                    # if index == 9 and category_id == 84:
                    #     visualize_bbox(
                    #         os.path.join(image_folder_root, images[index]["file_name"]),
                    #         pred_bbox_list,
                    #         [ann["category_name"] for ann in pred_annotations],
                    #         bbox_is_relative=True,
                    #         with_id=True
                    #     )
                    #     print(per_category_result)
                    #     print(pred_bbox_list == gt_bbox_list)
                    #     print(len(pred_bbox_list), len(gt_bbox_list))
                    #     print(iou_info_list)
            
            per_category_result_dict[category_id] = per_category_result
        
        per_image_result_dict[index] = per_category_result_dict
        # if index >= break_index:
        #     break
    return per_image_result_dict, oc_cost_list 

def create_annotations_for_coco(conversation_dataset,categories,get_bbox_func,processor,delete_region_failure=False,unknown_to_similar=False,sentence_transformer_model_path=None):
    if unknown_to_similar and not delete_region_failure:
        raise ValueError("unknown_to_similar is True but delete_region_failure is False. This combination is not supported.")
    elif unknown_to_similar and delete_region_failure:
        get_most_similar_category = create_get_most_similar_category_func(
            list(categories.keys()),
            sentence_transformer_model_path
        )
    else:
        get_most_similar_category = None

    ann_id_converastaion_dict = {}
    for i, conversation in enumerate(conversation_dataset):
        if conversation["ann_id"] not in ann_id_converastaion_dict:
            ann_id_converastaion_dict[conversation["ann_id"]] = []
        ann_id_converastaion_dict[conversation["ann_id"]].append(i)
        
    return_annotations = {}
    
    ann_keys_list = ann_id_converastaion_dict.keys()

    region_failure_count = 0
    region_failure_delim_count = 0
    name_failure_count = 0
    name_failure_delim_count = 0
    name_match_count = 0
    name_match_delim_count = 0

    id_index = 0
    for i, ann_key in enumerate(tqdm(ann_keys_list)):
        for conversation in ann_id_converastaion_dict[ann_key]:
            text = ""
            for conv in conversation_dataset[conversation]["conversations"]:
                text += conv["value"]

            ori_bbox_list,name_list = get_bbox_func(processor,text)
            for name,bbox in zip(name_list,ori_bbox_list):
                bbox_list = [bbox]
                # import pdb;pdb.set_trace()
                if "<patch_index" in name and delete_region_failure:
                    #raise ValueError(f"Unexpected patch index in name: {name}")
                    region_failure_count += 1
                    region_failure_delim_count += len(bbox_list)
                    continue
                elif name not in categories.keys():    
                    name_failure_count += 1
                    name_failure_delim_count += len(bbox_list)
                    if get_most_similar_category is not None:
                        name, score, _ = get_most_similar_category(name)
                else:
                    name_match_count += 1
                    name_match_delim_count += len(bbox_list)
                    
                for bbox in bbox_list:
                    annotation = {
                        "id": id_index,
                        "image_id": i ,
                        "category_id": categories[name] if name in categories else categories["unknown"],
                        "bbox": [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]],  # [x, y, width, height]
                        "area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]),
                        "iscrowd": 0,
                        "score": 1.0,  # Assuming all annotations are perfect for dummy data
                        "category_name": name,
                        "bbox_xyxy": bbox,  # [x1, y1, x2, y2]
                        "is_unknown": 1 if name not in categories else 0,
                        "ann_id": ann_key
                    }
                    
                    if i not in return_annotations:
                        return_annotations[i] = []
                    return_annotations[i].append(annotation)
                    id_index += 1

    return_num_dict = {
        "region_failure_count": region_failure_count,
        "region_failure_delim_count": region_failure_delim_count,
        "name_failure_count": name_failure_count,
        "name_failure_delim_count": name_failure_delim_count,
        "name_match_count": name_match_count,
        "name_match_delim_count": name_match_delim_count
    }
    return return_annotations, return_num_dict

def get_pascal_voc_category():
    cat_id2name = {
    0: "aeroplane",
    1: "bicycle",
    2: "bird",
    3: "boat",
    4: "bottle",
    5: "bus",
    6: "car",
    7: "cat",
    8: "chair",
    9: "cow",
    10: "diningtable",
    11: "dog",
    12: "horse",
    13: "motorbike",
    14: "person",
    15: "pottedplant",
    16: "sheep",
    17: "sofa",
    18: "train",
    19: "tvmonitor",
    }
    
    cat_name2id = {v:k for k, v in cat_id2name.items()}
    return cat_name2id, cat_id2name

In [6]:
def get_correct_table_data(correct_data,wandb_name_to_image,category_name_to_id):
    correct_dict = {}

    all_object_name_list = list(category_name_to_id.keys())
    
    for i in tqdm(range(len(correct_data)), desc="Processing correct data"):
        ann_id = correct_data[i]["ann_id"]
        if  ann_id not in correct_dict:
            bbox_num = 0
            all_entities_dict = {k: [] for k in all_object_name_list}
            for entity in correct_data[i]["gt_entities_quantized_normalized"]:
                name = entity[0]
                all_entities_dict[name].extend(entity[-1])
                bbox_num += len(entity[-1])
                
            image_name = os.path.basename(correct_data[i]["image"])
            correct_dict[ann_id] = {
                "ann_id": ann_id,
                "id_list": [correct_data[i]["id"]],
                "image_name": image_name,
                "gt_entities": all_entities_dict,
                "input": [correct_data[i]["conversations"][0]["value"]],
                "gt_bbox_num": bbox_num,
                "gt_output": [correct_data[i]["conversations"][1]["value"]]
            }
        else:
            current_entities_dict = correct_dict[ann_id]["gt_entities"]
            current_bbox_num = correct_dict[ann_id]["gt_bbox_num"]
            
            for entity in correct_data[i]["gt_entities_quantized_normalized"]:
                name = entity[0]
                all_entities_dict[name].extend(entity[-1])
                current_bbox_num += len(entity[-1])
                
            correct_dict[ann_id]["id_list"].append(correct_data[i]["id"])
            correct_dict[ann_id]["gt_entities"] = current_entities_dict
            correct_dict[ann_id]["gt_bbox_num"] = current_bbox_num
            correct_dict[ann_id]["input"].append(correct_data[i]["conversations"][0]["value"])
            correct_dict[ann_id]["gt_output"].append(correct_data[i]["conversations"][1]["value"])
            
    for ann_id, v in correct_dict.items():
        gt_entities_list = []
        current_entities_dict = v["gt_entities"]
        for name in all_object_name_list:
            if len(current_entities_dict[name]) == 0:
                del current_entities_dict[name]
            else:
                gt_entities_list.append([name, current_entities_dict[name]])
        v["gt_entities"] = gt_entities_list
        v["gt_image"] = add_bbox_to_wandb_image(wandb_name_to_image[v["image_name"]], v["gt_entities"], category_name_to_id)
        
    return sort_list_of_dicts(correct_dict.values(),key="ann_id")
    

In [7]:
def get_generated_table_data(correct_data, generated_data, unique_key,wandb_name_to_image,category_name2id,category_id2name):
    all_gt_annotations, all_gt_num_dict = create_annotations_for_coco(correct_data, category_name2id,parse_bbox_and_labels, None)
    all_pred_annotations, all_pred_num_dict = create_annotations_for_coco(generated_data, category_name2id, parse_bbox_and_labels,None,
            delete_region_failure=True, unknown_to_similar=True, sentence_transformer_model_path="/data_ssd/huggingface_model_weights/sentence-transformers/all-MiniLM-L6-v2")
    per_image_result_dict, oc_cost_list = get_per_image_class_result_and_oc_cost(all_gt_annotations, all_pred_annotations, category_name2id, iou_threshold=0.5)

    
    
    eval_dict = {}
    for data in generated_data:
        ann_id = data["ann_id"]
        if ann_id not in eval_dict:
            image_name = os.path.basename(data["image"])
            eval_dict[ann_id] = {
                "ann_id": ann_id,
                "image_name": image_name,
                f"{unique_key}_pred_output": []
            }
        pred_output = data["conversations"][1]["value"]
        eval_dict[ann_id][f"{unique_key}_pred_output"].append(pred_output)

    for per_image_result, oc_cost_value in tqdm(zip(per_image_result_dict.values(), oc_cost_list)):
        ann_id = per_image_result["ann_id"]
        del per_image_result["ann_id"]
        pred_bbox_num = 0
        pred_entities = []
        iou_info_list = []
        iou_over_0_5_count = 0
        for category_id, result in per_image_result.items():

            if result is None:
                continue
            category_name = category_id2name[category_id]
            pred_bbox_num += len(result["pred_bbox_list"])
            pred_entities.append([category_name, result["pred_bbox_list"]])
            iou_info_list.append([category_name, result["iou_info_list"]])
            iou_over_0_5_count += result["tp_num"]
        eval_item = eval_dict[ann_id]
        eval_item[f"{unique_key}_pred_bbox_num"] = pred_bbox_num
        
        eval_item[f"{unique_key}_iou_info_list"] = str(iou_info_list)
        eval_item[f"{unique_key}_iou_over_0_5_count"] = iou_over_0_5_count
        eval_item[f"{unique_key}_oc_cost"] = float(oc_cost_value)
        eval_item[f"{unique_key}_pred_image"] = add_bbox_to_wandb_image(
            wandb_name_to_image[eval_item["image_name"]], pred_entities, category_name2id
        )
        eval_item[f"{unique_key}_pred_entities"] = str(pred_entities)
    return sort_list_of_dicts(eval_dict.values(),key="ann_id")

In [None]:
gt_path = "/data_ssd/PASCAL-VOC/paligemma_actual_detection/test_pascal-voc_actual_detection_for_paligemma_sort_size_cat_size.json"
# compare_dict = {
#     "ce": "/data_ssd/USER_DATA/omote/iam-llms-finetune/experiment_output/paligemma_pascalvoc-multi-class-448px/448px_size_aligned_train-vision-proj-llm_cross-entropy_2025-10-17T18_07_18",
#     "prop": "/data_ssd/USER_DATA/omote/iam-llms-finetune/experiment_output/paligemma_pascalvoc-multi-class-448px/448px_size_aligned_train-vision-proj-llm_cedfl_excepted_split_ce_2025-10-18T00_33_59",
# }
compare_dict = {
    "ce": "/home/omote/omote-data-ssd/iam-llms-finetune/experiment_output/paligemma_pascalvoc-actual-detection-448px/448px_size_aligned_train-vision-proj-llm_cross-entropy_2025-10-20T22_29_37",
    "prop": "/home/omote/omote-data-ssd/iam-llms-finetune/experiment_output/paligemma_pascalvoc-actual-detection-448px/448px_size_aligned_train-vision-proj-llm_cedfl_excepted2_split_ce_2025-10-21T11_38_04",
}


eval_json_name = "test_pascal-voc_actual_detection_for_paligemma_sort_size_cat_size"

artifact_entity = "katlab-gifu/dataset/pascal_voc_test:v0"

ENTITY = "katlab-gifu"
PROJECT = "vis_test"
NAME="pascalvoc_actual_detection_comparison_448px_1"

In [9]:
run = wandb.init(entity=ENTITY, project=PROJECT, name=NAME)
img_art = run.use_artifact(artifact_entity)

[34m[1mwandb[0m: Currently logged in as: [33momote-hideaki-s8[0m ([33mkatlab-gifu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
wandb_dataset = img_art.get("pascal_voc_test")

[34m[1mwandb[0m: Downloading large artifact pascal_voc_test:v0, 1454.43MB. 4953 files... 
[34m[1mwandb[0m:   4953 of 4953 files downloaded.  
Done. 0:0:14.4 (101.0MB/s)


In [11]:
wandb_name_to_image = {data_row[0]: data_row[1] for data_row in wandb_dataset.data}
wandb_image_names = set(wandb_name_to_image.keys())
assert len(wandb_image_names) == len(wandb_name_to_image)

In [12]:
# print(wandb_name_to_image)

In [13]:
new_compare_dict = {}
for k,v in compare_dict.items():
    file_list = glob.glob(os.path.join(v,"**",eval_json_name,"**","eval_output.json"),recursive=True)
    assert len(file_list) == 1
    new_compare_dict[k] = file_list[0]
compare_dict = new_compare_dict

In [14]:
category_name2id,category_id2name = get_pascal_voc_category()

In [15]:
correct_data = load_json(gt_path)
correct_data = sort_list_of_dicts(correct_data, "id")
# all_gt_annotations, all_gt_num_dict = create_annotations_for_coco(correct_data, category_name2id,parse_bbox_and_labels, None)



In [16]:
correct_data_list = get_correct_table_data(correct_data,wandb_name_to_image,category_name2id)

Processing correct data: 100%|██████████| 4952/4952 [00:00<00:00, 16510.36it/s]


In [17]:
print(correct_data_list[0])

{'ann_id': '000001', 'id_list': ['000001'], 'image_name': '000001.jpg', 'gt_entities': [['dog', [[0.13294232649071358, 0.4780058651026393, 0.5493646138807429, 0.739980449657869]]], ['person', [[0.019550342130987292, 0.022482893450635387, 0.9941348973607038, 0.9941348973607038]]]], 'input': ['<image>Detect objects in this image from the following categories: [aeroplane ; bicycle ; bird ; boat ; bottle ; bus ; car ; cat ; chair ; cow ; diningtable ; dog ; horse ; motorbike ; person ; pottedplant ; sheep ; sofa ; train ; tvmonitor]. List detected categories only.'], 'gt_bbox_num': 2, 'gt_output': ['<loc0023><loc0020><loc1017><loc1017> person ; <loc0489><loc0136><loc0757><loc0562> dog'], 'gt_image': <wandb.sdk.data_types.image.Image object at 0x1469f0ceb410>}


In [18]:


wandb_columns = ["ann_id","id_list","image_name","gt_image","input","gt_output","gt_bbox_num","gt_entities"]
tmp_table_data = []
for item in correct_data_list:
    row_data = [str(item[k]) if not (type(item[k]) ==  wandb.Image or type(item[k]) ==  int or type(item[k]) ==  float) else item[k] for k in wandb_columns ]
    tmp_table_data.append(row_data)

for unique_key, generated_path in compare_dict.items():
    generated_data = load_json(generated_path)
    assert len(correct_data) == len(generated_data), "Length of correct and generated data does not match."
    generated_data = sort_list_of_dicts(generated_data, "id")
    generated_data_list = get_generated_table_data(correct_data, generated_data, unique_key, wandb_name_to_image, category_name2id, category_id2name)
    
    unique_columns = [
        f"{unique_key}_pred_image",
        f"{unique_key}_pred_output",
        f"{unique_key}_pred_bbox_num",
        f"{unique_key}_iou_info_list",
        f"{unique_key}_iou_over_0_5_count",
        f"{unique_key}_oc_cost",
        f"{unique_key}_pred_entities"
    ]
    
    for i in range(len(tmp_table_data)):
        assert tmp_table_data[i][0] == generated_data_list[i]["ann_id"], f"Ann ID mismatch at index {i}."
        for col in unique_columns:
            if type(generated_data_list[i][col]) ==  wandb.Image or type(generated_data_list[i][col]) ==  int or type(generated_data_list[i][col]) ==  float:
                tmp_table_data[i].append(generated_data_list[i][col])
            else:
                tmp_table_data[i].append(str(generated_data_list[i][col]))
    wandb_columns.extend(unique_columns)



100%|██████████| 4952/4952 [00:00<00:00, 17750.21it/s]
100%|██████████| 4952/4952 [00:00<00:00, 7553.51it/s]
100%|██████████| 4952/4952 [00:03<00:00, 1245.68it/s]
4952it [00:28, 170.79it/s]
100%|██████████| 4952/4952 [00:00<00:00, 7688.59it/s]
100%|██████████| 4952/4952 [00:00<00:00, 15292.95it/s]
100%|██████████| 4952/4952 [00:03<00:00, 1245.00it/s]
4952it [00:29, 170.42it/s]


In [19]:
for i, d in enumerate(tmp_table_data):
    # assert len(d) == len(wandb_columns), f"Data length mismatch at index {i}: {len(d)} != {len(wandb_columns)}"
    for col,d in zip(wandb_columns,d):
    #     print(col,type(d),d )
    # break
        if d is None or None in (d if type(d) == list else [d]):
            print(f"None value found in column {col} at row {i}")
            print(f"Type of d: {type(d)}")
            print(f"Value of d: {d}")

In [20]:
for i, d in enumerate(tmp_table_data):
    assert len(d) == len(wandb_columns), f"Data length mismatch at index {i}: {len(d)} != {len(wandb_columns)}"
    # for col,d in zip(wandb_columns,d):
    #     print(col,d)

In [21]:
for col, d in zip(wandb_columns, tmp_table_data[0]):
    print(col, type(d), d )

ann_id <class 'str'> 000001
id_list <class 'str'> ['000001']
image_name <class 'str'> 000001.jpg
gt_image <class 'wandb.sdk.data_types.image.Image'> <wandb.sdk.data_types.image.Image object at 0x1469f0ceb410>
input <class 'str'> ['<image>Detect objects in this image from the following categories: [aeroplane ; bicycle ; bird ; boat ; bottle ; bus ; car ; cat ; chair ; cow ; diningtable ; dog ; horse ; motorbike ; person ; pottedplant ; sheep ; sofa ; train ; tvmonitor]. List detected categories only.']
gt_output <class 'str'> ['<loc0023><loc0020><loc1017><loc1017> person ; <loc0489><loc0136><loc0757><loc0562> dog']
gt_bbox_num <class 'int'> 2
gt_entities <class 'str'> [['dog', [[0.13294232649071358, 0.4780058651026393, 0.5493646138807429, 0.739980449657869]]], ['person', [[0.019550342130987292, 0.022482893450635387, 0.9941348973607038, 0.9941348973607038]]]]
ce_pred_image <class 'wandb.sdk.data_types.image.Image'> <wandb.sdk.data_types.image.Image object at 0x146990cfae40>
ce_pred_outpu

In [22]:
result_table = wandb.Table(columns=wandb_columns, data=tmp_table_data)
wandb.log({"result_table": result_table})
run.finish()

In [None]:
pred_tuple_list = []
for c_item, g_item in zip(correct_data, generated_data):
    assert c_item["id"] == g_item["id"], f"ID mismatch: {c_item['id']} != {g_item['id']}"
    pred_input = generated_data["conversations"][0]["value"]
    pred_output = generated_data[i]["conversations"][1]["value"]
    pred_bbox_list, pred_label_list = paligemma_get_bbox(pred_output)
    pred_bbox_num = len(pred_bbox_list)
    
    pred_entities = [pred_input, pred_bbox_list]
    
    
    
    

In [None]:
correct_data_tuple_list = []
for item in correct_data:
    id = item["id"]
    ann_id = item["ann_id"]
    image_name = os.path.basename(item["image"])
    
    

In [None]:
wandb_columns = ["ann_id","id_list","image_name","gt_image","input","gt_output","gt_bbox_num","gt_entities"]

for key, path in compare_dict.items():
    wandb_columns.append(f"{key}_pred_image")
    wandb_columns.append(f"{key}_pred_output")
    wandb_columns.append(f"{key}_pred_bbox_num")
    wandb_columns.append(f"{key}_pred_entities")
    wandb_columns.append(f"{key}_iou_info_list")
    wandb_columns.append(f"{key}_iou_over_0.5_count")
    


In [None]:
print(data_list[0])

In [None]:
run = wandb.init(entity=ENTITY, project=PROJECT, id=RUN_ID, resume="must")
img_art = run.use_artifact(artifact_entity)
img_dir = Path(img_art.download())

In [None]:
table = wandb.Table(columns=["image_id", "image", "n_boxes"])