In [1]:
# torch is there, torchvision is there
!pip install timm

import os
import numpy as np

# Feature extraction
from torch.utils.data import DataLoader
import torch
from torchvision import transforms
import torchvision
import timm

# No grad's needed
torch.set_grad_enabled(False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

from google.colab import drive

mount_str = "/content/drive" # mount gdrive
drive.mount(mount_str)

Collecting timm
  Downloading timm-1.0.3-py3-none-any.whl (2.3 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/2.3 MB[0m [31m9.3 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.3/2.3 MB[0m [31m42.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m31.9 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x8

In [15]:
# Dataset
#dataset_root_path = f"{mount_str}/MyDrive/microcosmus/rotifer-locomotion/data/images/resized_cropped_squarebbox_gs_224/"
dataset_root_path = f"{mount_str}/MyDrive/microcosmus/rotifer-locomotion/data/images/resized_cropped_squarebbox_rgb_wbg_224/"
ds = timm.data.create_dataset("folder", root=dataset_root_path)  # ds = torchvision.datasets.ImageFolder(root=dataset_root_path)
print(ds)

# Transforms
tfs = transforms.Compose([
    #transforms.Resize(img_size),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.5], std=[0.5]),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Add transforms
ds.transform = tfs
dataloader = DataLoader(ds, batch_size=96, num_workers=12, shuffle=False)

# possible timm transforms instead of pytorch
#data_config = timm.data.resolve_model_data_config(model)
#transforms = timm.data.create_transform(**data_config, is_training=False)

<timm.data.dataset.ImageDataset object at 0x7da45f773340>




In [25]:
# Model selection
model_names = ["eva02_base_patch14_224.mim_in22k",
               "hiera_base_plus_224.mae_in1k_ft_in1k",
               "vit_base_patch16_224_in21k"]  # "vit_base_patch16_rope_reg1_gap_256.sbb_in1k"

for model_name in model_names:
    
    # Model
    model_params = {"model_name" : model_name, "pretrained" : True,
                    "num_classes" : 0, "global_pool" : "avg"}  # instead of num_classes, set features_only=True
    
    model = timm.create_model(**model_params).eval();
    model.to(device);

    # FEATURE Extraction
    features = list()
    for nth, (images, labels) in enumerate(dataloader):
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            outputs = model(images)
        print(nth)
        features.append(outputs.detach().cpu().numpy())

    feature_vectors = np.concatenate(features)
    del features
    print(feature_vectors.shape)

    # SAVE FEATURES
    SAVE_FEATS = True
    save_features_path = f"{mount_str}/MyDrive/microcosmus/rotifer-locomotion/data/results/feature_vectors_{model_params['model_name']}.npy"

    if SAVE_FEATS:
        np.save(save_features_path, feature_vectors)
    else:
        feature_vectors = np.load(save_features_path)

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

AssertionError: Input height (224) doesn't match model (256).