diff --git a/examples/image/map_cifar10.py b/examples/image/map_cifar10.py index 9f681f37..02708c97 100644 --- a/examples/image/map_cifar10.py +++ b/examples/image/map_cifar10.py @@ -1,9 +1,6 @@ from nomic import embed from nomic import atlas from tqdm import tqdm -import numpy as np -import base64 -from io import BytesIO from datasets import load_dataset @@ -24,29 +21,19 @@ images = [] datums = [] -max_embeddings = 200_000 +max_embeddings = 100_000 for idx, image in enumerate(tqdm(dataset)): images.append(image['img']) - buffered = BytesIO() - image['img'].save(buffered, format="PNG") - - b64img = base64.b64encode(buffered.getvalue()).decode('utf-8') datums.append({'id': str(idx), 'label': labels[image['label']], - "img": f'' } ) if idx >= max_embeddings: break -output = embed.image(images=images) - -embeddings = np.array(output['embeddings']) -atlas.map_data(embeddings=embeddings, - identifier='cifar', - data=datums, - id_field='id', - topic_model=False) \ No newline at end of file +atlas.map_data(blobs=images, + identifier='cifar-50k' +) \ No newline at end of file diff --git a/nomic/atlas.py b/nomic/atlas.py index 5ac18c86..fd42df9c 100644 --- a/nomic/atlas.py +++ b/nomic/atlas.py @@ -10,6 +10,7 @@ import pyarrow as pa from loguru import logger from pandas import DataFrame +from PIL import Image from pyarrow import Table from tqdm import tqdm @@ -21,6 +22,7 @@ def map_data( data: Optional[Union[DataFrame, List[Dict], Table]] = None, + blobs: Optional[List[Union[str, bytes, Image.Image]]] = None, embeddings: Optional[np.ndarray] = None, identifier: Optional[str] = None, description: str = "", @@ -54,8 +56,30 @@ def map_data( raise Exception("Your embeddings cannot be empty") if indexed_field is not None: + if embeddings is not None: + logger.warning("You have specified an indexed field but are using embeddings. Embeddings will be ignored.") modality = "text" + if blobs is not None: + # change this when we support other modalities + modality = "image" + indexed_field = "_blob_hash" + if embedding_model is not None: + if isinstance(embedding_model, str): + model_name = embedding_model + elif isinstance(embedding_model, dict): + model_name = embedding_model["model"] + elif isinstance(embedding_model, NomicEmbedOptions): + model_name = embedding_model.model + else: + raise ValueError("embedding_model must be a string, dictionary, or NomicEmbedOptions object") + + if model_name in ["nomic-embed-text-v1", "nomic-embed-text-v1.5"]: + raise Exception("You cannot use a text embedding model with blobs") + else: + # default to vision v1.5 + embedding_model = NomicEmbedOptions(model="nomic-embed-vision-v1.5") + if id_field is None: id_field = ATLAS_DEFAULT_ID_FIELD @@ -73,9 +97,14 @@ def map_data( # no metadata was specified added_id_field = False - if data is None and embeddings is not None: - data = [{ATLAS_DEFAULT_ID_FIELD: b64int(i)} for i in range(len(embeddings))] + if data is None: added_id_field = True + if embeddings is not None: + data = [{ATLAS_DEFAULT_ID_FIELD: b64int(i)} for i in range(len(embeddings))] + elif blobs is not None: + data = [{ATLAS_DEFAULT_ID_FIELD: b64int(i)} for i in range(len(blobs))] + else: + raise ValueError("You must specify either data, embeddings, or blobs") if id_field == ATLAS_DEFAULT_ID_FIELD and data is not None: if isinstance(data, list) and id_field not in data[0]: @@ -116,6 +145,9 @@ def map_data( embeddings=embeddings, data=data, ) + elif modality == "image": + dataset.add_data(blobs=blobs, data=data) + except BaseException as e: if number_of_datums_before_upload == 0: logger.info(f"{dataset.identifier}: Deleting dataset due to failure in initial upload.") diff --git a/nomic/data_inference.py b/nomic/data_inference.py index d14e3e82..782009d3 100644 --- a/nomic/data_inference.py +++ b/nomic/data_inference.py @@ -1,7 +1,7 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union import pyarrow as pa -from pydantic import BaseModel, Field +from pydantic import AliasChoices, BaseModel, Field from .settings import DEFAULT_DUPLICATE_THRESHOLD @@ -29,7 +29,7 @@ def convert_pyarrow_schema_for_atlas(schema: pa.Schema) -> pa.Schema: for field in schema: if field.name.startswith("_"): # Underscore fields are private to Atlas and will be handled with their own logic. - if not field.name in {"_embeddings"}: + if not field.name in {"_embeddings", "_blob_hash"}: raise ValueError(f"Underscore fields are reserved for Atlas internal use: {field.name}") whitelist[field.name] = field.type elif pa.types.is_boolean(field.type): @@ -114,4 +114,6 @@ class NomicEmbedOptions(BaseModel): model: The Nomic Embedding Model to use. """ - model: str = "NomicEmbed" + model: Literal[ + "nomic-embed-text-v1", "nomic-embed-vision-v1", "nomic-embed-text-v1.5", "nomic-embed-vision-v1.5" + ] = "nomic-embed-text-v1.5" diff --git a/nomic/dataset.py b/nomic/dataset.py index 80389dba..09ce9b64 100644 --- a/nomic/dataset.py +++ b/nomic/dataset.py @@ -7,6 +7,7 @@ import time from contextlib import contextmanager from datetime import datetime +from io import BytesIO from pathlib import Path from typing import Dict, List, Optional, Tuple, Union @@ -15,6 +16,7 @@ import requests from loguru import logger from pandas import DataFrame +from PIL import Image from pyarrow import compute as pc from pyarrow import feather, ipc from tqdm import tqdm @@ -337,7 +339,7 @@ def _validate_and_correct_arrow_upload(self, data: pa.Table, project: "AtlasData for key in data.column_names: if key.startswith("_"): - if key == "_embeddings": + if key == "_embeddings" or key == "_blob_hash": continue raise ValueError("Metadata fields cannot start with _") if pa.compute.max(pa.compute.utf8_length(data[project.id_field])).as_py() > 36: # type: ignore @@ -1080,7 +1082,7 @@ def create_index( elif isinstance(embedding_model, NomicEmbedOptions): pass elif isinstance(embedding_model, str): - embedding_model = NomicEmbedOptions(model=embedding_model) + embedding_model = NomicEmbedOptions(model=embedding_model) # type: ignore else: embedding_model = NomicEmbedOptions() @@ -1133,7 +1135,7 @@ def create_index( ), } - elif self.modality == "text": + elif self.modality == "text" or self.modality == "image": # find the index id of the index with name reuse_embeddings_from_index reuse_embedding_from_index_id = None indices = self.indices @@ -1153,6 +1155,18 @@ def create_index( if indexed_field not in self.dataset_fields: raise Exception(f"Indexing on {indexed_field} not allowed. Valid options are: {self.dataset_fields}") + if self.modality == "image": + if topic_model.topic_label_field is None: + print( + "You did not specify the `topic_label_field` option in your topic_model, your dataset will not contain auto-labeled topics." + ) + topic_field = None + topic_model.build_topic_model = False + else: + topic_field = topic_model.topic_label_field + else: + topic_field = topic_model.topic_label_field + build_template = { "project_id": self.id, "index_name": name, @@ -1185,7 +1199,7 @@ def create_index( "topic_model_hyperparameters": json.dumps( { "build_topic_model": topic_model.build_topic_model, - "community_description_target_field": indexed_field, # TODO change key to topic_label_field post v0.0.85 + "community_description_target_field": topic_field, "cluster_method": topic_model.build_topic_model, "enforce_topic_hierarchy": topic_model.enforce_topic_hierarchy, } @@ -1320,7 +1334,13 @@ def delete_data(self, ids: List[str]) -> bool: else: raise Exception(response.text) - def add_data(self, data=Union[DataFrame, List[Dict], pa.Table], embeddings: Optional[np.ndarray] = None, pbar=None): + def add_data( + self, + data=Union[DataFrame, List[Dict], pa.Table], + embeddings: Optional[np.ndarray] = None, + blobs: Optional[List[Union[str, bytes, Image.Image]]] = None, + pbar=None, + ): """ Adds data of varying modality to an Atlas dataset. Args: @@ -1333,9 +1353,109 @@ def add_data(self, data=Union[DataFrame, List[Dict], pa.Table], embeddings: Opti elif isinstance(data, pa.Table) and "_embeddings" in data.column_names: # type: ignore embeddings = np.array(data.column("_embeddings").to_pylist()) # type: ignore self._add_embeddings(data=data, embeddings=embeddings, pbar=pbar) + elif blobs is not None: + self._add_blobs(data=data, blobs=blobs, pbar=pbar) else: self._add_text(data=data, pbar=pbar) + def _add_blobs( + self, data: Union[DataFrame, List[Dict], pa.Table], blobs: List[Union[str, bytes, Image.Image]], pbar=None + ): + """ + Add data, with associated blobs, to the dataset. + Uploads blobs to the server and associates them with the data. + """ + if isinstance(data, DataFrame): + data = pa.Table.from_pandas(data) + elif isinstance(data, list): + data = pa.Table.from_pylist(data) + elif not isinstance(data, pa.Table): + raise ValueError("Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table.") + + blob_upload_endpoint = "/v1/project/data/add/blobs" + + # uploda batch of blobs + # return hash of blob + # add hash to data as _blob_hash + # set indexed_field to _blob_hash + # call _add_data + + # Cast self id field to string for merged data lower down on function + data = data.set_column( # type: ignore + data.schema.get_field_index(self.id_field), self.id_field, pc.cast(data[self.id_field], pa.string()) # type: ignore + ) + + ids = data[self.id_field].to_pylist() # type: ignore + if not isinstance(ids[0], str): + ids = [str(uuid) for uuid in ids] + + # TODO: add support for other modalities + images = [] + for uuid, blob in tqdm(zip(ids, blobs), total=len(ids), desc="Loading images"): + if isinstance(blob, str) and os.path.exists(blob): + # Auto resize to max 512x512 + image = Image.open(blob) + if image.height > 512 or image.width > 512: + image = image.resize((512, 512)) + buffered = BytesIO() + image.save(buffered, format="JPEG") + images.append((uuid, buffered.getvalue())) + elif isinstance(blob, bytes): + images.append((uuid, blob)) + elif isinstance(blob, Image.Image): + if blob.height > 512 or blob.width > 512: + blob = blob.resize((512, 512)) + buffered = BytesIO() + blob.save(buffered, format="JPEG") + images.append((uuid, buffered.getvalue())) + else: + raise ValueError(f"Invalid blob type for {uuid}. Must be a path to an image, bytes, or PIL Image.") + + batch_size = 40 + num_workers = 10 + + def send_request(i): + image_batch = images[i : i + batch_size] + ids = [uuid for uuid, _ in image_batch] + blobs = [("blobs", blob) for _, blob in image_batch] + response = requests.post( + self.atlas_api_path + blob_upload_endpoint, + headers=self.header, + data={"dataset_id": self.id}, + files=blobs, + ) + if response.status_code != 200: + raise Exception(response.text) + return {uuid: blob_hash for uuid, blob_hash in zip(ids, response.json()["hashes"])} + + # if this method is being called internally, we pass a global progress bar + if pbar is None: + pbar = tqdm(total=len(data), desc="Uploading blobs to Atlas") + + hash_schema = pa.schema([(self.id_field, pa.string()), ("_blob_hash", pa.string())]) + returned_ids = [] + returned_hashes = [] + + succeeded = 0 + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = {executor.submit(send_request, i): i for i in range(0, len(data), batch_size)} + + for future in concurrent.futures.as_completed(futures): + response = future.result() + # add hash to data as _blob_hash + for uuid, blob_hash in response.items(): + returned_ids.append(uuid) + returned_hashes.append(blob_hash) + + # A successful upload. + succeeded += len(response) + pbar.update(len(response)) + + hash_tb = pa.Table.from_pydict({self.id_field: returned_ids, "_blob_hash": returned_hashes}, schema=hash_schema) + merged_data = data.join(right_table=hash_tb, keys=self.id_field) # type: ignore + + self._add_data(merged_data, pbar=pbar) + def _add_text(self, data=Union[DataFrame, List[Dict], pa.Table], pbar=None): """ Add text data to the dataset. diff --git a/setup.py b/setup.py index 22688243..62be90cc 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name="nomic", - version="3.0.35", + version="3.0.36", url="https://github.com/nomic-ai/nomic", description=description, long_description=description,