### Imports

In [None]:
import os
import sys
sys.path.insert(0, "/Users/neeraja/fiftyone")
os.environ["PYTHONPATH"] = "/Users/neeraja/fiftyone:/Users/neeraja/fiftyone-brain"

In [None]:
import pytest
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, pearsonr
from sklearn.manifold import TSNE

In [None]:
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.operators as foo
import fiftyone.brain as fob

In [None]:
sys.path.append(os.path.dirname(os.getcwd()))
from annoprop import propagate_annotations_sam2, estimate_propagatability
from utils import evaluate_success_rate

### Multi-instance Dataset

In [None]:
dataset = foz.load_zoo_dataset("https://github.com/voxel51/davis-2017", split="validation", format="image")
SELECT_SEQUENCES = ["india"]
view = dataset.match_tags(SELECT_SEQUENCES)

Partially label

In [None]:
if "human_labels_test" in view._dataset.get_field_schema():
    view._dataset.delete_sample_field("human_labels_test")
    view._dataset.add_sample_field(
        "human_labels_test",
        fo.EmbeddedDocumentField,
        embedded_doc_type=fo.Detections,
    )

sequences = view.distinct("tags")
sequences.remove("val")

new_frame_number = 0
for seq in sequences:
    seq_view = view.match_tags(seq).sort_by("frame_number")
    seq_view.set_values(
        "new_frame_number",
        [new_frame_number + ii for ii in range(len(seq_view))]
    )
    new_frame_number += len(seq_view)
    for ii, sample in enumerate(seq_view):
        if ii % 10 == 0:
            sample["human_labels_test"] = sample["ground_truth"]
            sample.save()

### Multi-scene dataset

In [None]:
dataset = fo.load_dataset("basketball_frames")
dataset_slice_1 = dataset.load_saved_view("side_top_layup").limit(20)
dataset_slice_2 = dataset.load_saved_view("underbasket_reverse_layup").limit(20)
view = dataset_slice_1.concat(dataset_slice_2)

Partially Label

In [None]:
if "human_labels_test" in view._dataset.get_field_schema():
    view._dataset.delete_sample_field("human_labels_test")
    view._dataset.add_sample_field(
        "human_labels_test",
        fo.EmbeddedDocumentField,
        embedded_doc_type=fo.Detections,
    )

for ii, sample in enumerate(view):
    if ii % 2 == 0:
        sample["human_labels_test"] = sample["ha_test_1"]
        sample.save()

### Propagation with SAM2

In [None]:
from labelprop_methods.sam2 import PropagatorSAM2

In [None]:
input_annotation_field="human_labels_test"
output_annotation_field="human_labels_test_propagated"
sort_field="new_frame_number"

In [None]:
propagator = PropagatorSAM2()

In [None]:
if view.has_field(sort_field):
    image_path_list = view.sort_by(sort_field).values("filepath")
else:
    image_path_list = view.values("filepath")
propagator.initialize(image_path_list)

#### Backbone Embeddings

In [None]:
def compute_sam2_backbone_embeddings(sample, embedding_field_name="sam2_backbone_embeddings"):
    spatial_feat = propagator.extract_spatial_embeddings(sample["filepath"])
    sample[embedding_field_name] = spatial_feat
    return sample

_ = list(view.map_samples(compute_sam2_backbone_embeddings, num_workers=1, save=True))

In [None]:
from embedding_utils import compute_hausdorff_mds_embedding

In [None]:
compute_hausdorff_mds_embedding(
    view,
    spatial_embedding_field_name="sam2_backbone_embeddings",
    mds_embedding_field_name="embeddings_hausdorff_mds_sam2",
)

In [None]:
fob.compute_visualization(
    view,
    model=None,
    embeddings="embeddings_hausdorff_mds_sam2",
    brain_key="embedding_hausdorff_mds_sam2_run",
)

### TODO: For the multi-instance dataset
- Compute hausdorff + mds space
- Visualize
- Plot propagatability