In [1]:
import pathlib

import clip
import ml_collections
import numpy as np
import torch
import tqdm

import main

In [2]:
device = torch.device('cuda')

In [3]:
model, preprocess = clip.load('ViT-B/32', device=device)
model_name = 'vitb32'

In [4]:
# import configs.inaturalist2021mini
# config = configs.inaturalist2021mini.get_config()
# config.dataset_root = '/home/jack/data/manual/inaturalist2021/'
# dataset_name = 'inat21'

import configs.imagenet
config = configs.imagenet.get_config()
config.dataset_root = '/home/jack/data/torchvision/imagenet/'
dataset_name = 'imagenet'

In [5]:
datasets = {}
datasets[config.train_split], datasets[config.eval_split], tree, node_names, _, _ = main.make_datasets(config)

In [6]:
{k: len(datasets[k]) for k in datasets}

{'train': 1281167, 'val': 50000}

In [7]:
for k in datasets:
    datasets[k].transform = preprocess

In [8]:
def extract_features(dataset):
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=256,
        shuffle=False,
        pin_memory=False,
        num_workers=8,
        prefetch_factor=2)

    feature_batches = []
    label_batches = []

    model.eval()
    with torch.inference_mode():
        for image_batch, label_batch in tqdm.tqdm(loader):
            image_batch = image_batch.to(device)
            feature_batch = model.encode_image(image_batch)
            feature_batches.append(np.array(feature_batch.cpu()))
            # Important not to keep output of DataLoader. Perform deep copy.
            # https://github.com/pytorch/pytorch/issues/11201#issuecomment-486232056
            label_batches.append(np.array(label_batch.numpy()))

    features = np.concatenate(feature_batches, axis=0)
    labels = np.concatenate(label_batches, axis=0)
    return features, labels

In [9]:
out_dir = pathlib.Path(f'resources/features/{dataset_name}_{model_name}')
out_dir.mkdir(exist_ok=True)

for k in datasets:
    features, labels = extract_features(datasets[k])
    np.savez(out_dir / f'{k}.npz', features=features, labels=labels)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 5005/5005 [20:30<00:00,  4.07it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [00:51<00:00,  3.77it/s]
