In [None]:
# from gorillatracker.args import TrainingArgs
import json
from pathlib import Path
from typing import Any, Callable, Dict, Literal, Tuple, Type, Union
from urllib.parse import urlparse

import cv2
import cv2.typing as cvt
import pandas as pd
import torch
import torchvision.transforms as transforms
import wandb
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm

from gorillatracker.model.get_model_cls import get_model_cls
from gorillatracker.model.base_module import BaseModule
from gorillatracker.train_utils import get_dataset_class
from gorillatracker.type_helper import Label
from gorillatracker.scripts.create_dataset_from_videos import _crop_image

wandbRun = Any


def _get_frames_for_ids(json_path: str) -> Any:
    """Get the frames for the given IDs.

    Args:
        json_path: Path to the JSON file containing the IDs.

    Returns:
        Dictionary of IDs to frames.
    """
    id_frames: Any = {}
    face_class: int = 1
    # read the JSON file
    with open(json_path, "r") as f:
        data = json.load(f)
    for frame_idx, frame in enumerate(data["labels"]):
        for bbox in frame:
            if bbox["class"] != face_class:
                continue
            id = int(bbox["id"])
            if id not in id_frames:
                id_frames[id] = []
            id_frames[id].append((frame_idx, (bbox["center_x"], bbox["center_y"], bbox["w"], bbox["h"])))

    return id_frames


def get_wandb_api() -> wandb.Api:
    if not hasattr(get_wandb_api, "api"):
        get_wandb_api.api = wandb.Api()  # type: ignore
    return get_wandb_api.api  # type: ignore


def parse_wandb_url(url: str) -> Tuple[str, str, str]:
    assert url.startswith("https://wandb.ai/")
    parsed = urlparse(url)
    assert parsed.netloc == "wandb.ai"
    print(parsed, parsed.path.split("/"), parsed.path)
    parts = parsed.path.strip("/").split(
        "/"
    )  # ['gorillas', 'Embedding-SwinV2-CXL-Open', 'runs', 'fnyvl65k', 'overview']
    entity, project, s_runs, run_id, *rest = parts
    assert (
        s_runs == "runs"
    ), "expect: https://wandb.ai/gorillas/Embedding-SwinV2-CXL-Open/runs/fnyvl65k/overview like format."
    return entity, project, run_id


def get_run(url: str) -> wandbRun:
    # https://docs.wandb.ai/ref/python/run
    entity, project, run_id = parse_wandb_url(url)
    run = get_wandb_api().run(f"{entity}/{project}/{run_id}")  # type: ignore
    return run


def load_model_from_wandb(
    wandb_fullname: str, model_cls: Type[BaseModule], model_config: Dict[str, Any], device: str = "cpu"
) -> BaseModule:
    api = get_wandb_api()

    artifact = api.artifact(  # type: ignore
        wandb_fullname,
        type="model",
    )
    artifact_dir = artifact.download()
    model = artifact_dir + "/model.ckpt"  # all of our models are saved as model.ckpt
    checkpoint = torch.load(model, map_location=torch.device("cpu"))
    model_state_dict = checkpoint["state_dict"]

    model = model_cls(**model_config)

    if (
        "loss_module_train.prototypes" in model_state_dict or "loss_module_val.prototypes" in model_state_dict
    ):  # necessary because arcface loss also saves prototypes
        model.loss_module_train.prototypes = torch.nn.Parameter(model_state_dict["loss_module_train.prototypes"])
        model.loss_module_val.prototypes = torch.nn.Parameter(model_state_dict["loss_module_val.prototypes"])
    # note the following lines can fail if your model was not trained with the same 'embedding structure' as the current model class
    # easiest fix is to just use the old embedding structure in the model class
    model.load_state_dict(model_state_dict)

    model.to(device)
    model.eval()
    return model


def generate_embeddings(model: BaseModule, dataset: Any, device: str = "cpu") -> pd.DataFrame:
    embeddings = []
    df = pd.DataFrame(columns=["embedding", "label", "input", "label_string"])
    with torch.no_grad():
        print("Generating embeddings...")
        for imgs, labels in tqdm(dataset):
            if isinstance(imgs, torch.Tensor):
                imgs = [imgs]
                labels = [labels]
            batch_inputs = torch.stack(imgs)
            if batch_inputs.shape[0] != 1:
                batch_inputs = batch_inputs.unsqueeze(1)
            batch_inputs = batch_inputs.to(device)
            embeddings = model(batch_inputs)

            for i in range(len(imgs)):
                input_img = transforms.ToPILImage()(batch_inputs[i].cpu())
                df = pd.concat(
                    [
                        df,
                        pd.DataFrame(
                            {
                                "embedding": [embeddings[i]],
                                "label": [labels[i]],
                                "input": [input_img],
                                "label_string": [dataset.mapping[labels[i]]] if dataset.mapping else None,
                            }
                        ),
                    ]
                )
    df.reset_index(drop=False, inplace=True)
    return df


