In [1]:
import logging
import os

import wandb
import datasets
import torch
import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.decomposition import PCA
from transformers import pipeline, AutoImageProcessor, AutoModel
from transformers.image_processing_base import BatchFeature

logger = logging.getLogger(__name__)

device = "cuda:2"

if 'notebooks' in os.getcwd():
    os.chdir("../")

In [2]:
%load_ext autoreload
%autoreload 2
from exrep.registry import load_data, save_tensor
from exrep.utils import generic_map

crops_dataset = load_data(
    base_name="imagenet",
    phase="crops",
    load_local=True,
)

image_dataset = load_data(
    base_name="imagenet",
    phase="images",
    load_local=True,
)

In [3]:
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
model = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
dataloader = torch.utils.data.DataLoader(
    crops_dataset.with_transform(
        lambda x: processor(images=x['crops'], return_tensors="pt")
    ),
    batch_size=128,
)
embeddings = generic_map(
    model,
    dataloader,
    post_proc_fn=lambda x: x.pooler_output,
    input_format="keyword",
    device=device
)
embeddings.shape

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
100%|██████████| 81/81 [03:44<00:00,  2.77s/it]


torch.Size([10291, 768])

In [25]:
n_clusters = 30
preview_size = 5
clustering = KMeans(n_clusters=n_clusters).fit(embeddings.cpu().numpy())
cluster_labels = clustering.predict(embeddings.cpu().numpy())
for cluster_id in range(n_clusters):
    indices = np.where(cluster_labels == cluster_id)[0]
    # fig, axs = plt.subplots(1, preview_size, figsize=(20, 2))
    # for ax, i in zip(axs, indices):
    #     display_im = crops_dataset[int(i)]["crops"].resize((80, 80))
    #     ax.imshow(display_im)
    #     ax.axis("off")

In [26]:
local_encoding = np.zeros((len(image_dataset), n_clusters))
for i in range(n_clusters):
    crop_indices = np.where(clustering.labels_ == i)[0]
    image_indices, _, _ = np.strings.partition(crops_dataset[crop_indices]['index'], "_")
    local_encoding[np.unique(image_indices).astype(int), i] = 1

In [27]:
local_encoding

array([[0., 0., 1., ..., 1., 0., 1.],
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 1., 0.],
       ...,
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.]], shape=(2000, 30))

In [28]:
from dotenv import dotenv_values

local_config = dotenv_values(".env")

run = wandb.init(
    project=local_config["WANDB_PROJECT"],
    config={
        "job_type": "concept-generation",
    },
    save_code=False,
)

save_tensor(
    torch.from_numpy(local_encoding),
    base_name="imagenet",
    phase="local-encoding",
    model_name=f"kmeans-{n_clusters}",
    wandb_run=run,
)

In [29]:
wandb.finish()