Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow arrow table for upload in map_text #214

Merged
merged 4 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion nomic/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import uuid
from typing import Dict, Iterable, List, Optional, Union

import pyarrow as pa
import numpy as np
from pandas import DataFrame
from loguru import logger
from tqdm import tqdm

from .project import AtlasProject
from .settings import *
from .utils import b64int, get_random_name
from .utils import b64int, get_random_name, arrow_iterator


def map_embeddings(
Expand Down Expand Up @@ -209,6 +210,9 @@ def map_text(
if isinstance(data, DataFrame):
# Convert DataFrame to a generator of dictionaries
data_iterator = (row._asdict() for row in data.itertuples(index=False))
elif isinstance(data, pa.Table):
# Create generator from pyarrow table
data_iterator = arrow_iterator(data)
else:
data_iterator = iter(data)

Expand Down
35 changes: 33 additions & 2 deletions nomic/tests/test_atlas_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import requests
from nomic import AtlasProject, atlas
import pyarrow as pa
import pandas as pd

def gen_random_datetime(min_year=1900, max_year=datetime.now().year):
Expand Down Expand Up @@ -353,8 +354,13 @@ def test_map_embeddings():

map = project.get_map(name='UNITTEST1')

with project.wait_for_project_lock():
time.sleep(1)
num_tries = 0
while map.project.is_locked:
time.sleep(10)
num_tries += 1
if num_tries > 5:
raise TimeoutError('Timed out while waiting for project to unlock')

retrieved_embeddings = map.embeddings.latent

assert project.total_datums == num_embeddings
Expand Down Expand Up @@ -410,6 +416,31 @@ def test_map_text_pandas():

project.delete()


def test_map_text_arrow():
size = 50
data = pa.Table.from_pydict({
'field': [str(uuid.uuid4()) for i in range(size)],
'id': [str(uuid.uuid4()) for i in range(size)],
'color': [random.choice(['red', 'blue', 'green']) for i in range(size)],
})

project = atlas.map_text(
name='UNITTEST_arrow_text',
id_field='id',
indexed_field="color",
data=data,
is_public=True,
colorable_fields=['color'],
reset_project_if_exists=True,
)

map = project.get_map(name='UNITTEST_arrow_text')

assert project.total_datums == 50

project.delete()


def test_map_text_iterator():
size = 50
Expand Down
10 changes: 10 additions & 0 deletions nomic/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import base64
import gc
import sys
import pyarrow as pa
from uuid import UUID

from wonderwords import RandomWord


def arrow_iterator(table: pa.Table):
# TODO: setting to 10k so it doesn't take too long?
# Wrote this as a generator so we don't realize the whole table in memory
reader = table.to_reader(max_chunksize=10_000)
for batch in reader:
for item in batch.to_pylist():
yield item


def b64int(i: int):
ibytes = int.to_bytes(i, length=8, byteorder='big').lstrip(b'\x00')
if ibytes == b'':
Expand Down