In [1]:
from src.dataset import Vietnam46AttrDataset
from torch.utils.data import DataLoader, SequentialSampler
from src.model import FeatureExtractor
from src.retriever import transform
import torch
from src.utils import create_index, write_index
from tqdm import tqdm

In [2]:
ds = Vietnam46AttrDataset(transform=transform)
sampler = SequentialSampler(ds)
dl = DataLoader(ds, batch_size=64, sampler=sampler)


len(ds), len(dl)

(9142, 143)

In [3]:
batch, paths = next(iter(dl))

batch.shape, len(paths)

(torch.Size([64, 3, 224, 224]), 64)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
extractor = FeatureExtractor().to(device)

In [6]:
index = create_index(extractor.feature_size, metrics="IP")

In [7]:
with tqdm(total=len(dl), desc="Extracting features") as pbar:
    for batch, _ in dl:
        batch = batch.to(device)
        features = extractor(batch)
        index.add(features)
        pbar.update(1)

write_index(index, "index/Vietnam46Attr_full.resnet50_IP.bin")

Extracting features: 100%|██████████| 143/143 [01:32<00:00,  1.55it/s]


In [8]:
index.ntotal

9142