### Features
- Load mmpretrain classification model and compute embeddings
- Compute visualization and similarity
- Apply model predictions
- Compute mistakenness
- Find out label inconsistency among similar samples

In [None]:
import os.path as osp
from copy import deepcopy

import fiftyone as fo
import fiftyone.brain as fob
import fiftyone.core.models as fom
import fiftyone.types as fot
import fiftyone.utils as fou
import patchbrain  # noqa: F401
import yaml
from fiftyone import ViewField as F

In [None]:
fo.annotation_config.backends["cvat"]["segment_size"] = 300

In [None]:
dataset_type = fot.FiftyOneDataset
config_file = "../../../config.yaml"

if osp.exists(config_file):
    print(f"Loading config from {config_file}")
    config = yaml.safe_load(open(config_file, "r"))
    dataset_dir = config["dataset_dir"]
    model_config = config["model_config"]
    model_checkpoint = config["model_checkpoint"]
    categories = config["categories"]
    anno_key = config["anno_key"]
    prediction_field = config["prediction_field"]
    label_field = config["label_field"]
    temp_anno_field = config["temp_anno_field"]
else:
    dataset_dir = "/home/fkwong/datasets/82_truckcls/data/raw/truckcls-fiftyone"
    model_config = "/home/fkwong/workspace/srcs/public/grp_openmm/mmpretrain/work_dirs/efficientnet-b1_8xb32_truckcls/20250805_172827/efficientnet-b1_8xb32_truckcls.py"
    model_checkpoint = "/home/fkwong/workspace/srcs/public/grp_openmm/mmpretrain/work_dirs/efficientnet-b1_8xb32_truckcls/20250805_172827/best_accuracy_top1_epoch_46.pth"
    categories = [
        "danger_vehicle",
        "dumper",
        "dumper6",
        "others",
        "pickup",
        "truck_box",
        "unknown",
    ]
    anno_key = "cvat_annotation"
    prediction_field = "efficientnet_b1"
    label_field = "ground_truth"  # classification
    temp_anno_field = "temp_annotation"

In [None]:
dataset = fo.Dataset.from_dir(dataset_type=dataset_type, dataset_dir=dataset_dir)

In [None]:
default_config = deepcopy(fou.mmpretrain.MMPRETRAIN_DEFAULT_DEPLOYMENT_CONFIG)
default_config["config"]["model_config"] = model_config
default_config["config"]["model_checkpoint"] = model_checkpoint
default_config["config"]["classes"] = categories
model = fom.load_model(default_config)

In [None]:
dataset.compute_embeddings(
    model,
    embeddings_field=f"{prediction_field}_embeddings",
    batch_size=32,
    num_workers=8,
    progress=True,
)

In [None]:
if f"{prediction_field}_visualization" in dataset.list_brain_runs():
    dataset.delete_brain_run(f"{prediction_field}_visualization")
if f"{prediction_field}_similarity" in dataset.list_brain_runs():
    dataset.delete_brain_run(f"{prediction_field}_similarity")

In [None]:
visualization_result = fob.compute_visualization(
    dataset,
    embeddings=f"{prediction_field}_embeddings",
    brain_key=f"{prediction_field}_visualization",
    progress=True,
)

In [None]:
similarity_result = fob.compute_similarity(
    dataset,
    embeddings=f"{prediction_field}_embeddings",
    brain_key=f"{prediction_field}_similarity",
    progress=True,
)

In [None]:
dataset.apply_model(
    model,
    label_field=prediction_field,
    batch_size=32,
    store_logits=True,
    num_workers=8,
    progress=True,
)

In [None]:
prediction_view = dataset.match(~F(prediction_field).is_null())

In [None]:
mistakenness_field = f"{prediction_field}_mistakenness"
fob.compute_mistakenness(
    prediction_view,
    prediction_field,
    "ground_truth",
    mistakenness_field=mistakenness_field,
)

In [None]:
compute_mistakenness_view = prediction_view.match(
    (F(mistakenness_field) > 0.92)
).sort_by(
    mistakenness_field,
    reverse=True,
)

In [None]:
sess = fo.Session(compute_mistakenness_view, auto=False)

In [None]:
similarity_result.find_duplicates(thresh=0.3)
mistakenness_view = similarity_result.mistakenness_view(
    "ground_truth.label",
    type_field=f"{prediction_field}_dup_type",
    id_field=f"{prediction_field}_dup_id",
    dist_field=f"{prediction_field}_dup_dist",
    reverse=False,
)

In [None]:
sess.view = mistakenness_view

In [None]:
# anno_view = dataset.select(sess.selected)
anno_view = dataset.match_tags("relabel")

In [None]:
if len(anno_view) > 0:
    if anno_key in dataset.list_saved_views():
        print(f"{anno_key} view existed!")
    else:
        dataset.save_view(anno_key, anno_view)
else:
    print("No sample to be annotated.")

In [None]:
if len(anno_view) > 0:
    sess = fo.Session(anno_view, auto=False)
else:
    print("No sample to be reviewed.")

In [None]:
if len(anno_view) > 0:
    anno_results = anno_view.annotate(
        anno_key=anno_key,
        label_field=label_field,
        label_type="classification",
        classes=categories,
        launch_editor=False,
    )
else:
    print("No sample to be annotated.")

In [None]:
anno_view = dataset.load_annotation_view(anno_key)
anno_view.load_annotations(
    anno_key,
    dest_field=temp_anno_field,
    cleanup=False,
)
sess.view = anno_view

In [None]:
anno_results = anno_view.load_annotation_results(anno_key)
for k, v in anno_results.get_status()[label_field].items():
    if v["status"] != "completed":
        print(f"Task-{k} is not completed yet, current status: {v['status']}")
        break
else:
    anno_view.load_annotations(anno_key, dest_field=label_field, cleanup=False)

In [None]:
anno_results = dataset.load_annotation_results(
    anno_key,
    cache=False,
)

In [None]:
dataset.delete_sample_fields(
    [
        temp_anno_field,
        mistakenness_field,
        prediction_field,
    ],
    error_level=1,
)
anno_results.cleanup()
if anno_key in dataset.list_annotation_runs():
    dataset.delete_annotation_run(anno_key)
if anno_key in dataset.list_saved_views():
    dataset.delete_saved_view(anno_key)
dataset.list_saved_views()
dataset.list_annotation_runs()
dataset.untag_samples("relabel")

In [None]:
dataset.export(
    export_dir=dataset_dir,
    dataset_type=dataset_type,
    export_media=True,
)