In [1]:
from gorillatracker.model.wrappers_ssl import MoCoWrapper
from gorillatracker.utils.embedding_generator import generate_embeddings, df_from_predictions
from pathlib import Path
from gorillatracker.data.nlet_dm import NletDataModule
from gorillatracker.data.nlet import build_onelet, SupervisedDataset, SupervisedKFoldDataset
from torchvision.transforms import Resize, Normalize, Compose
import pandas as pd

# TODO(liamvdv): @robert: why filtered? Worauf sind die Dataset Stats?
BRISTOL = Path(
    "/workspaces/gorillatracker/data/supervised/bristol/cross_encounter_validation/cropped_frames_square_filtered"
)
SPAC = Path("/workspaces/gorillatracker/data/supervised/cxl_all/face_images_square")


def get_moco_model(
    checkpoint_path: str = "/workspaces/gorillatracker/models/ssl/moco-accuracy-0.58.ckpt",
) -> MoCoWrapper:
    return MoCoWrapper.load_from_checkpoint(checkpoint_path=checkpoint_path, data_module=None, wandb_run=None)


def get_model_transforms(model):
    resize = getattr(model, "data_resize_transform", (224, 224))
    model_transforms = Resize(resize)
    normalize_transform = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    use_normalization = getattr(model, "use_normalization", True)
    # NOTE(liamvdv): normalization_mean, normalization_std are always default.
    if use_normalization:
        model_transforms = Compose([model_transforms, normalize_transform])
    return model_transforms


def _get_dataloader(model, path: Path):
    data_module = NletDataModule(
        data_dir=path,
        dataset_class=SupervisedDataset,
        nlet_builder=build_onelet,
        batch_size=64,
        workers=10,
        model_transforms=get_model_transforms(model),
        training_transforms=lambda x: x,
        dataset_names=["Showcase"],
    )

    data_module.setup("validate")
    dls = data_module.val_dataloader()  # val for transforms
    assert len(dls) == 1
    dl = dls[0]
    return dl


def get_df(model, path: Path):
    dl = _get_dataloader(model, path)
    preds = generate_embeddings(model, dl)
    df = df_from_predictions(preds)
    return df

In [2]:
model = get_moco_model()
bristol = get_df(model, BRISTOL)
bristol.to_pickle("bristol.pkl")
spac = get_df(model, SPAC)
spac.to_pickle("spac.pkl")
print("done")

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/conda/envs/research/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA H100 80GB HBM3') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_m

Predicting: |          | 0/? [00:00<?, ?it/s]

Class 120 has less than two samples (samples: 1).
Class 125 has less than two samples (samples: 1).
Class 127 has less than two samples (samples: 1).
Class 129 has less than two samples (samples: 1).
Class 130 has less than two samples (samples: 1).
Class 132 has less than two samples (samples: 1).
Class 134 has less than two samples (samples: 1).
Class 135 has less than two samples (samples: 1).
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

done
