In [None]:
import kv_bottleneck_experiments.utils.model as model_utils
import kv_bottleneck_experiments.utils.data as data_utils

import key_value_bottleneck.core as kv_core

from addict import Dict
import os
import torch

In [None]:
ROOT_DIR = os.environ.get("PROJECT_ROOT_DIR", None)

In [None]:
args_dict = {"method": "ours",
             "seed": 0,
             "batch_size": 256,
             "num_workers": 0,
             "dim_key": 16,
             "dim_value": 8,
             "topk": 1,
             "t_mode": "uniform_importance",
             "scaling_mode": "free_num_keys",
             "threshold_factor": 0.1,
             "accept_image_fmap": False,
             "init_mode": "random",
             "num_pairs": 10,
             "num_books": 128,
             "decay": 0.95,
             "splitting_mode": "random_projection",
             "root_dir": ROOT_DIR,
             "sample_codebook_temperature": 0,
             }

args = Dict(args_dict)
args.dataset_name = "CIFAR10"
args.pretrain_layer = 3
args.backbone = "clip_vit_b32"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
for dataset in ["STL10", "CIFAR10", "CIFAR100"]:
    for backbone in ["swav_resnet50w2", "dino_resnet50", "clip_vit_b32", "resnet50_imagenet_v2"]:
        for pretrain_layer in [3, 4]:
            args.dataset_name = dataset
            args.pretrain_layer = pretrain_layer
            args.backbone = backbone

            dataloader_train, dataloader_test = data_utils.get_dataloaders(
                dataset=dataset, args=args
            )
            key_value_pairs_per_codebook = model_utils.get_key_value_pairs_per_codebook(args)
            threshold_ema_dead_code = model_utils.get_threshold_ema_dead_code(args)

            if "cifar10" in args.backbone or "cifar100" in args.backbone:
                bottleneck_encoder_cls = kv_core.SLPretrainedBottleneckedEncoder
            elif args.backbone == "resnet50_imagenet_v2":
                bottleneck_encoder_cls = kv_core.SLPretrainedBottleneckedEncoder
            elif args.backbone == "clip_vit_b32":
                bottleneck_encoder_cls = kv_core.CLIPBottleneckedEncoder
            elif "swav" in args.backbone:
                bottleneck_encoder_cls = kv_core.SwavBottleneckedEncoder
            else:
                bottleneck_encoder_cls = kv_core.DinoBottleneckedEncoder

            dim_values = args.dim_value

            bottlenecked_encoder = bottleneck_encoder_cls(
                num_codebooks=args.num_books,
                key_value_pairs_per_codebook=key_value_pairs_per_codebook,
                backbone=args.backbone,
                extracted_layer=args.pretrain_layer,
                pool_embedding=not args.accept_image_fmap,
                init_mode=args.init_mode,
                splitting_mode=args.splitting_mode,
                dim_values=dim_values,
                dim_keys=args.dim_key,
                decay=args.decay,
                eps=1e-5,
                threshold_ema_dead_code=threshold_ema_dead_code,
                concat_values_from_all_codebooks=False,
                sample_codebook_temperature=args.sample_codebook_temperature,
                return_values_only=False,
                topk=args.topk,
            )
            bottlenecked_encoder = bottlenecked_encoder.freeze_encoder()
            bottlenecked_encoder.eval()
            deviating_transforms = None
            if bottlenecked_encoder.transforms is not None:
                deviating_transforms = bottlenecked_encoder.transforms
                if isinstance(dataloader_train, torch.utils.data.DataLoader):
                    dataloader_train.dataset.transform = deviating_transforms
                    dataloader_test.dataset.transform = deviating_transforms

            bottlenecked_encoder.to(device)

            data_utils.create_embedding_dataset(
                dataloader_train, dataloader_test, bottlenecked_encoder, args
            )