In [None]:
import fiftyone as fo
from tqdm import tqdm

In [None]:
dataset_name = "mcity_fisheye_2000"
dataset = fo.load_dataset(dataset_name)
dataset.reload()
dataset_schema = dataset.get_field_schema()
print(dataset_schema)
print(dataset)

In [None]:
vru_labels = ["skater",
                "child",
                "bicycle",
                "bicyclist",
                "cyclist",
                "bike",
                "rider",
                "motorcycle",
                "motorcyclist",
                "pedestrian",
                "person",
                "walker",
                "jogger",
                "runner",
                "skateboarder",
                "scooter",
                "delivery driver"]

gt_vru_labels = ["motorbike/cycler", "pedestrian"]

In [None]:
#Get fields of zero shot model predictions
pred_fields = []
dataset_schema = dataset.get_field_schema()
for field in dataset_schema:
    if "pred_" in field and "_vru" not in field:
        pred_fields.append(field)
        print(field)

In [None]:
samples_detections = [] # List of lists of list [model][sample][detections]
for field in tqdm(pred_fields, desc = "Getting detection values"):
    field_detections = dataset.values(f"{field}.detections")  # list of lists of detections per sample
    samples_detections.append(field_detections)

In [None]:
n_samples = len(dataset)
for i in tqdm(range(n_samples)):
    for j in range(len(pred_fields)):
        detections = samples_detections[j][i]
        if detections:
            # Create a new list to store kept detections
            kept_detections = []
            
            # Iterate over a copy of the list
            for detection in detections[:]:
                if detection.label in vru_labels:
                    print(f"Changed {detection.label} to vru")
                    detection.label = "vru"
                    kept_detections.append(detection)
                else:
                    print(f"Removing non-VRU detection: {detection.label}")
            # Replace original list with filtered list
            samples_detections[j][i] = kept_detections

In [None]:
gt_field = "ground_truth"
field_gt = dataset.values(f"{gt_field}.detections")

for i in tqdm(range(n_samples)):
    detections = field_gt[i]
    if detections:
        # Create new list for kept detections
        kept_detections = []
        
        # Process each detection
        for detection in detections[:]:  # Iterate over copy
            if detection.label in gt_vru_labels:
                print(f"Changed {detection.label} to vru")
                detection.label = "vru"
                kept_detections.append(detection)
            else:
                print(f"Removing non-VRU detection: {detection.label}")
        
        # Replace with filtered list
        field_gt[i] = kept_detections

In [None]:
for i, field in enumerate(pred_fields):
    pred_key = f"{field}_vru"
    dataset.add_sample_field(
                        pred_key,
                        fo.EmbeddedDocumentField,
                        embedded_doc_type=fo.Detections,
                    )
    dataset.set_values(f"{pred_key}.detections", samples_detections[i]) 

In [None]:
pred_key = f"{gt_field}_vru"
dataset.add_sample_field(
                        pred_key,
                        fo.EmbeddedDocumentField,
                        embedded_doc_type=fo.Detections,
                    )
dataset.set_values(f"{pred_key}.detections", field_gt) 

In [None]:
#Get fields of simplified VRU zero shot model predictions
dataset.reload()
dataset_schema = dataset.get_field_schema()
eval_results = []
for field in tqdm(dataset_schema):
    if "pred_" in field and "_vru" in field:
        print(field)
        eval_key = field.replace("pred_", "eval_")
        eval_result = dataset.evaluate_detections(
            field,
            gt_field="ground_truth_vru",
            iou=0.1,
            eval_key=eval_key,
            compute_mAP=True,
        )
        eval_results.append(eval_result)

In [None]:
for field in tqdm(dataset_schema):
    if "pred_" in field and "_vru" in field:
        eval_key = field.replace("pred_", "eval_")
        print(eval_key)
        print("TP: %d" % dataset.sum(f"{eval_key}_tp"))
        print("FP: %d" % dataset.sum(f"{eval_key}_fp"))
        print("FN: %d" % dataset.sum(f"{eval_key}_fn"))