def get_dataset(
    model: BaseModule,
    partition: Literal["train", "val", "test"],
    data_dir: str,
    dataset_class: str,
    transform: Union[Callable[..., Any], None] = None,
) -> Dataset[Tuple[Any, Label]]:
    cls = get_dataset_class(dataset_class)
    if transform is None:
        transform = transforms.Compose(
            [
                cls.get_transforms(),  # type: ignore
                model.get_tensor_transforms(),
            ]
        )

    return cls(  # type: ignore
        data_dir=data_dir,
        partition=partition,
        transform=transform,
    )


def get_latest_model_checkpoint(run: wandbRun) -> wandb.Artifact:
    models = [a for a in run.logged_artifacts() if a.type == "model"]
    return max(models, key=lambda a: a.created_at)


def generate_embeddings_from_run(run_url: str, outpath: str) -> pd.DataFrame:
    """
    generate a pandas df that generates embeddings for all images in the dataset partitions train and val.
    stores to DataFrame
    partition, image_path, embedding, label, label_string
    """
    out = Path(outpath)
    is_write = outpath != "-"
    if is_write:
        assert not out.exists(), "outpath must not exist"
        assert out.parent.exists(), "outpath parent must exist"
        assert out.suffix == ".pkl", "outpath must be a pickle file"

    run = get_run(run_url)
    print("Using model from run:", run.name)
    print("Config:", run.config)
    # args = TrainingArgs(**run.config) # NOTE(liamvdv): contains potenially unknown keys / missing keys (e. g. l2_beta)
    args = {
        k: run.config[k]
        for k in (
            # Others:
            "model_name_or_path",
            "dataset_class",
            "data_dir",
            # Model Params:
            "embedding_size",
            "from_scratch",
            "loss_mode",
            "weight_decay",
            "lr_schedule",
            "warmup_mode",
            "warmup_epochs",
            "max_epochs",
            "initial_lr",
            "start_lr",
            "end_lr",
            "beta1",
            "beta2",
            # NOTE(liamvdv): might need be extended by other keys if model keys change
        )
    }

    print("Loading model from latest checkpoint")
    model_path = get_latest_model_checkpoint(run).qualified_name
    model_cls = get_model_cls(args["model_name_or_path"])
    model = load_model_from_wandb(model_path, model_cls=model_cls, model_config=args)

    train_dataset = get_dataset(
        partition="train", data_dir=args["data_dir"], model=model, dataset_class=args["dataset_class"]
    )
    val_dataset = get_dataset(
        partition="val", data_dir=args["data_dir"], model=model, dataset_class=args["dataset_class"]
    )

    val_df = generate_embeddings(model, val_dataset)
    val_df["partition"] = "val"

    train_df = generate_embeddings(model, train_dataset)
    train_df["partition"] = "train"

    df = pd.concat([train_df, val_df], ignore_index=True)

    print("Embeddings for", len(df), "images generated")

    # store
    if is_write:
        df.to_pickle(outpath)
    print("done")
    return df


def generate_embeddings_from_tracked_video(
    model: BaseModule, video_path: str, tracking_data, model_transforms=lambda x: x
) -> pd.DataFrame:  # TODO
    """
    Args:
        model: The model to use for embedding generation.
        video_path: Path to the video.
        tracking_data: Dictionary of Individual IDs to frames. -> {id: List[(frame_idx, (bbox))]} (bbox = (x, y, w, h)

    Returns:
        DataFrame with columns: invididual_id, frame_id, bbox, embedding,
    """
    min_frames = 15  # discard if less than 5 images
    max_per_individual = 15

    tracking_data = {
        id: frames for id, frames in tracking_data.items() if len(frames) >= min_frames
    }  # discard if less than 5 images
    print("Using", len(tracking_data), "individuals")

    video = cv2.VideoCapture(video_path)
    embedding_img_table = pd.DataFrame(columns=["embedding", "frame_id", "bbox", "invididual_id"])

    for id, frames in tracking_data.items():
        step_size = len(frames) // max_per_individual
        if step_size == 0:
            continue
        frame_list = [frames[i] for i in range(0, max_per_individual * step_size, step_size)]
        for frame_idx, bbox in frame_list:
            video.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            frame = video.read()[1]  # read the frame. read() returns a tuple of (success, frame)
            embedding = get_embedding_from_frame(model, frame, bbox, model_transforms)
            embedding_img_table = pd.concat(
                [
                    embedding_img_table,
                    pd.DataFrame(
                        {
                            "invididual_id": [id],
                            "frame_id": [frame_idx],
                            "bbox": [bbox],
                            "embedding": [embedding],
                        }
                    ),
                ],
                ignore_index=True,
            )
    video.release()
    embedding_img_table.reset_index(drop=False, inplace=True)
    return embedding_img_table


