# Encoding CIFAR100 with a set of pretrained models

We will use the [HuggingFace datasets library](https://huggingface.co/datasets) to load the CIFAR100 dataset. We will then encode the images using a set of pretrained models from the [timm library](https://rwightman.github.io/pytorch-image-models/) and from Transformers.

# Imports and configuration

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging

import pandas as pd
import torch
import torch.nn.functional as F
from nn_core.common import PROJECT_ROOT
import random

from pathlib import Path

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum
from pytorch_lightning import seed_everything
import matplotlib.pyplot as plt
import random
from collections import namedtuple
import timm
from transformers import AutoModel, AutoProcessor
from typing import Sequence, List
from PIL.Image import Image
from tqdm import tqdm
import functools
from timm.data import resolve_data_config
from datasets import load_dataset, load_from_disk, Dataset, DatasetDict

from timm.data import create_transform

# Data loading

In [None]:
USE_CACHED: bool = True

In [None]:
def get_dataset(name: str, split: str, perc: float, seed: int = 42):
    """
    Load a dataset from the HuggingFace datasets library.
    """
    assert 0 < perc <= 1
    dataset = load_dataset(
        name,
        split=split,
        use_auth_token=True,
    )
    seed_everything(seed)

    # Select a random subset
    if perc != 1:
        dataset = dataset.shuffle(seed=seed).select(list(range(int(len(dataset) * perc))))

    return dataset

In [None]:
DatasetParams = namedtuple("DatasetParams", ["name", "fine_grained", "train_split", "test_split", "perc", "hf_key"])

In [None]:
dataset_params: DatasetParams = DatasetParams("cifar100", None, "train", "test", 1, ("cifar100",))
dataset_params

In [None]:
DATASET_KEY = "_".join(map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None]))
DATASET_DIR: Path = PROJECT_ROOT / "data" / "encoded_data" / DATASET_KEY
DATASET_DIR

In [None]:
if not DATASET_DIR.exists() or not USE_CACHED:

    data: DatasetDict = DatasetDict(
        train=get_dataset(name=dataset_params.name, split=dataset_params.train_split, perc=dataset_params.perc),
        test=get_dataset(name=dataset_params.name, split=dataset_params.test_split, perc=dataset_params.perc),
    )
else:
    data: Dataset = load_from_disk(dataset_path=str(DATASET_DIR))

data

# Embed

In [None]:
FORCE_RECOMPUTE: bool = False
DEVICE: str = "cuda"

In [None]:
ENCODERS = (
    "rexnet_100",
    "vit_base_patch16_224",
    "vit_base_patch16_384",
    "vit_base_resnet50_384",
    "openai/clip-vit-base-patch32",
    "vit_small_patch16_224",
)

In [None]:
def encode_field(batch, src_field: str, tgt_field: str, transformation):
    """
    Create a new field with name `tgt_field` by applying `transformation` to `src_field`.
    """
    src_data = batch[src_field]
    transformed = transformation(src_data)

    return {tgt_field: transformed}


@torch.no_grad()
def image_encode(images: Sequence[Image], transform, encoder):
    """
    Encode images using a timm model.
    """
    images: List[torch.Tensor] = [transform(image.convert("RGB")) for image in images]
    images: torch.Tensor = torch.stack(images, dim=0).to(DEVICE)
    encoding = encoder(images)

    return list(encoding.cpu().numpy())


@torch.no_grad()
def clip_image_encode(images: Sequence[Image], transform, encoder):
    """
    Encode images using the OpenAI CLIP model.
    """
    images = [image.convert("RGB") for image in images]
    image_inputs = transform(images=images, return_tensors="pt").to(DEVICE)
    encoder_out = encoder.vision_model(**image_inputs)
    encoding = encoder_out.pooler_output

    return list(encoding.cpu().numpy())

In [None]:
missing_encoders = [encoder for encoder in ENCODERS if FORCE_RECOMPUTE or encoder not in data["train"].column_names]

for encoder_name in tqdm(missing_encoders):
    tgt_field: str = encoder_name

    if encoder_name.startswith("openai/clip"):
        encoder = AutoModel.from_pretrained(encoder_name).requires_grad_(False).eval().to(DEVICE)
        transform = AutoProcessor.from_pretrained(encoder_name)
        encode_func = clip_image_encode

    else:
        encoder = (
            timm.create_model(encoder_name, pretrained=True, num_classes=0).requires_grad_(False).eval().to(DEVICE)
        )
        config = resolve_data_config({}, model=encoder)
        transform = create_transform(**config)
        encode_func = image_encode

    data = data.map(
        functools.partial(
            encode_field,
            src_field="img",
            tgt_field=tgt_field,
            transformation=functools.partial(
                encode_func,
                transform=transform,
                encoder=encoder,
            ),
        ),
        num_proc=1,
        batched=True,
        batch_size=32,
        desc=f"{encoder_name}",
    )
    encoder = encoder.cpu()

    data.save_to_disk(str(DATASET_DIR))

if "index" not in data["train"].column_names:
    data = data.map(lambda x, index: {"index": index}, with_indices=True)
    data.save_to_disk(str(DATASET_DIR))

data.set_format(columns=ENCODERS, output_all_columns=True, type="torch")

In [None]:
data