### Load libraries

In [1]:
import pathlib
import sys
from typing import Type

import git.repo
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import transformers
from tqdm import tqdm

GIT_ROOT = pathlib.Path(
    str(git.repo.Repo(".", search_parent_directories=True).working_tree_dir)
)
sys.path.append(str(GIT_ROOT))

from src.pretrain.models import BaseEmbedderConfig, get_embedder_index

transformers.logging.set_verbosity_error()

### Iterate across all models

In [2]:
embedder_config_ts: tuple[Type[BaseEmbedderConfig]] = tuple(
    get_embedder_index().values()
)

In [3]:
data = []
for embedder_cfg_t in embedder_config_ts:
    for id in embedder_cfg_t().valid_model_ids:  # type: ignore
        embedder_cfg = embedder_cfg_t(id=id)
        embedder = embedder_cfg.get_model()

        info = {
            "model": embedder_cfg.id,
            "embed_dim": embedder.embed_dim,
            "num_params": embedder.n_embedder_params,
        }
        data.append(info)
        print(info)


{'model': 'hf/laion/CLIP-ViT-B-32-laion2B-s34B-b79K', 'embed_dim': 768, 'num_params': 87456000}
{'model': 'hf/laion/CLIP-ViT-g-14-laion2B-s12B-b42K', 'embed_dim': 1408, 'num_params': 1011203840}
{'model': 'hf/laion/CLIP-ViT-H-14-laion2B-s32B-b79K', 'embed_dim': 1280, 'num_params': 630766080}
{'model': 'hf/laion/CLIP-ViT-L-14-laion2B-s32B-b82K', 'embed_dim': 1024, 'num_params': 303179776}
{'model': 'hf/microsoft/beit-base-finetuned-ade-640-640', 'embed_dim': 768, 'num_params': 86555712}
{'model': 'hf/microsoft/beit-base-patch16-224-pt22k-ft22k', 'embed_dim': 768, 'num_params': 85761984}
{'model': 'hf/microsoft/beit-base-patch16-224-pt22k', 'embed_dim': 768, 'num_params': 85666128}
{'model': 'hf/microsoft/beit-base-patch16-224', 'embed_dim': 768, 'num_params': 85761984}
{'model': 'hf/microsoft/beit-base-patch16-384', 'embed_dim': 768, 'num_params': 85975104}
{'model': 'hf/microsoft/beit-large-finetuned-ade-640-640', 'embed_dim': 1024, 'num_params': 305522176}
{'model': 'hf/microsoft/beit

In [9]:
df = pd.DataFrame(data)
df["model_type"] = df["model"].apply(
    lambda x: x.split("-")[0] if x.startswith("hf/") else x[:9],
)

# Plot embed_dim vs num_params (color by model_type)
fig = px.scatter(
    df,
    x="embed_dim",
    y="num_params",
    hover_name="model",
    color="model_type",
    title="Embedding Dimension vs. Number of Parameters",
)
fig.update_xaxes(type="log")
fig.update_yaxes(type="log")
fig.show()