In [None]:
from gorillatracker.model.wrappers_ssl import MoCoWrapper
from gorillatracker.utils.embedding_generator import generate_embeddings, df_from_predictions
from gorillatracker.model.wrappers_supervised import TimmEvalWrapper, BaseModuleSupervised
from pathlib import Path
from gorillatracker.data.nlet_dm import NletDataModule
from gorillatracker.data.nlet import build_onelet, SupervisedDataset
from torchvision.transforms import Resize, Normalize, Compose
import pandas as pd
import numpy as np


def get_finetuned_vit() -> MoCoWrapper:
    # ViT Large + DinoV2; finetuned with SSL and MoCo Loss
    # https://wandb.ai/gorillas/Embedding-VitLarge-MoCo-Face-Sweep/runs/rlemhfix
    finetuned = "/workspaces/gorillatracker/models/ssl/moco-accuracy-0.58.ckpt"
    return MoCoWrapper.load_from_checkpoint(
        checkpoint_path=finetuned,
        data_module=None,
        wandb_run=None,
    )

def get_mock_loss_kwargs() -> dict:
    return {
        'margin': 1.0,  # From the file
        's': 64.0,  # From the file
        'temperature': 0.07,  # Default value, not specified in the file
        'memory_bank_size': 4096,  # Default value, not specified in the file
        'embedding_size': 128,  # From the file
        'batch_size': 64,  # From the file
        'num_classes': None,  # Default value, not specified in the file
        'class_distribution': None,  # Default value, not specified in the file
        'use_focal_loss': False,  # Default value, not specified in the file
        'k_subcenters': 1,  # Default value, not specified in the file
        'accelerator': 'cuda',  # From the file
        'label_smoothing': 0.1,  # Default value, not specified in the file
        'l2_alpha': 0.1,  # From the file
        'l2_beta': 0.01,  # From the file
        'path_to_pretrained_weights': "",  # From the file
        'use_class_weights': False,  # Default value, not specified in the file
        'use_dist_term': False,  # Default value, not specified in the file
    }


def get_pretrained_vit() -> TimmEvalWrapper:
    # ViT Large + DinoV2
    model = BaseModuleSupervised(
        model_name_or_path="timm_eval/vit_large_patch14_dinov2.lvd142m",
        fix_img_size=224,
        freeze_backbone=True,
        wandb_run=None,
        data_module=None,
        loss_mode="offline",
        **get_mock_loss_kwargs(),
    )
    # model = TimmEvalWrapper(
    #     backbone_name="vit_large_patch14_dinov2.lvd142m",
    #     img_size=224,
    # )
    # model.freeze = lambda: None
    return model


def get_pretrained_efnet() -> TimmEvalWrapper:
    # EfficientNetV2 RW_M
    model = BaseModuleSupervised(
        model_name_or_path="timm_eval/efficientnetv2_rw_m",
        fix_img_size=224,
        freeze_backbone=True,
        wandb_run=None,
        data_module=None,
        loss_mode="offline",
        **get_mock_loss_kwargs(),
    )
    # model = TimmEvalWrapper(backbone_name="efficientnetv2_rw_m")
    # model.freeze = lambda: None
    return model


def get_finetuned_efnet() -> TimmEvalWrapper:
    # EfficientNetV2 RW_M; finetuned with SSL and MoCo Loss
    # TODO(liamvdv): add SSL trained effnet model.
    return None


def get_model_transforms(model):
    resize = getattr(model, "data_resize_transform", (224, 224))
    model_transforms = Resize(resize)
    normalize_transform = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    use_normalization = getattr(model, "use_normalization", True)
    # NOTE(liamvdv): normalization_mean, normalization_std are always default.
    if use_normalization:
        model_transforms = Compose([model_transforms, normalize_transform])
    return model_transforms


def _get_dataloader(model, path: Path):
    data_module = NletDataModule(
        data_dir=path,
        dataset_class=SupervisedDataset,
        nlet_builder=build_onelet,
        batch_size=64,
        workers=10,
        model_transforms=get_model_transforms(model),
        training_transforms=lambda x: x,
        dataset_names=["Showcase"],
    )

    data_module.setup("validate")
    dls = data_module.val_dataloader()  # val for transforms
    assert len(dls) == 1
    dl = dls[0]
    return dl


def get_df(model, path: Path):
    dl = _get_dataloader(model, path)
    preds = generate_embeddings(model, dl)
    df = df_from_predictions(preds)
    # TODO(liamvdv): Should be DF of
    #                id, embedding, label, label_string, input, model, dataset

    def transform_embedding(embedding_list):
        return np.array([tensor.item() for tensor in embedding_list])

    df["embedding"] = df["embedding"].apply(transform_embedding)
    df["label"] = df["label"].apply(lambda x: x.item())
    return df

In [None]:
on_cpu = True
models = {
    "ViT-Pretrained": get_pretrained_vit,
    "ViT-Finetuned": get_finetuned_vit,
    "EfN-Pretrained": get_pretrained_efnet,
    "EfN-Finetuned": get_finetuned_efnet,
}

# TODO(liamvdv): @robert: why filtered? Worauf sind die Dataset Stats?
BRISTOL = Path(
    "/workspaces/gorillatracker/data/supervised/bristol/cross_encounter_validation/cropped_frames_square_filtered"
)
SPAC = Path("/workspaces/gorillatracker/data/supervised/cxl_all/face_images_square")
datasets = {
    "Bristol": BRISTOL,
    "SPAC": SPAC,
}
dfs = []
for model_name, get_model in models.items():
    for dataset_name, dataset_path in datasets.items():
        model = get_model()
        if not model:
            print("Skipping model", model_name, "for dataset", dataset_name, "Model not yet implemented.")
            continue
        if on_cpu:
            model = model.cpu()
        df = get_df(model, dataset_path)
        df["dataset"] = dataset_name
        df["model"] = model_name
        print("Appending", model_name, dataset_name, f"{len(df)} rows")
        dfs.append(df)

        # Cleanup
        del model  # and?: torch.cuda.empty_cache()
merged_df = pd.concat(dfs, ignore_index=True)
merged_df.to_pickle("merged.pkl")
print("done")
# vitf_spac = merged_df[(merged_df['model'] == 'ViT-Finetuned') & (merged_df['dataset'] == 'SPAC')]

# model = get_moco_model()
# bristol = get_df(model, BRISTOL)
# # bristol["dataset"] = "bristol"
# # bristol["model"] = "ViT-Finetuned" (fine tuned) (MoCo-ViT-DinoV2)
# bristol.to_pickle("bristol.pkl")
# spac = get_df(model, SPAC)
# # spac["dataset"] = "spac"
# # spac["model"] = "ViT-Finetuned" (fine tuned)
# spac.to_pickle("spac.pkl")
# # merged_df = pd.concat([merged_df, df], ignore_index=True)
# print("done")