In [0]:
%pip install ray[all] lancedb
dbutils.library.restartPython()


## Shard plan
1. Create parquet shards (e.g., 10) --> See 00a_shard_data_preparation
2. Construct inferencing Ray actors
  * Ray actor, as input, receive shard num.
  * Ray actor build local lance DB instance using shard from parquet
3. inference:
  * ray.data a batch (e.g. 1k)
  * rerank?

In [0]:
import pyarrow as pa
catalog_name = 'jon_cheung'
schema_name = 'vizio_poc'
lance_table_name = 'audio_100M_chunk_shard_5'
num_test_vectors = 5_000


audio_parquet_path = f'/Volumes/{catalog_name}/{schema_name}/{lance_table_name}'
pyarrow_schema = pa.schema(
    [
        pa.field("id", pa.int64()),
        # pa.field("list_col", pa.list_(pa.uint8(), 35)),   # Fixed size list
        pa.field("list_col", pa.list_(pa.float16(), 35)),   # Fixed size list
    ]
)


In [0]:
# Create fuse volume for Ray to write back to Spark
import os
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog

w = WorkspaceClient()
UC_fuse_temp_dir = f'/Volumes/{catalog_name}/{schema_name}/ray_data_tmp_dir'
if not os.path.exists(UC_fuse_temp_dir):
  created_volume = w.volumes.create(catalog_name=catalog_name,
                                    schema_name=schema_name,
                                    name='ray_data_tmp_dir',
                                    volume_type=catalog.VolumeType.MANAGED
                                    )

os.environ['RAY_UC_VOLUMES_FUSE_TEMP_DIR'] = f'/Volumes/{catalog_name}/{schema_name}/ray_data_tmp_dir'


In [0]:
import ray
from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster, setup_global_ray_cluster


restart = True
if restart is True:
  try:
    shutdown_ray_cluster()
  except:
    pass
  try:
    ray.shutdown()
  except:
    pass

setup_global_ray_cluster(
  min_worker_nodes=3,
  max_worker_nodes=3,
  num_cpus_worker_node=64,
  num_gpus_worker_node=0,
  collect_log_to_path="/dbfs/Users/jon.cheung@databricks.com/ray_collected_logs"
)

In [0]:
import os
import numpy as np
import pyarrow.dataset as ds
import lancedb
from tqdm import tqdm

num_closest_vectors = 1

class LanceDBActor:
    def __init__(self, vector_parquet_path, pyarrow_schema):
        
        # print(os.listdir(vector_parquet_path))
        # TODO: consider parameterizing
        num_partitions = 1
        num_sub_vectors = 5
        # assuming 48 core CPU...
        os.environ["LANCE_CPU_THREADS"] = "48"
        os.environ["LANCE_IO_THREADS"] = "48"
        
        # Create local directory, if it does not exist, and connect LanceDB to it.
        lance_db_uri = "/tmp/lancedb"
        try: 
            os.makedirs(lance_db_uri)
        except:
            print('lanceDB directory already exists')
        self.db = lancedb.connect(lance_db_uri)

        # Open Lance table if exists, otherwise create.
        try: 
            self.table_arrow = self.db.open_table(lance_table_name)
            print(f'Found LanceDB table {lance_table_name}. Using this for vector search.')
        except:
            print(f'No LanceDB table {lance_table_name}. Rebuilding from scratch')
            self.table_arrow = self.db.create_table(lance_table_name,
                                        data=LanceDBActor._get_batches_from_parquet(vector_parquet_path,
                                                                                    pyarrow_schema,
                                                                                    batch_size=200_000),
                                        mode='overwrite'
                                        )
            self.table_arrow.create_index(
                    metric="l2",
                    vector_column_name="list_col",
                    num_partitions=num_partitions,
                    num_sub_vectors=num_sub_vectors
                )
        

    @staticmethod
    def _get_batches_from_parquet(vector_parquet_path,
                                                schema,
                                                 batch_size: int = 4096):
            """
            Reads a Parquet file in chunks and yields PyArrow RecordBatches,
            displaying progress using tqdm.
            """
            dataset = ds.dataset(vector_parquet_path, format="parquet", schema=schema)

            total_rows = dataset.count_rows()
            total_batches = np.ceil(total_rows / batch_size)
            scanner = dataset.scanner(batch_size=batch_size)

            # Wrap the scanner.to_batches() with tqdm
            # We use `total_batches` for tqdm's 'total' argument.
            with tqdm(total=total_batches, unit="batch", desc="Ingesting Parquet Batches") as pbar:
                for batch in scanner.to_batches():
                    yield batch
                    pbar.update(1) # Manually update progress for each yielded batch
                    pbar.set_postfix({"rows_in_batch": len(batch)})

    def __call__(self, batch: np.ndarray):
        results = self.table_arrow.search(batch['item']).limit(1).to_pandas()

        return results



def create_arrays(n, dimensions):
    return [np.random.randint(0, 256, size=dimensions).astype(np.float16) for _ in range(n)]

# Make Ray Data dataset and inference with it
large_query_batch =  ray.data.from_items(create_arrays(num_test_vectors, dimensions=35))
output = (large_query_batch.map_batches(LanceDBActor,
                              fn_constructor_args=(audio_parquet_path,pyarrow_schema),
                              concurrency=3, 
                              num_cpus=64,
                              batch_size=250))




In [0]:
all_results = output.take_all()

In [0]:
num_test_vectors = 10000
large_query_batch =  ray.data.from_items(create_arrays(num_test_vectors, dimensions=35))
output = (large_query_batch.map_batches(LanceDBActor,
                              fn_constructor_args=(audio_parquet_path,pyarrow_schema),
                              concurrency=3, 
                              num_cpus=64,
                              batch_size=250))
second_batch = output.take_all()

In [0]:
second_batch[0]

In [0]:
# # # write grouped results to a Delta table
# ray.data.Dataset.write_databricks_table(output, 
#                                         f"{catalog_name}.{schema_name}.f'{lance_table_name}_results",
#                                          mode='overwrite')