In [None]:
# NOTE(liamvdv): todo: move to SSLConfig
# from gorillatracker.ssl_pipeline.ssl_config import SSLConfig

from sqlalchemy import select, text, create_engine
from sqlalchemy.orm import Session
from gorillatracker.ssl_pipeline.models import TrackingFrameFeature
from gorillatracker.ssl_pipeline.dataset import GorillaDatasetKISZ


sample = 10
query = text(
    f"""WITH ranked_features AS (
    SELECT
        tracking_id,
        tracking_frame_feature_id,
        bbox_width,
        bbox_height,
        frame_nr,
        feature_type,
        ROW_NUMBER() OVER (PARTITION BY tracking_id ORDER BY RANDOM()) AS rn
    FROM tracking_frame_feature
    WHERE feature_type = 'face_45'
        AND bbox_width >= 184
        AND bbox_height >= 184
        AND tracking_id IS NOT NULL
)
SELECT
    tracking_frame_feature_id
FROM ranked_features
WHERE rn <= {sample}
"""
)


# engine = create_engine(GorillaDatasetKISZ.DB_URI)

# # stmt = select(TrackingFrameFeature).where(TrackingFrameFeature.tracking_frame_feature_id.in_(subquery))

# with Session(engine) as session:
#     result = session.execute(query).scalars().all()
#     for row in result:
#         print(row)

In [None]:
from pathlib import Path
from typing import List, Literal, Optional, Tuple

from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms

import gorillatracker.type_helper as gtypes
from gorillatracker.transform_utils import SquarePad
from gorillatracker.type_helper import Id, Label
from gorillatracker.utils.labelencoder import LabelEncoder

base_path = "cropped-images/2024-04-18"


def cast_label_to_int(labels: List[str]) -> List[int]:
    return LabelEncoder.encode_list(labels)


class HackDataset(Dataset[Tuple[Id, Tensor, Label]]):
    def get_tffs(self) -> list[TrackingFrameFeature]:
        engine = create_engine(GorillaDatasetKISZ.DB_URI)

        stmt = select(TrackingFrameFeature).where(TrackingFrameFeature.tracking_frame_feature_id.in_(query))

        with Session(engine) as session:
            return session.execute(stmt).scalars().all()

    def __init__(
        self, data_dir: str, partition: Literal["train", "val", "test"], transform: Optional[gtypes.Transform] = None
    ):
        self.ttfs = self.get_tffs()
        self.transform = transform
        self.partition = partition

    def __len__(self) -> int:
        return len(self.ttfs)

    def __getitem__(self, idx: int) -> Tuple[Id, Tensor, Label]:
        """tracklets will be labels for now"""
        ttf = self.ttfs[idx]

        img = Image.open(ttf.cache_path(base_path))
        if self.transform:
            img = self.transform(img)
        assert ttf.tracking_id is not None
        return ttf.tracking_frame_feature_id, img, ttf.tracking_id

    @classmethod
    def get_transforms(cls) -> gtypes.Transform:
        return transforms.Compose(
            [
                SquarePad(),
                # Uniform input, you may choose higher/lower sizes.
                transforms.Resize(224),
                transforms.ToTensor(),
            ]
        )

In [None]:
from gorillatracker.utils.embedding_generator import (
    get_run,
    get_model_from_run,
    generate_embeddings,
    read_embeddings_from_disk,
)
from torchvision import transforms

regen = True
file = "example.pkl"

if regen:
    run_url = "https://wandb.ai/gorillas/Embedding-SwinV2Large-CXL-Open/runs/bp5e1rnx/workspace?nw=nwuserliamvdv"
    run = get_run(run_url)
    model = get_model_from_run(run)
    dataset = HackDataset("", "val", transforms.Compose([HackDataset.get_transforms(), model.get_tensor_transforms()]))
    df = generate_embeddings(model, dataset, device="gpu")
    df["partition"] = "val"
    df.to_pickle(file)
else:
    df = read_embeddings_from_disk(file)