@torch.no_grad()
def get_embedding_from_frame(model: BaseModule, frame: cvt.MatLike, bbox, model_transforms) -> torch.Tensor:  # TODO
    frame_cropped = _crop_image(
        frame,
        bbox[0],  # x
        bbox[1],  # y
        bbox[2],  # w
        bbox[3],  # h
    )

    # convert to pil image
    img = cv2.cvtColor(frame_cropped, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(img)
    img = model_transforms(img)

    model.eval()
    embedding = model(img.unsqueeze(0))
    return embedding


def read_embeddings_from_disk(path: str) -> pd.DataFrame:
    return pd.read_pickle(path)

## Code used for loading in another notebook

In [None]:
from gorillatracker.model.wrappers_supervised import SwinV2LargeWrapper
from gorillatracker.transform_utils import SquarePad
from torchvision.transforms import v2 as transforms_v2

model_config = {
    "embedding_size": 128,
    "from_scratch": False,
    "loss_mode": "softmax/arcface",
    "weight_decay": 0.0005,
    "lr_schedule": "cosine",
    "warmup_mode": "linear",
    "warmup_epochs": 10,
    "max_epochs": 100,
    "initial_lr": 0.01,
    "start_lr": 0.01,
    "end_lr": 0.0001,
    "beta1": 0.9,
    "beta2": 0.999,
    "model_name_or_path": "SwinV2LargeWrapper",
    "stepwise_schedule": True,
    "lr_interval": 10,
    "l2_beta": 0.0,
    "l2_alpha": 0.0,
    "path_to_pretrained_weights": "a/b/c",
}
model = load_model_from_wandb(
    "gorillas/Embedding-SwinV2Large-CXL-Open/model-a4t93htr:v14", SwinV2LargeWrapper, model_config, "cpu"
)
model.eval()
model_transforms = transforms.Compose(
    [
        SquarePad(),
        # Uniform input, you may choose higher/lower sizes.
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Resize((192), antialias=True),
        transforms_v2.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
video_path = "/workspaces/gorillatracker/video_data/M002_20220725_006.mp4"
tracked_video_path = (
    "/workspaces/gorillatracker/data/derived_data/spac_gorillas_converted_labels_tracked/M002_20220725_006_tracked.json"
)
tracked_video = _get_frames_for_ids(tracked_video_path)
embedding_img_table = generate_embeddings_from_tracked_video(model, video_path, tracked_video, model_transforms)

In [None]:
embedding_img_table.head()

In [None]:
# plot embeddings using t-SNE use one color for each individual_id

from sklearn.manifold import TSNE
import seaborn as sns

tsne = TSNE(n_components=2, random_state=42, perplexity=10)
X = torch.stack(embedding_img_table.embedding.to_list()).numpy().reshape(-1, 128)
print(X.shape)
X_2d = tsne.fit_transform(X)

embedding_img_table["tsne-2d-one"] = X_2d[:, 0]
embedding_img_table["tsne-2d-two"] = X_2d[:, 1]

sns.scatterplot(
    x="tsne-2d-one",
    y="tsne-2d-two",
    hue="invididual_id",
    palette=sns.color_palette("tab10", len(embedding_img_table.invididual_id.unique())),
    data=embedding_img_table,
    legend="full",
    alpha=0.7,
)

In [None]:
# plot a frame for each individual_id

import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, len(embedding_img_table.invididual_id.unique()), figsize=(20, 10))
for i, id in enumerate(embedding_img_table.invididual_id.unique()):
    img = embedding_img_table[embedding_img_table.invididual_id == id].iloc[10].frame_id
    bbox = embedding_img_table[embedding_img_table.invididual_id == id].iloc[0].bbox
    # get the frame from the video
    video = cv2.VideoCapture(video_path)
    video.set(cv2.CAP_PROP_POS_FRAMES, img)
    img = video.read()[1]
    img = _crop_image(img, bbox[0], bbox[1], bbox[2], bbox[3])
    video.release()
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(img)
    # crop image
    ax[i].imshow(img)
    ax[i].set_title(f"Individual {id}")
    ax[i].axis("off")

plt.show()

In [None]:
# merge individual 19 20 and 67

embedding_img_table["invididual_id"] = embedding_img_table["invididual_id"].replace({28: 19})
embedding_img_table["invididual_id"] = embedding_img_table["invididual_id"].replace({67: 19})