In [1]:
import json
import sys
sys.path.append("/home/omote/cluster_project/iam2/eval")
from eval_utils.custom_oc_cost import get_cmap,get_ot_cost,DetectedInstance
import argparse
import os
import datetime
from tqdm import tqdm
import regex as re
from torchvision.ops import box_iou
import torch
from transformers import AutoProcessor
import imgviz
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import math
from pycocotools.coco import COCO
import math
from copy import deepcopy

def save_json(file_path, data):
    """
    Save data to a JSON file.

    Args:
        file_path (str): Path to the JSON file.
        data (dict): Data to save.
    """
    with open(file_path, 'w') as f:
        json.dump(data, f, indent=4)

def load_json(file_path):
    """
    Load data from a JSON file.

    Args:
        file_path (str): Path to the JSON file.

    Returns:
        dict: Data loaded from the file.
    """
    with open(file_path, 'r') as f:
        return json.load(f)

def extract_bbox_from_text(ans):
    pattern = re.compile(r'\[(((0|1)\.(\d){3}\,\s*){3}((0|1)\.(\d){3}))\]')
    match_list = pattern.findall(ans)

    if len(match_list) > 0:
        answer = [list(map(float,match[0].split(","))) for match in match_list]
    else:
        answer = "FAILED"
    return answer

def calculate_iou(gt_bbox_list, pred_bbox_list):
    iou_matrix = box_iou(torch.tensor(gt_bbox_list).float(), torch.tensor(pred_bbox_list).float())
    iou_matrix = torch.nan_to_num(iou_matrix, nan=0.0)  # NaNを0に置き換える
    iou_argsort_matrix = torch.argsort(iou_matrix.flatten(),descending=True).argsort().reshape(iou_matrix.shape)#iouが大きい順にソートしたインデックスを取得
    # print(iou_argsort_matrix)
    # print("-" * 50)
    # print(iou_matrix)
    pred_index_list =  torch.full((len(pred_bbox_list),), False, dtype=torch.bool)
    gt_index_list = torch.full((len(gt_bbox_list),), False, dtype=torch.bool)

    short_index_list = pred_index_list if len(pred_bbox_list) < len(gt_bbox_list) else gt_index_list
    iou_info_list = []

    # print(iou_matrix.numel())
    for i in range(iou_matrix.numel()):
        max_iou_index = torch.where(iou_argsort_matrix == i)
        if not gt_index_list[max_iou_index[0]] and not pred_index_list[max_iou_index[1]]:
            iou_info_list.append( {
                "gt_index": max_iou_index[0].item(),
                "pred_index": max_iou_index[1].item(),
                "iou_value": iou_matrix[max_iou_index].item()
            })
            gt_index_list[max_iou_index[0]] = True
            pred_index_list[max_iou_index[1]] = True
            # print(f"index {i} - gt_index: {max_iou_index[0].item()}, pred_index: {max_iou_index[1].item()}, iou_value: {iou_matrix[max_iou_index].item()}")
        
        if torch.all(short_index_list):
            break
        
    assert len(iou_info_list) == min(len(gt_bbox_list), len(pred_bbox_list)), f"Length mismatch: {len(iou_info_list)} != {min(len(gt_bbox_list), len(pred_bbox_list))}"
    # print(iou_info_list)
    # for iou_info in iou_info_list:
    #     if math.isnan(iou_info["iou_value"]):
    #         print(f"IOU value is NaN for gt index {iou_info['gt_index']} and pred index {iou_info['pred_index']}")
    #         print(iou_matrix[iou_info['gt_index'], iou_info['pred_index']])
    #         print(iou_matrix[iou_info['gt_index'], iou_info['pred_index']].item())
    #         print(iou_info["iou_value"])
    #         print(iou_matrix)
    
    return iou_info_list,iou_matrix,iou_argsort_matrix,pred_index_list, gt_index_list

def sort_list_of_dicts(data, key, reverse=False):
    """
    Sort a list of dictionaries by the specified key.

    Args:
        data (list): List of dictionaries to sort.
        key (str): Key to sort by.
        reverse (bool): Sort in descending order if True, ascending if False.

    Returns:
        list: Sorted list of dictionaries.
    """
    return sorted(data, key=lambda x: x[key], reverse=reverse)

def oc_cost(pred_instance_list,tgt_instance_list, alpha=0.5,beta=0.6):
    cmap_func = lambda x, y: get_cmap(x, y, alpha=alpha, beta=beta,label_or_sim="label")
    otc = get_ot_cost(pred_instance_list, tgt_instance_list, cmap_func)
    return otc


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pred_bbox_list = [(0.421875, 0.453125, 0.453125, 0.546875), (0.453125, 0.453125, 0.484375, 0.546875), (0.46875, 0.46875, 0.5, 0.5625), (0.484375, 0.453125, 0.515625, 0.546875), (0.453125, 0.546875, 0.484375, 0.640625), (0.46875, 0.53125, 0.5, 0.65625), (0.46875, 0.53125, 0.5, 0.65625), (0.421875, 0.734375, 0.484375, 0.828125), (0.953125, 0.453125, 0.984375, 0.515625), (0.921875, 0.578125, 0.953125, 0.609375), (0.953125, 0.546875, 0.984375, 0.609375)]
gt_bbox_list = [(0.421875, 0.453125, 0.453125, 0.546875), (0.453125, 0.453125, 0.484375, 0.546875), (0.46875, 0.46875, 0.5, 0.5625), (0.484375, 0.453125, 0.515625, 0.546875), (0.453125, 0.546875, 0.484375, 0.640625), (0.46875, 0.53125, 0.5, 0.65625), (0.46875, 0.53125, 0.5, 0.65625), (0.421875, 0.734375, 0.484375, 0.828125), (0.953125, 0.453125, 0.984375, 0.515625), (0.921875, 0.578125, 0.953125, 0.609375), (0.953125, 0.546875, 0.984375, 0.609375)]
iou_info_list,iou_matrix,iou_argsort_matrix,pred_index_list, gt_index_list = calculate_iou(gt_bbox_list, pred_bbox_list)
print(f"len(gt_bbox_list): {len(gt_bbox_list)}")
print(f"len(iou_info_list): {len(iou_info_list)}")
print(f"pred_bbox_list: {pred_bbox_list}")
print(f"gt_bbox_list: {gt_bbox_list}")
print(f"iou_info_list: {iou_info_list}")
print(f"iou_matrix: {iou_matrix}")
print(f"iou_argsort_matrix: {iou_argsort_matrix}")
print(f"pred_index_list: {pred_index_list}")
print(f"gt_index_list: {gt_index_list}")


