In [1]:
from datasets import load_dataset

dataset = load_dataset("dnth/Eyewear-Dataset-1024-with-captions")

In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['brand', 'prompt', 'product_type', 'image', 'control_image', 'caption'],
        num_rows: 20964
    })
})

In [3]:
dataset['train'][0]

{'brand': 'Polaroid',
 'prompt': 'shape is round / oval, technology is r, frame material is polycarbonate,.',
 'product_type': 'sunglasses',
 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024>,
 'control_image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024>,
 'caption': 'A close-up view of a pair of Polaroid sunglasses. The frame is a dark brown color with a tortoiseshell pattern and a slightly curved shape. The clear lenses reflect a blue tint. The right arm of the sunglasses is slightly extended, showcasing the sleek design. The brand name "Polaroid" is prominently displayed on the right arm of the frame. Two small white dots are visible on the left arm, possibly indicating fasteners. The sunglasses are set against a stark white background, which highlights their features and design.'}

In [4]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('all-MiniLM-L6-v2')

def embed_captions(batch):
    """
    Function to embed captions in batches for efficiency
    """
    captions = batch['caption']
    # Generate embeddings
    embeddings = model.encode(captions, convert_to_tensor=False)
    # Convert to list for compatibility with datasets
    batch['caption_embeddings'] = embeddings.tolist()
    return batch

In [5]:
embedded_dataset = dataset.map(
    embed_captions,
    batched=True,
    batch_size=32,  # Adjust based on your GPU memory
    desc="Embedding captions"
)


Embedding captions:   0%|          | 0/20964 [00:00<?, ? examples/s]

  return forward_call(*args, **kwargs)


In [7]:
embedded_dataset['train'].push_to_hub("dnth/eyewear-dataset-1024-embedded")

Uploading the dataset shards:   0%|          | 0/5 [00:00<?, ? shards/s]

Map:   0%|          | 0/4193 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]

Map:   0%|          | 0/4193 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]

Map:   0%|          | 0/4193 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]

Map:   0%|          | 0/4193 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]

Map:   0%|          | 0/4192 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/dnth/eyewear-dataset-1024-embedded/commit/e9d1173fab6d6bdbb29f3da588790603b063d50a', commit_message='Upload dataset', commit_description='', oid='e9d1173fab6d6bdbb29f3da588790603b063d50a', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/dnth/eyewear-dataset-1024-embedded', endpoint='https://huggingface.co', repo_type='dataset', repo_id='dnth/eyewear-dataset-1024-embedded'), pr_revision=None, pr_num=None)

In [10]:
embedded_dataset['train'][0]

{'brand': 'Polaroid',
 'prompt': 'shape is round / oval, technology is r, frame material is polycarbonate,.',
 'product_type': 'sunglasses',
 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024>,
 'control_image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024>,
 'caption': 'A close-up view of a pair of Polaroid sunglasses. The frame is a dark brown color with a tortoiseshell pattern and a slightly curved shape. The clear lenses reflect a blue tint. The right arm of the sunglasses is slightly extended, showcasing the sleek design. The brand name "Polaroid" is prominently displayed on the right arm of the frame. Two small white dots are visible on the left arm, possibly indicating fasteners. The sunglasses are set against a stark white background, which highlights their features and design.',
 'caption_embeddings': [-0.03259830176830292,
  0.0042353603057563305,
  -0.014557523652911186,
  -0.041319455951452255,
  0.0383530855178833,
  -0.017793025