Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: map_data support for Nomic Embed Vision #308

Merged
merged 29 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d019c39
feat: image upload
zanussbaum Jun 13, 2024
ac9a55c
feat: update cifar10 update
zanussbaum Jun 13, 2024
8983981
style: black isort
zanussbaum Jun 17, 2024
e707f94
fix: pyright errors
zanussbaum Jun 17, 2024
675642a
fix: simplify
zanussbaum Jun 17, 2024
53bd566
fix: black isort
zanussbaum Jun 17, 2024
1e9b55a
Merge branch 'main' into map_image
zanussbaum Jun 19, 2024
1af4d18
style: isort
zanussbaum Jun 19, 2024
ca3746f
refactor: remove bytes upload
zanussbaum Jun 20, 2024
fa031db
fix: id type casting to str when not str
zanussbaum Jun 21, 2024
2850c8a
fix: communityh_description_target_field -> topic_label_field
zanussbaum Jun 21, 2024
ca1b60f
fix: alias for pydantic 2x
zanussbaum Jun 21, 2024
da0411f
substituting id for self id_field and casting id_field to str
eelegiap Jun 24, 2024
3b0715d
Topic label field bug alias bug(#310)
zanussbaum Jun 25, 2024
9978a8a
chore: version bump
zanussbaum Jun 25, 2024
f205e5d
style: black isort
zanussbaum Jun 25, 2024
7501db3
fix: pyright
zanussbaum Jun 25, 2024
3e9e0f2
fix: ellipsis suggestion
zanussbaum Jun 25, 2024
4977dd3
Merge branch 'main' into map_image
zanussbaum Jun 25, 2024
d02344e
fix: pyright pydantic + pyarrow errors
zanussbaum Jun 25, 2024
5a134db
style: black isort
zanussbaum Jun 25, 2024
016b3e1
fix: remove unused bytes
zanussbaum Jun 25, 2024
e286831
fix: increase batch size
zanussbaum Jun 25, 2024
4ad4af5
Merge branch 'main' into map_image
zanussbaum Jul 2, 2024
af646e1
fix: allow only blobs to be passed with no metadata
zanussbaum Jul 2, 2024
d7a33a7
chore: better logging for when indexed field and embeddings present
zanussbaum Jul 4, 2024
e2be038
fix: add id field if not present for blobs
zanussbaum Jul 4, 2024
2b00dc3
fix: better logging for blob upload
zanussbaum Jul 4, 2024
c8db85e
fix: map cifar 10
zanussbaum Jul 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions examples/image/map_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
images = []
datums = []

max_embeddings = 200_000
max_embeddings = 100_000

for idx, image in enumerate(tqdm(dataset)):
images.append(image['img'])
Expand All @@ -34,19 +34,19 @@
b64img = base64.b64encode(buffered.getvalue()).decode('utf-8')
datums.append({'id': str(idx),
'label': labels[image['label']],
"img": f'<img src="data:image/jpeg;base64,{b64img}" style="min-width:150px"/>'
}
)

if idx >= max_embeddings:
break

output = embed.image(images=images)

embeddings = np.array(output['embeddings'])

atlas.map_data(embeddings=embeddings,
identifier='cifar',
atlas.map_data(blobs=images,
identifier='cifar-50k-image-upload-with-topic',
data=datums,
id_field='id',
topic_model=False)
topic_model={
"build_topic_model": True,
"topic_label_field": "label"
},
)
25 changes: 25 additions & 0 deletions nomic/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pyarrow as pa
from loguru import logger
from pandas import DataFrame
from PIL import Image
from pyarrow import Table
from tqdm import tqdm

Expand All @@ -21,6 +22,7 @@

def map_data(
data: Optional[Union[DataFrame, List[Dict], Table]] = None,
blobs: Optional[List[Union[str, bytes, Image.Image]]] = None,
embeddings: Optional[np.ndarray] = None,
identifier: Optional[str] = None,
description: str = "",
Expand Down Expand Up @@ -56,6 +58,26 @@ def map_data(
if indexed_field is not None:
modality = "text"

if blobs is not None:
# change this when we support other modalities
modality = "image"
indexed_field = "_blob_hash"
if embedding_model is not None:
if isinstance(embedding_model, str):
model_name = embedding_model
elif isinstance(embedding_model, dict):
model_name = embedding_model["model"]
elif isinstance(embedding_model, NomicEmbedOptions):
model_name = embedding_model.model
else:
raise ValueError("embedding_model must be a string, dictionary, or NomicEmbedOptions object")

if model_name in ["nomic-embed-text-v1", "nomic-embed-text-v1.5"]:
raise Exception("You cannot use a text embedding model with blobs")
else:
# default to vision v1.5
embedding_model = NomicEmbedOptions(model="nomic-embed-vision-v1.5")

if id_field is None:
id_field = ATLAS_DEFAULT_ID_FIELD

Expand Down Expand Up @@ -116,6 +138,9 @@ def map_data(
embeddings=embeddings,
data=data,
)
elif modality == "image":
dataset.add_data(blobs=blobs, data=data)

except BaseException as e:
if number_of_datums_before_upload == 0:
logger.info(f"{dataset.identifier}: Deleting dataset due to failure in initial upload.")
Expand Down
10 changes: 6 additions & 4 deletions nomic/data_inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

import pyarrow as pa
from pydantic import BaseModel, Field
from pydantic import AliasChoices, BaseModel, Field

from .settings import DEFAULT_DUPLICATE_THRESHOLD

Expand Down Expand Up @@ -29,7 +29,7 @@ def convert_pyarrow_schema_for_atlas(schema: pa.Schema) -> pa.Schema:
for field in schema:
if field.name.startswith("_"):
# Underscore fields are private to Atlas and will be handled with their own logic.
if not field.name in {"_embeddings"}:
if not field.name in {"_embeddings", "_blob_hash"}:
raise ValueError(f"Underscore fields are reserved for Atlas internal use: {field.name}")
whitelist[field.name] = field.type
elif pa.types.is_boolean(field.type):
Expand Down Expand Up @@ -114,4 +114,6 @@ class NomicEmbedOptions(BaseModel):
model: The Nomic Embedding Model to use.
"""

model: str = "NomicEmbed"
model: Literal[
"nomic-embed-text-v1", "nomic-embed-vision-v1", "nomic-embed-text-v1.5", "nomic-embed-vision-v1.5"
] = "nomic-embed-text-v1.5"
130 changes: 125 additions & 5 deletions nomic/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
from contextlib import contextmanager
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

Expand All @@ -15,6 +16,7 @@
import requests
from loguru import logger
from pandas import DataFrame
from PIL import Image
from pyarrow import compute as pc
from pyarrow import feather, ipc
from tqdm import tqdm
Expand Down Expand Up @@ -337,7 +339,7 @@ def _validate_and_correct_arrow_upload(self, data: pa.Table, project: "AtlasData

for key in data.column_names:
if key.startswith("_"):
if key == "_embeddings":
if key == "_embeddings" or key == "_blob_hash":
continue
raise ValueError("Metadata fields cannot start with _")
if pa.compute.max(pa.compute.utf8_length(data[project.id_field])).as_py() > 36: # type: ignore
Expand Down Expand Up @@ -1073,7 +1075,7 @@ def create_index(
elif isinstance(embedding_model, NomicEmbedOptions):
pass
elif isinstance(embedding_model, str):
embedding_model = NomicEmbedOptions(model=embedding_model)
embedding_model = NomicEmbedOptions(model=embedding_model) # type: ignore
else:
embedding_model = NomicEmbedOptions()

Expand Down Expand Up @@ -1126,7 +1128,7 @@ def create_index(
),
}

elif self.modality == "text":
elif self.modality == "text" or self.modality == "image":
# find the index id of the index with name reuse_embeddings_from_index
reuse_embedding_from_index_id = None
indices = self.indices
Expand All @@ -1146,6 +1148,18 @@ def create_index(
if indexed_field not in self.dataset_fields:
raise Exception(f"Indexing on {indexed_field} not allowed. Valid options are: {self.dataset_fields}")

if self.modality == "image":
if topic_model.topic_label_field is None:
print(
"You did not specify the `topic_label_field` option in your topic_model, your dataset will not contain auto-labeled topics."
)
topic_field = None
topic_model.build_topic_model = False
else:
topic_field = topic_model.topic_label_field
else:
topic_field = topic_model.topic_label_field

build_template = {
"project_id": self.id,
"index_name": name,
Expand Down Expand Up @@ -1178,7 +1192,7 @@ def create_index(
"topic_model_hyperparameters": json.dumps(
{
"build_topic_model": topic_model.build_topic_model,
"community_description_target_field": indexed_field, # TODO change key to topic_label_field post v0.0.85
"community_description_target_field": topic_field,
"cluster_method": topic_model.build_topic_model,
"enforce_topic_hierarchy": topic_model.enforce_topic_hierarchy,
}
Expand Down Expand Up @@ -1313,7 +1327,13 @@ def delete_data(self, ids: List[str]) -> bool:
else:
raise Exception(response.text)

def add_data(self, data=Union[DataFrame, List[Dict], pa.Table], embeddings: Optional[np.ndarray] = None, pbar=None):
def add_data(
self,
data=Union[DataFrame, List[Dict], pa.Table],
embeddings: Optional[np.ndarray] = None,
blobs: Optional[List[Union[str, bytes, Image.Image]]] = None,
pbar=None,
):
"""
Adds data of varying modality to an Atlas dataset.
Args:
Expand All @@ -1326,9 +1346,109 @@ def add_data(self, data=Union[DataFrame, List[Dict], pa.Table], embeddings: Opti
elif isinstance(data, pa.Table) and "_embeddings" in data.column_names: # type: ignore
embeddings = np.array(data.column("_embeddings").to_pylist()) # type: ignore
self._add_embeddings(data=data, embeddings=embeddings, pbar=pbar)
elif blobs is not None:
self._add_blobs(data=data, blobs=blobs, pbar=pbar)
else:
self._add_text(data=data, pbar=pbar)

def _add_blobs(
self, data: Union[DataFrame, List[Dict], pa.Table], blobs: List[Union[str, bytes, Image.Image]], pbar=None
):
"""
Add data, with associated blobs, to the dataset.
Uploads blobs to the server and associates them with the data.
"""
if isinstance(data, DataFrame):
data = pa.Table.from_pandas(data)
elif isinstance(data, list):
data = pa.Table.from_pylist(data)
elif not isinstance(data, pa.Table):
raise ValueError("Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table.")

blob_upload_endpoint = "/v1/project/data/add/blobs"

# uploda batch of blobs
# return hash of blob
# add hash to data as _blob_hash
# set indexed_field to _blob_hash
# call _add_data

# Cast self id field to string for merged data lower down on function
data = data.set_column( # type: ignore
data.schema.get_field_index(self.id_field), self.id_field, pc.cast(data[self.id_field], pa.string()) # type: ignore
)

ids = data[self.id_field].to_pylist() # type: ignore
if not isinstance(ids[0], str):
ids = [str(uuid) for uuid in ids]

# TODO: add support for other modalities
images = []
for uuid, blob in tqdm(zip(ids, blobs), total=len(ids), desc="Processing blobs"):
if isinstance(blob, str) and os.path.exists(blob):
# Auto resize to max 512x512
image = Image.open(blob)
if image.height > 512 or image.width > 512:
image = image.resize((512, 512))
buffered = BytesIO()
image.save(buffered, format="JPEG")
images.append((uuid, buffered.getvalue()))
elif isinstance(blob, bytes):
images.append((uuid, blob))
elif isinstance(blob, Image.Image):
if blob.height > 512 or blob.width > 512:
blob = blob.resize((512, 512))
buffered = BytesIO()
blob.save(buffered, format="JPEG")
images.append((uuid, buffered.getvalue()))
else:
raise ValueError("Invalid blob type")

batch_size = 20
num_workers = 10

def send_request(i):
image_batch = images[i : i + batch_size]
ids = [uuid for uuid, _ in image_batch]
blobs = [("blobs", blob) for _, blob in image_batch]
response = requests.post(
self.atlas_api_path + blob_upload_endpoint,
headers=self.header,
data={"dataset_id": self.id},
files=blobs,
)
if response.status_code != 200:
raise Exception(response.text)
return {uuid: blob_hash for uuid, blob_hash in zip(ids, response.json()["hashes"])}

# if this method is being called internally, we pass a global progress bar
AndriyMulyar marked this conversation as resolved.
Show resolved Hide resolved
if pbar is None:
pbar = tqdm(total=int(len(data)) // batch_size)

hash_schema = pa.schema([(self.id_field, pa.string()), ("_blob_hash", pa.string())])
returned_ids = []
returned_hashes = []

succeeded = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(send_request, i): i for i in range(0, len(data), batch_size)}

for future in concurrent.futures.as_completed(futures):
response = future.result()
# add hash to data as _blob_hash
for uuid, blob_hash in response.items():
returned_ids.append(uuid)
returned_hashes.append(blob_hash)

# A successful upload.
succeeded += batch_size
pbar.update(1)

hash_tb = pa.Table.from_pydict({self.id_field: returned_ids, "_blob_hash": returned_hashes}, schema=hash_schema)
merged_data = data.join(right_table=hash_tb, keys=self.id_field) # type: ignore

self._add_data(merged_data, pbar=pbar)

def _add_text(self, data=Union[DataFrame, List[Dict], pa.Table], pbar=None):
"""
Add text data to the dataset.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setup(
name="nomic",
version="3.0.34",
version="3.0.35",
url="https://github.com/nomic-ai/nomic",
description=description,
long_description=description,
Expand Down