Skip to content

Commit

Permalink
Merge pull request #214 from nomic-ai/map_text_pa
Browse files Browse the repository at this point in the history
feat: allow arrow table for upload in map_text
  • Loading branch information
bmschmidt committed Oct 5, 2023
2 parents d647963 + c6b13f9 commit a3a91c2
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
6 changes: 5 additions & 1 deletion nomic/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import uuid
from typing import Dict, Iterable, List, Optional, Union

import pyarrow as pa
import numpy as np
from pandas import DataFrame
from loguru import logger
from tqdm import tqdm

from .project import AtlasProject
from .settings import *
from .utils import b64int, get_random_name
from .utils import b64int, get_random_name, arrow_iterator


def map_embeddings(
Expand Down Expand Up @@ -209,6 +210,9 @@ def map_text(
if isinstance(data, DataFrame):
# Convert DataFrame to a generator of dictionaries
data_iterator = (row._asdict() for row in data.itertuples(index=False))
elif isinstance(data, pa.Table):
# Create generator from pyarrow table
data_iterator = arrow_iterator(data)
else:
data_iterator = iter(data)

Expand Down
35 changes: 33 additions & 2 deletions nomic/tests/test_atlas_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import requests
from nomic import AtlasProject, atlas
import pyarrow as pa
import pandas as pd

def gen_random_datetime(min_year=1900, max_year=datetime.now().year):
Expand Down Expand Up @@ -353,8 +354,13 @@ def test_map_embeddings():

map = project.get_map(name='UNITTEST1')

with project.wait_for_project_lock():
time.sleep(1)
num_tries = 0
while map.project.is_locked:
time.sleep(10)
num_tries += 1
if num_tries > 5:
raise TimeoutError('Timed out while waiting for project to unlock')

retrieved_embeddings = map.embeddings.latent

assert project.total_datums == num_embeddings
Expand Down Expand Up @@ -410,6 +416,31 @@ def test_map_text_pandas():

project.delete()


def test_map_text_arrow():
size = 50
data = pa.Table.from_pydict({
'field': [str(uuid.uuid4()) for i in range(size)],
'id': [str(uuid.uuid4()) for i in range(size)],
'color': [random.choice(['red', 'blue', 'green']) for i in range(size)],
})

project = atlas.map_text(
name='UNITTEST_arrow_text',
id_field='id',
indexed_field="color",
data=data,
is_public=True,
colorable_fields=['color'],
reset_project_if_exists=True,
)

map = project.get_map(name='UNITTEST_arrow_text')

assert project.total_datums == 50

project.delete()


def test_map_text_iterator():
size = 50
Expand Down
10 changes: 10 additions & 0 deletions nomic/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import base64
import gc
import sys
import pyarrow as pa
from uuid import UUID

from wonderwords import RandomWord


def arrow_iterator(table: pa.Table):
# TODO: setting to 10k so it doesn't take too long?
# Wrote this as a generator so we don't realize the whole table in memory
reader = table.to_reader(max_chunksize=10_000)
for batch in reader:
for item in batch.to_pylist():
yield item


def b64int(i: int):
ibytes = int.to_bytes(i, length=8, byteorder='big').lstrip(b'\x00')
if ibytes == b'':
Expand Down

0 comments on commit a3a91c2

Please sign in to comment.