len(gt_bbox_list): 11
len(iou_info_list): 11
pred_bbox_list: [(0.421875, 0.453125, 0.453125, 0.546875), (0.453125, 0.453125, 0.484375, 0.546875), (0.46875, 0.46875, 0.5, 0.5625), (0.484375, 0.453125, 0.515625, 0.546875), (0.453125, 0.546875, 0.484375, 0.640625), (0.46875, 0.53125, 0.5, 0.65625), (0.46875, 0.53125, 0.5, 0.65625), (0.421875, 0.734375, 0.484375, 0.828125), (0.953125, 0.453125, 0.984375, 0.515625), (0.921875, 0.578125, 0.953125, 0.609375), (0.953125, 0.546875, 0.984375, 0.609375)]
gt_bbox_list: [(0.421875, 0.453125, 0.453125, 0.546875), (0.453125, 0.453125, 0.484375, 0.546875), (0.46875, 0.46875, 0.5, 0.5625), (0.484375, 0.453125, 0.515625, 0.546875), (0.453125, 0.546875, 0.484375, 0.640625), (0.46875, 0.53125, 0.5, 0.65625), (0.46875, 0.53125, 0.5, 0.65625), (0.421875, 0.734375, 0.484375, 0.828125), (0.953125, 0.453125, 0.984375, 0.515625), (0.921875, 0.578125, 0.953125, 0.609375), (0.953125, 0.546875, 0.984375, 0.609375)]
iou_info_list: [{'gt_index': 5, 'pred_index': 5, 

In [3]:
def bbox_relative_to_absolute(relative_bbox, image_width_height):
    width, height = image_width_height
    x1 = relative_bbox[0] * width
    y1 = relative_bbox[1] * height
    x2 = relative_bbox[2] * width
    y2 = relative_bbox[3] * height
    absolute_bbox = [x1, y1, x2, y2]
    return absolute_bbox

def visualize_bbox(image, bbox_list, bbox_name_list,bbox_is_relative=True,with_id=False):
    assert len(bbox_list) == len(bbox_name_list), "bbox_list and bbox_name_list must have the same length"
    if isinstance(image, str):
        image = Image.open(image).convert("RGB")

    if bbox_is_relative:
        # 画像のサイズを取得
        image_width_height = (image.width, image.height)
        # 相対座標を絶対座標に変換
        bbox_list = [bbox_relative_to_absolute(bbox, image_width_height) for bbox in bbox_list]
        
    #bbox_name_listをソート、bbox_listも同じ順番にソート
    # bbox_name_list, bbox_list = zip(*sorted(zip(bbox_name_list, bbox_list), key=lambda x: x[0]))
    # bbox_name_list = list(bbox_name_list)
    # bbox_list = list(bbox_list)
    name_to_label_id_dict = {}
    label_id = 0
    for bbox_name in bbox_name_list:
        if bbox_name not in name_to_label_id_dict:
            name_to_label_id_dict[bbox_name] = label_id
            label_id += 1    
    
    # bbox_listの座標をy1, x1, y2, x2の形式に変換
    bboxes = []
    labels = []
    # label_id = -1
    # old_label = None
    count_object_dict = {}
    id_bbox_name_list = []
    for bbox ,bbox_name in zip(bbox_list, bbox_name_list):
        x1, y1, x2, y2 = bbox
        bboxes.append([y1, x1, y2, x2])
        # if old_label != bbox_name:
        #     label_id += 1
        #     old_label = bbox_name
        label_id = name_to_label_id_dict[bbox_name]
        if bbox_name not in count_object_dict:
            count_object_dict[bbox_name] = 0
        else:
            count_object_dict[bbox_name] += 1
        if with_id:
            bbox_name = f"{bbox_name}_{count_object_dict[bbox_name]}"
            id_bbox_name_list.append(bbox_name)
        labels.append(label_id)
    # bboxes = np.array([bbox[1],bbox[0],bbox[3],bbox[2]]).astype(np.int32).reshape(-1, 4)
    
    base_resolution = 100 * 100
    base_font_size = 3
    image_resolution = image.width * image.height
    font_size = int( base_font_size * (image_resolution / base_resolution) ** 0.5)
    
    if with_id:
        bbox_name_list = id_bbox_name_list
    image = imgviz.instances2rgb(np.array(image), bboxes=bboxes, labels=labels,font_size=font_size,captions=bbox_name_list)

    plt.imshow(image)
    plt.show()

In [4]:
split = "val"

correct_json_path = f"/data_ssd/mscoco-detection/val_for-kosmos2_mscoco2017-detection.json"
correct_data = load_json(correct_json_path)

generated_json_path = "/data_ssd/mscoco-detection/val_for-kosmos2_mscoco2017-detection.json"
generated_json_path = "/home/omote/omote-data-ssd/iam-llms-finetune/experiment_output/kosmos-2_mscoco2017-detection/mscoco2017-detection_train-vision-proj-llm_cross-entropy_2025-07-03T12_51_20/checkpoint-10536/eval_output/val_for-kosmos2_mscoco2017-detection/2025-07-04T10_36_40/eval_output.json"
# generated_json_path = "/home/omote/omote-data-ssd/iam-llms-finetune/experiment_output/kosmos-2_mscoco2017-detection/mscoco2017-detection_train-vision-proj-llm_distance-loss_2025-07-03T12_52_38/checkpoint-10536/eval_output/val_for-kosmos2_mscoco2017-detection/2025-07-04T11_16_55/eval_output.json"
# generated_json_path = "/home/omote/omote-data-ssd/iam-llms-finetune/experiment_output/kosmos-2_mscoco2017-detection/mscoco2017-detection_train-vision-proj-llm_distance-forward-kl-loss_2025-07-03T16_46_51/checkpoint-10536/eval_output/val_for-kosmos2_mscoco2017-detection/2025-07-04T11_56_05/eval_output.json"
generated_data = load_json(generated_json_path)

assert len(correct_data) == len(generated_data), "Length of correct and generated data does not match."

correct_data = sort_list_of_dicts(correct_data, "id")
generated_data = sort_list_of_dicts(generated_data, "id")

for correct, generated in zip(correct_data, generated_data):
    assert correct["id"] == generated["id"], f"ID mismatch: {correct['id']} != {generated['id']}"
    
iou_threshold = 0.5
image_folder_root = "/data_ssd"


In [5]:
processor = AutoProcessor.from_pretrained("/data_ssd/huggingface_model_weights/microsoft/kosmos-2-patch14-224")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [6]:
item = correct_data[0]
for key,value in item.items():
    print(f"{key}: {value}")

id: mscoco2017-detection_train-100083
image: mscoco2017/coco/images/train2017/000000100083.jpg
conversations: [{'from': 'human', 'value': '<image><grounding> Please carefully check the image and detect the following objects: [person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic light, fire hydrant, stop sign, parking meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports ball, kite, baseball bat, baseball glove, skateboard, surfboard, tennis racket, bottle, wine glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot dog, pizza, donut, cake, chair, couch, potted plant, bed, dining table, toilet, tv, laptop, mouse, remote, keyboard, cell phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy bear, hair drier, toothbrush].'}, {'from': 'gpt', 'value': '<phrase> cup</phrase><object><patch_index_0

In [7]:
def get_mscoco2017_detection_cat_name2id(split="val"):
    """
    Get the category name to ID mapping for the MS COCO 2017 detection dataset.

    Args:
        split (str): The dataset split, either 'train' or 'val'.

    Returns:
        dict: A dictionary mapping category names to their corresponding IDs.
    """
     # /data_ssd/mscoco2017/coco/annotations/instances_train2017.json
     # /data_ssd/mscoco2017/coco/annotations/instances_val2017.json
    anno_path = f"/data_ssd/mscoco2017/coco/annotations/instances_{split}2017.json"
    #/data_ssd/mscoco2017/coco/annotations/instances_train2017.json
    coco_dataset = COCO(anno_path)
    cat_name2id = {c["name"]: c["id"] for c in coco_dataset.loadCats(coco_dataset.getCatIds())}

    return cat_name2id

cat_name2id = get_mscoco2017_detection_cat_name2id(split=split)
cat_name2id.update({"unknown": -1})




loading annotations into memory...
Done (t=0.55s)
creating index...
index created!


In [8]:
print(len(cat_name2id))

81


In [9]:


def create_images_for_coco(conversation_dataset, image_folder_root="/data_ssd"):
    return_images = []
    num_images = len(conversation_dataset)
    
    for i in tqdm(range(num_images)):
        image_name = conversation_dataset[i]["image"]
        image_path = os.path.join(image_folder_root, image_name)
        image = Image.open(image_path)
        image_height = image.height
        image_width = image.width
        image_info = {
            "id": i,
            "width": image_width,
            "height": image_height,
            "file_name": image_name
        }
        return_images.append(image_info)
    
    return return_images

def create_annotations_for_coco(conversation_dataset,categories,processor):
    return_annotations = {}
    
    num_images = len(conversation_dataset)
    
    id_index = 0
    for i in tqdm(range(num_images)):
        caption, entities = processor.post_process_generation(conversation_dataset[i]["conversations"][1]["value"])
        for name,_,bbox_list in entities:
            for bbox in bbox_list:
                annotation = {
                    "id": id_index,
                    "image_id": i,
                    "category_id": categories[name] if name in categories else categories["unknown"],
                    "bbox": [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]],  # [x, y, width, height]
                    "area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]),
                    "iscrowd": 0,
                    "score": 1.0,  # Assuming all annotations are perfect for dummy data
                    "category_name": name,
                    "bbox_xyxy": bbox,  # [x1, y1, x2, y2]
                    "is_unknown": 1 if name not in categories else 0
                }
                
                if i not in return_annotations:
                    return_annotations[i] = []
                return_annotations[i].append(annotation)
                id_index += 1

    return return_annotations


In [10]:
images = create_images_for_coco(correct_data, image_folder_root)
all_gt_annotations = create_annotations_for_coco(correct_data, cat_name2id, processor)
all_pred_annotations = create_annotations_for_coco(generated_data, cat_name2id, processor)

100%|██████████| 5000/5000 [00:02<00:00, 1944.49it/s]
100%|██████████| 5000/5000 [00:00<00:00, 5253.01it/s]
100%|██████████| 5000/5000 [00:00<00:00, 14457.41it/s]


In [11]:
for index, annotation_list in all_gt_annotations.items():
    print(f"Image {index} has {len(annotation_list)} ground truth annotations.")
    for annotation in annotation_list:
        print(f"  Annotation ID: {annotation['id']}, Category: {annotation['category_name']}, BBox: {annotation['bbox_xyxy']}")
    break

Image 0 has 5 ground truth annotations.
  Annotation ID: 0, Category: cup, BBox: (0.046875, 0.109375, 0.171875, 0.234375)
  Annotation ID: 1, Category: cup, BBox: (0.171875, 0.109375, 0.265625, 0.234375)
  Annotation ID: 2, Category: fork, BBox: (0.265625, 0.109375, 0.328125, 0.234375)
  Annotation ID: 3, Category: fork, BBox: (0.453125, 0.328125, 0.546875, 0.390625)
  Annotation ID: 4, Category: dining table, BBox: (0.015625, 0.015625, 0.984375, 0.984375)


In [12]:
def get_per_image_class_result_and_oc_cost(all_gt_annotations, all_pred_annotations, cat_name2id, iou_threshold=0.5):
    """
    Evaluate per-image and per-category results and calculate occlusion cost.

    Args:
        all_gt_annotations (dict): Ground truth annotations.
        all_pred_annotations (dict): Predicted annotations.
        cat_name2id (dict): Category name to ID mapping.
        iou_threshold (float): IOU threshold for true positives.

    Returns:
        dict: Per-image results and occlusion costs.
    """
    per_image_result_dict = {}
    break_index = 10
    oc_cost_list = []
    for index, gt_per_image_annotation_list in tqdm(all_gt_annotations.items()):
        pred_per_image_annotation_list = all_pred_annotations.get(index, [])
        
        # 画像ごとの評価
        pred_instance_list = [DetectedInstance(
            label=ann["category_id"],
            x1=ann["bbox_xyxy"][0],
            y1=ann["bbox_xyxy"][1],
            x2=ann["bbox_xyxy"][2],
            y2=ann["bbox_xyxy"][3]) for ann in pred_per_image_annotation_list]
        tgt_instance_list = [DetectedInstance(
            label=ann["category_id"],
            x1=ann["bbox_xyxy"][0],
            y1=ann["bbox_xyxy"][1],
            x2=ann["bbox_xyxy"][2],
            y2=ann["bbox_xyxy"][3]) for ann in gt_per_image_annotation_list]
        
        oc_cost_value = oc_cost(pred_instance_list, tgt_instance_list, alpha=0.5, beta=0.6)
        oc_cost_list.append(oc_cost_value)
        
        #画像ごと・カテゴリごとの評価準備
        gt_per_category_dict = {}
        pred_per_category_dict = {}
        per_category_result_dict = {}
        
        for category_id in cat_name2id.values():
            gt_per_category_dict[category_id] = None
            pred_per_category_dict[category_id] = None
            per_category_result_dict[category_id] = None
            
        for annotation in gt_per_image_annotation_list:
            if gt_per_category_dict[annotation["category_id"]] is None:
                gt_per_category_dict[annotation["category_id"]] = []
            gt_per_category_dict[annotation["category_id"]].append(annotation)
        
        for annotation in pred_per_image_annotation_list:
            if pred_per_category_dict[annotation["category_id"]] is None:
                pred_per_category_dict[annotation["category_id"]] = []
            pred_per_category_dict[annotation["category_id"]].append(annotation)
        
        
        for category_id, gt_annotations in gt_per_category_dict.items():
            pred_annotations = pred_per_category_dict[category_id]
            if gt_annotations is None and  pred_annotations is None:
                continue
            
            per_category_result = {
                "iou_list": [],
                "pred_iou_list": [],
                "tp_num": 0,
                "fp_num": 0,
                "fn_num": 0,
            }
            if gt_annotations is None and pred_per_category_dict[category_id] is not None:
                per_category_result["fp_num"] = len(pred_per_category_dict[category_id])
            elif gt_annotations is not None:
                if pred_per_category_dict[category_id] is None:
                    per_category_result["fn_num"] = len(gt_annotations)
                    per_category_result["iou_list"] = [0.0] * len(gt_annotations)
                else: 
                    gt_bbox_list = [ann["bbox_xyxy"] for ann in gt_annotations]
                    pred_bbox_list = [ann["bbox_xyxy"] for ann in pred_annotations]
                    iou_info_list,iou_matrix,iou_argsort_matrix,pred_index_list, gt_index_listt = calculate_iou(gt_bbox_list, pred_bbox_list)
                    assert ((len(gt_bbox_list) < len(pred_bbox_list) and len(iou_info_list) == len(gt_bbox_list)) or (len(gt_bbox_list) >= len(pred_bbox_list) and len(iou_info_list) == len(pred_bbox_list))), f"Length mismatch in category {category_id}, index {index}: len(iou_info_list)={len(iou_info_list)}, len(gt_bbox_list)={len(gt_bbox_list)}, len(pred_bbox_list)={len(pred_bbox_list)}"
                    # if not((len(gt_bbox_list) < len(pred_bbox_list) and len(iou_info_list) == len(gt_bbox_list)) or \
                    #     (len(gt_bbox_list) >= len(pred_bbox_list) and len(iou_info_list) == len(pred_bbox_list))):
                        # print(f"index: {index}, category_id: {category_id}, len(iou_info_list): {len(iou_info_list)}, len(gt_bbox_list): {len(gt_bbox_list)}, len(pred_bbox_list): {len(pred_bbox_list)}")
                        # print(f"pred_bbox_list: {pred_bbox_list}")
                        # print(f"gt_bbox_list: {gt_bbox_list}")
                        # print(f"iou_info_list: {iou_info_list}")
                        # print(f"iou_matrix: {iou_matrix}")
                        # print(f"iou_argsort_matrix: {iou_argsort_matrix}")
                        # print(f"pred_index_list: {pred_index_list}")
                        # print(f"gt_index_list: {gt_index_listt}")
                        # raise ValueError("IOU information length mismatch")
                    iou_list = [info["iou_value"] for info in iou_info_list]
                    per_category_result["pred_iou_list"] = deepcopy(iou_list)
                    for iou in iou_list:
                        assert not math.isnan(iou), f"IOU value is NaN in category {category_id}, index {index}"
                    if len(iou_list) < len(gt_bbox_list):
                        iou_list += [0.0] * (len(gt_bbox_list) - len(iou_list))
                    
                    
                    # for iou in iou_list:
                    #     assert not math.isnan(iou), f"IOU value is NaN in category {category_id}, index {index}"
                    per_category_result["iou_list"] = iou_list
                    tp_num = sum(1 for iou in iou_list if iou >= iou_threshold)
                    per_category_result["tp_num"] = tp_num
                    per_category_result["fp_num"] = len(pred_bbox_list) - tp_num
                    per_category_result["fn_num"] = len(gt_bbox_list) - tp_num
                    # if index == 9 and category_id == 84:
                    #     visualize_bbox(
                    #         os.path.join(image_folder_root, images[index]["file_name"]),
                    #         pred_bbox_list,
                    #         [ann["category_name"] for ann in pred_annotations],
                    #         bbox_is_relative=True,
                    #         with_id=True
                    #     )
                    #     print(per_category_result)
                    #     print(pred_bbox_list == gt_bbox_list)
                    #     print(len(pred_bbox_list), len(gt_bbox_list))
                    #     print(iou_info_list)
            
            per_category_result_dict[category_id] = per_category_result
        
        per_image_result_dict[index] = per_category_result_dict
        # if index >= break_index:
        #     break
    return per_image_result_dict, oc_cost_list 

per_image_result_dict, oc_cost_list = get_per_image_class_result_and_oc_cost(all_gt_annotations, all_pred_annotations, cat_name2id, iou_threshold=iou_threshold)


  0%|          | 0/5000 [00:00<?, ?it/s]

100%|██████████| 5000/5000 [00:07<00:00, 643.55it/s]


In [13]:
break_index = 10
for key ,value in per_image_result_dict.items():
    print(f"Image {key} has {len(value)} categories.")
    print(f"OC Cost: {oc_cost_list[key]}")
    for category_id, result in value.items():
        if result is None:
            continue
        print(f"  Category ID: {category_id}, TP: {result['tp_num']}, FP: {result['fp_num']}, FN: {result['fn_num']}, IOU List: {result['iou_list']}")
    if key >= break_index:
        break

Image 0 has 81 categories.
OC Cost: 0.4766666615009308
  Category ID: 47, TP: 0, FP: 0, FN: 2, IOU List: [0.0, 0.0]
  Category ID: 48, TP: 1, FP: 0, FN: 1, IOU List: [0.6666666865348816, 0.0]
  Category ID: 51, TP: 0, FP: 2, FN: 0, IOU List: []
  Category ID: 67, TP: 0, FP: 0, FN: 1, IOU List: [0.0]
Image 1 has 81 categories.
OC Cost: 0.22083333333333333
  Category ID: 10, TP: 2, FP: 0, FN: 1, IOU List: [1.0, 0.75, 0.0]
Image 2 has 81 categories.
OC Cost: 0.0
  Category ID: 13, TP: 1, FP: 0, FN: 0, IOU List: [1.0]
Image 3 has 81 categories.
OC Cost: 0.255171529452006
  Category ID: 15, TP: 1, FP: 0, FN: 0, IOU List: [0.6712749600410461]
  Category ID: 47, TP: 1, FP: 0, FN: 1, IOU List: [0.6666666865348816, 0.0]
Image 4 has 81 categories.
OC Cost: 0.0902777761220932
  Category ID: 17, TP: 1, FP: 0, FN: 0, IOU List: [1.0]
  Category ID: 64, TP: 0, FP: 1, FN: 1, IOU List: [0.2777777910232544]
Image 5 has 81 categories.
OC Cost: 0.23569624694910915
  Category ID: 1, TP: 6, FP: 2, FN: 0, IO

In [14]:
def convert_per_class_result_dict(per_image_result_dict):
    per_category_result_dict = {}
    for index, per_image_result in per_image_result_dict.items():
        for category_id, result in per_image_result.items():
            if category_id not in per_category_result_dict:
                per_category_result_dict[category_id] = {
                    "iou_list": [],
                    "pred_iou_list": [],
                    "tp_num": 0,
                    "fp_num": 0,
                    "fn_num": 0,
                }
            
            if result is None:
                continue
            # if result["fp_num"] > 0 or result["fn_num"]:
            #     print(index, category_id, result)
            per_category_result_dict[category_id]["tp_num"] += result["tp_num"]
            per_category_result_dict[category_id]["fp_num"] += result["fp_num"]
            per_category_result_dict[category_id]["fn_num"] += result["fn_num"]
            per_category_result_dict[category_id]["iou_list"].extend(result["iou_list"])
            per_category_result_dict[category_id]["pred_iou_list"].extend(result["pred_iou_list"])
    return per_category_result_dict

per_category_result_dict = convert_per_class_result_dict(per_image_result_dict)

for category_id, result in per_category_result_dict.items():
    print(f"Category ID: {category_id}, TP: {result['tp_num']}, FP: {result['fp_num']}, FN: {result['fn_num']}, len(IOU List): {len(result['iou_list'])}, len(Pred IOU List): {len(result['pred_iou_list'])}")

Category ID: 1, TP: 6170, FP: 4054, FN: 4324, len(IOU List): 10494, len(Pred IOU List): 9102
Category ID: 2, TP: 110, FP: 194, FN: 194, len(IOU List): 304, len(Pred IOU List): 182
Category ID: 3, TP: 696, FP: 1179, FN: 1222, len(IOU List): 1918, len(Pred IOU List): 1324
Category ID: 4, TP: 143, FP: 44, FN: 194, len(IOU List): 337, len(Pred IOU List): 173
Category ID: 5, TP: 176, FP: 50, FN: 59, len(IOU List): 235, len(Pred IOU List): 208
Category ID: 6, TP: 89, FP: 18, FN: 117, len(IOU List): 206, len(Pred IOU List): 94
Category ID: 7, TP: 121, FP: 24, FN: 40, len(IOU List): 161, len(Pred IOU List): 132
Category ID: 8, TP: 95, FP: 70, FN: 301, len(IOU List): 396, len(Pred IOU List): 114
Category ID: 9, TP: 154, FP: 173, FN: 267, len(IOU List): 421, len(Pred IOU List): 255
Category ID: 10, TP: 84, FP: 136, FN: 427, len(IOU List): 511, len(Pred IOU List): 153
Category ID: 11, TP: 42, FP: 10, FN: 34, len(IOU List): 76, len(Pred IOU List): 47
Category ID: 13, TP: 42, FP: 4, FN: 46, len(IOU

In [15]:
import math

def calculate_score(per_category_result_dict,oc_cost_list,category_id2name):
    """
    Calculate precision, recall, F1 score, and mean IOU for each category and overall dataset.

    Args:
        per_category_result_dict (dict): Per-category results.
        oc_cost_list (list): List of occlusion costs.

    Returns:
        dict: Summary scores and data numbers for the dataset.
    """
    per_category_score_dict = {}

    dataset_score = {
        "summary_scores":{
            "micro_precision": 0.0,
            "micro_recall": 0.0,
            "micro_f1": 0.0,
            "m_iou": [],
            "m_pred_iou": [],
            "oc_cost": np.mean(oc_cost_list) if len(oc_cost_list) > 0 else 0.0,
            "macro_precision": [],
            "macro_recall": [],
            "macro_f1": [],
            "cm_iou": [],
            "cm_pred_iou": [],
        },
        "summary_data_num":{
            "tp_num": 0,
            "fp_num": 0,
            "fn_num": 0,
            "unkonown_fp_num": per_category_result_dict[-1]["fp_num"] if -1 in per_category_result_dict else 0,
            "iou_num": 0,
            "pred_iou_num": 0,
        },
    }

    for category_id, result in per_category_result_dict.items():
        #クラスごと
        tp_num = result["tp_num"]
        fp_num = result["fp_num"]
        fn_num = result["fn_num"]
        iou_list = result["iou_list"]
        pred_iou_list = result["pred_iou_list"]
    
        cm_iou = np.mean(iou_list) if len(iou_list) > 0 else 0.0
        cm_pred_iou = np.mean(pred_iou_list) if len(pred_iou_list) > 0 else 0.0
        # if math.isnan(m_iou):
        #     print(f"Category ID: {category_id} has NaN mIoU. Check the IOU list: {iou_list}")
        #     for iou in iou_list:
        #         assert not math.isnan(iou), f"IOU value is NaN in category {category_id}"
        precision = tp_num / (tp_num + fp_num) if (tp_num + fp_num) > 0 else 0.0
        recall = tp_num / (tp_num + fn_num) if (tp_num + fn_num) > 0 else 0.0
        f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        per_category_score = {
            "category_name": category_id2name[category_id],
            "precision": precision,
            "recall": recall,
            "f1_score": f1_score,
            "cm_iou": cm_iou,
            "cm_pred_iou": cm_pred_iou,
            "tp_num": tp_num,
            "fp_num": fp_num,
            "fn_num": fn_num,
        }
        per_category_score_dict[category_id] = per_category_score
        
        #データセット全体
        dataset_score["summary_data_num"]["tp_num"] += tp_num
        dataset_score["summary_data_num"]["fp_num"] += fp_num
        dataset_score["summary_data_num"]["fn_num"] += fn_num
        dataset_score["summary_data_num"]["iou_num"] += len(iou_list)
        dataset_score["summary_data_num"]["pred_iou_num"] += len(pred_iou_list)
        dataset_score["summary_scores"]["m_iou"].extend(iou_list)
        dataset_score["summary_scores"]["m_pred_iou"].extend(pred_iou_list)
        
        #カテゴリごと
        if category_id != -1:
            dataset_score["summary_scores"]["macro_precision"].append(precision)
            dataset_score["summary_scores"]["macro_recall"].append(recall)
            dataset_score["summary_scores"]["macro_f1"].append(f1_score)
            dataset_score["summary_scores"]["cm_iou"].append(cm_iou)
            dataset_score["summary_scores"]["cm_pred_iou"].append(cm_pred_iou)

    #データセット全体
    assert dataset_score["summary_data_num"]["tp_num"] + dataset_score["summary_data_num"]["fn_num"] == len(dataset_score["summary_scores"]["m_iou"]) ,\
        f"TP + FN mismatch: {dataset_score['summary_data_num']['tp_num']} + {dataset_score['summary_data_num']['fn_num']} != {len(dataset_score['summary_scores']['m_iou'])}"
    dataset_score["summary_scores"]["micro_precision"] = dataset_score["summary_data_num"]["tp_num"] / (dataset_score["summary_data_num"]["tp_num"] + dataset_score["summary_data_num"]["fp_num"]) if (dataset_score["summary_data_num"]["tp_num"] + dataset_score["summary_data_num"]["fp_num"]) > 0 else 0.0
    dataset_score["summary_scores"]["micro_recall"] = dataset_score["summary_data_num"]["tp_num"] / (dataset_score["summary_data_num"]["tp_num"] + dataset_score["summary_data_num"]["fn_num"]) if (dataset_score["summary_data_num"]["tp_num"] + dataset_score["summary_data_num"]["fn_num"]) > 0 else 0.0
    dataset_score["summary_scores"]["micro_f1"] = (2 * dataset_score["summary_scores"]["micro_precision"] * dataset_score["summary_scores"]["micro_recall"]) / (dataset_score["summary_scores"]["micro_precision"] + dataset_score["summary_scores"]["micro_recall"]) if (dataset_score["summary_scores"]["micro_precision"] + dataset_score["summary_scores"]["micro_recall"]) > 0 else 0.0
    dataset_score["summary_scores"]["m_iou"] = np.mean(dataset_score["summary_scores"]["m_iou"]) if len(dataset_score["summary_scores"]["m_iou"]) > 0 else 0.0
    dataset_score["summary_scores"]["m_pred_iou"] = np.mean(dataset_score["summary_scores"]["m_pred_iou"]) if len(dataset_score["summary_scores"]["m_pred_iou"]) > 0 else 0.0
    dataset_score["summary_scores"]["macro_precision"] = np.mean(dataset_score["summary_scores"]["macro_precision"]) if len(dataset_score["summary_scores"]["macro_precision"]) > 0 else 0.0
    dataset_score["summary_scores"]["macro_recall"] = np.mean(dataset_score["summary_scores"]["macro_recall"]) if len(dataset_score["summary_scores"]["macro_recall"]) > 0 else 0.0
    dataset_score["summary_scores"]["macro_f1"] = np.mean(dataset_score["summary_scores"]["macro_f1"]) if len(dataset_score["summary_scores"]["macro_f1"]) > 0 else 0.0
    dataset_score["summary_scores"]["cm_iou"] = np.mean(dataset_score["summary_scores"]["cm_iou"]) if len(dataset_score["summary_scores"]["cm_iou"]) > 0 else 0.0 
    dataset_score["summary_scores"]["cm_pred_iou"] = np.mean(dataset_score["summary_scores"]["cm_pred_iou"]) if len(dataset_score["summary_scores"]["cm_pred_iou"]) > 0 else 0.0
    
    return per_category_score_dict, dataset_score

category_id2name = {v: k for k, v in cat_name2id.items()}
per_category_score_dict, dataset_score = calculate_score(per_category_result_dict, oc_cost_list,category_id2name)

In [16]:
for category_id, score in per_category_score_dict.items():
    print(f"category_name: {score['category_name']}")
    print(f"Category ID: {category_id}, Precision: {score['precision']:.4f}, Recall: {score['recall']:.4f}, F1 Score: {score['f1_score']:.4f}, cmIoU: {score['cm_iou']:.4f}, cmPredIoU: {score['cm_pred_iou']:.4f}, TP: {score['tp_num']}, FP: {score['fp_num']}, FN: {score['fn_num']}")

for key,score in dataset_score["summary_scores"].items():
    print(f"{key}: {score:.4f}")
    
for key,num in dataset_score["summary_data_num"].items():
    print(f"{key}: {num}")

category_name: person
Category ID: 1, Precision: 0.6035, Recall: 0.5880, F1 Score: 0.5956, cmIoU: 0.5344, cmPredIoU: 0.6161, TP: 6170, FP: 4054, FN: 4324
category_name: bicycle
Category ID: 2, Precision: 0.3618, Recall: 0.3618, F1 Score: 0.3618, cmIoU: 0.3269, cmPredIoU: 0.5460, TP: 110, FP: 194, FN: 194
category_name: car
Category ID: 3, Precision: 0.3712, Recall: 0.3629, F1 Score: 0.3670, cmIoU: 0.3293, cmPredIoU: 0.4770, TP: 696, FP: 1179, FN: 1222
category_name: motorcycle
Category ID: 4, Precision: 0.7647, Recall: 0.4243, F1 Score: 0.5458, cmIoU: 0.3597, cmPredIoU: 0.7006, TP: 143, FP: 44, FN: 194
category_name: airplane
Category ID: 5, Precision: 0.7788, Recall: 0.7489, F1 Score: 0.7636, cmIoU: 0.6821, cmPredIoU: 0.7706, TP: 176, FP: 50, FN: 59
category_name: bus
Category ID: 6, Precision: 0.8318, Recall: 0.4320, F1 Score: 0.5687, cmIoU: 0.3943, cmPredIoU: 0.8641, TP: 89, FP: 18, FN: 117
category_name: train
Category ID: 7, Precision: 0.8345, Recall: 0.7516, F1 Score: 0.7908, cmI

# gt data
micro_precision: 0.9895
micro_recall: 0.9895
micro_f1: 0.9895
m_iou: 0.9895
oc_cost: 0.0000
macro_precision: 0.9951
macro_recall: 0.9951
macro_f1: 0.9951
cm_iou: 0.9951
tp_num: 35444
fp_num: 375
fn_num: 375
unkonown_fp_num: 0

# cross entropy
micro_precision: 0.5354
micro_recall: 0.3726
micro_f1: 0.4394
m_iou: 0.3360
oc_cost: 0.2708
macro_precision: 0.5672
macro_recall: 0.3237
macro_f1: 0.3933
cm_iou: 0.2927
tp_num: 13345
fp_num: 11581
fn_num: 22474
unkonown_fp_num: 437

# distance loss
micro_precision: 0.5174
micro_recall: 0.3565
micro_f1: 0.4221
m_iou: 0.3222
oc_cost: 0.2726
macro_precision: 0.5664
macro_recall: 0.3119
macro_f1: 0.3821
cm_iou: 0.2832
tp_num: 12769
fp_num: 11911
fn_num: 23050
unkonown_fp_num: 775

# distance forward kl loss
micro_precision: 0.5112
micro_recall: 0.3593
micro_f1: 0.4220
m_iou: 0.3235
oc_cost: 0.2735
macro_precision: 0.5697
macro_recall: 0.3168
macro_f1: 0.3879
cm_iou: 0.2842
tp_num: 12868
fp_num: 12304
fn_num: 22951
unkonown_fp_num: 1198

In [None]:
# gt data
micro_precision: 1.0000
micro_recall: 1.0000
micro_f1: 1.0000
m_iou: 1.0000
m_pred_iou: 1.0000
oc_cost: 0.0000
macro_precision: 1.0000
macro_recall: 1.0000
macro_f1: 1.0000
cm_iou: 1.0000
cm_pred_iou: 1.0000
tp_num: 35819
fp_num: 0
fn_num: 0
unkonown_fp_num: 0
iou_num: 35819
pred_iou_num: 35819

# cross entropy
micro_precision: 0.5364
micro_recall: 0.3733
micro_f1: 0.4402
m_iou: 0.3391
m_pred_iou: 0.6152
oc_cost: 0.2708
macro_precision: 0.5680
macro_recall: 0.3243
macro_f1: 0.3940
cm_iou: 0.2947
cm_pred_iou: 0.6543
tp_num: 13371
fp_num: 11555
fn_num: 22448
unkonown_fp_num: 437
iou_num: 35819
pred_iou_num: 19744

# distance loss
micro_precision: 0.5181
micro_recall: 0.3570
micro_f1: 0.4227
m_iou: 0.3258
m_pred_iou: 0.5878
oc_cost: 0.2726
macro_precision: 0.5671
macro_recall: 0.3122
macro_f1: 0.3825
cm_iou: 0.2852
cm_pred_iou: 0.6349
tp_num: 12787
fp_num: 11893
fn_num: 23032
unkonown_fp_num: 775
iou_num: 35819
pred_iou_num: 19854

# distance forward kl loss
micro_precision: 0.5120
micro_recall: 0.3598
micro_f1: 0.4226
m_iou: 0.3269
m_pred_iou: 0.5884
oc_cost: 0.2735
macro_precision: 0.5703
macro_recall: 0.3172
macro_f1: 0.3884
cm_iou: 0.2863
cm_pred_iou: 0.6341
tp_num: 12887
fp_num: 12285
fn_num: 22932
unkonown_fp_num: 1198
iou_num: 35819
pred_iou_num: 19900

# 欲しい情報
* 画像レベルで、クラスごとに分けて、どのpredとどのgtの検出インスタンスが紐づいたか、紐づいた分のiouのリスト、
# クラスごとの評価指標
* cm-accuracy
* cm-IoU
* macro-f1
* macro-recall
* macro-precision
# 全体の評価指標
* accuracy=micro-f1,micro-precision,micro-recall
* m-iou
# 画像ごとの評価指標
* oc-cost


In [17]:
# def calculate_per_class_m_iou(per_category_result_dict, iou_threshold=0.5):
#     """
#     Calculate per-class mIoU from the per-category results.

#     Args:
#         per_category_result_dict (dict): Dictionary containing per-category results.
#         iou_threshold (float): IoU threshold to consider a prediction as true positive.

#     Returns:
#         dict: Dictionary with category IDs as keys and mIoU values as values.
#     """
#     per_class_miou = {}
#     for category_id, result in per_category_result_dict.items():
#         tp_num = result["tp_num"]
#         fp_num = result["fp_num"]
#         fn_num = result["fn_num"]
#         iou_list = result["iou_list"]

#         if tp_num + fp_num + fn_num == 0:
#             continue

#         # Calculate mIoU
#         m_iou = sum(iou for iou in iou_list if iou >= iou_threshold) / max(tp_num, 1)
#         per_class_miou[category_id] = m_iou

#     return per_class_miou