In [1]:
from datasets import load_dataset

dataset = load_dataset("dnth/eyewear-retrieval-dataset")

dataset

DatasetDict({
    train: Dataset({
        features: ['brand', 'prompt', 'product_type', 'image', 'control_image', 'caption', 'caption_embeddings', 'local_path'],
        num_rows: 20964
    })
})

In [2]:
import timm
import torch

model = timm.create_model(
    'naflexvit_base_patch16_gap.e300_s576_in1k',
    pretrained=True,
    num_classes=0,  
)
model = model.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)

data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

In [3]:

from PIL import Image

def embed_image(batch):
    image_paths = batch['local_path']

    images = [Image.open(image_path) for image_path in image_paths]
    images = [transforms(image).unsqueeze(0) for image in images]
    images = torch.cat(images, dim=0)
    images = images.to(device)
    
    with torch.no_grad():
        embeddings = model(images)
        
    batch['image_embeddings'] = embeddings.tolist()
    return batch



In [4]:
dataset_with_image_embeddings = dataset['train'].map(embed_image, batched=True, batch_size=32)

Map:   0%|          | 0/20964 [00:00<?, ? examples/s]

In [5]:
len(dataset_with_image_embeddings[0]['image_embeddings'])

768

In [6]:
len(dataset_with_image_embeddings[0]['caption_embeddings'])

384

In [8]:
dataset_with_image_embeddings.push_to_hub("dnth/eyewear-retrieval-dataset")

Uploading the dataset shards:   0%|          | 0/5 [00:00<?, ? shards/s]

Map:   0%|          | 0/4193 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]

Map:   0%|          | 0/4193 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]

Map:   0%|          | 0/4193 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]

Map:   0%|          | 0/4193 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]

Map:   0%|          | 0/4192 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/dnth/eyewear-retrieval-dataset/commit/b2a7d240c67444cd2f0a49b907682b3db548a146', commit_message='Upload dataset', commit_description='', oid='b2a7d240c67444cd2f0a49b907682b3db548a146', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/dnth/eyewear-retrieval-dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='dnth/eyewear-retrieval-dataset'), pr_revision=None, pr_num=None)