In [0]:
%pip install ray[all]
dbutils.library.restartPython()


## Shard plan
1. Create parquet shards (e.g., 10) --> See 00a_shard_data_preparation
2. Construct inferencing Ray actors
  * Ray actor, as input, receive shard num.
  * Ray actor build local lance DB instance using shard from parquet
3. inference:
  * ray.data a batch (e.g. 1k)
  * rerank?

In [0]:
catalog_name = 'jon_cheung'
schema_name = 'vizio_poc'
lance_table_name = 'audio_100M_chunk_shard_1'
num_test_vectors = 2_000


In [0]:
import ray
from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster

restart = True
if restart is True:
  try:
    shutdown_ray_cluster()
  except:
    pass
  try:
    ray.shutdown()
  except:
    pass

setup_ray_cluster(
  min_worker_nodes=2,
  max_worker_nodes=4,
  num_cpus_worker_node=48,
  num_gpus_worker_node=0,
  collect_log_to_path="/dbfs/Users/jon.cheung@databricks.com/ray_collected_logs"
)

In [0]:
import os
import numpy as np
import pyarrow as pa
import pyarrow.dataset as ds
import lancedb


class LanceDBActor:
    def __init__(self, parquet_path, pyarrow_schema):
        
        # TODO: consider parameterizing
        lance_db_uri = "/tmp/lancedb"
        num_partitions = 10
        num_sub_vectors = 5
        
        # Create lanceDB directory
        os.makedirs(lance_db_uri, exist_ok=True)

        # assuming 96 core CPU...
        os.environ["LANCE_CPU_THREADS"] = "48"
        os.environ["LANCE_IO_THREADS"] = "48"

        db = lancedb.connect(lance_db_uri)
        self.table_arrow = db.create_table(lance_table_name,
                                      data=LanceDBActor._get_batches_from_parquet_with_progress(parquet_path,
                                                                                                pyarrow_schema,
                                                                                                batch_size=200_000),
                                      mode="overwrite"
                                      )
        
        self.table_arrow.create_index(
                metric="l2",
                vector_column_name="list_col",
                num_partitions=num_partitions,
                num_sub_vectors=num_sub_vectors
            )
        

    @staticmethod
    def _get_batches_from_parquet_with_progress(parquet_path: str,
                                                schema: pa.Schema,
                                                 batch_size: int = 4096):
            """
            Reads a Parquet file in chunks and yields PyArrow RecordBatches,
            displaying progress using tqdm.
            """
        dataset = ds.dataset(parquet_path, format="parquet", schema=schema)

        total_rows = dataset.count_rows()
        total_batches = np.ceil(total_rows / batch_size)
        scanner = dataset.scanner(batch_size=batch_size)

        # Wrap the scanner.to_batches() with tqdm
        # We use `total_batches` for tqdm's 'total' argument.
        with tqdm(total=total_batches, unit="batch", desc="Ingesting Parquet Batches") as pbar:
            for batch in scanner.to_batches():
                yield batch
                pbar.update(1) # Manually update progress for each yielded batch
                pbar.set_postfix({"rows_in_batch": len(batch)})

    def __call__(self, batch: np.ndarray, limit: int):
        results = self.table_arrow.search(query_batch).limit(limit).to_pandas()

        return results


## TODO: Can actors read from Volumes?! Do we need an S3 bucket instead?
audio_parquet_path = f'/Volumes/{catalog}/{schema}/{lance_table_name}'
pyarrow_schema = pa.schema(
    [
        pa.field("id", pa.int64()),
        pa.field("list_col", pa.list_(pa.float16(), 35)),   # Fixed size list
    ]
)

def create_arrays(n, dimensions):
    return [np.random.randint(0, 256, size=dimensions).astype(np.float16) for _ in range(n)]

# Make Ray Data dataset and inference with it
large_query_batch =  ray.data.from_items(create_arrays(num_test_vectors, dimensions=35))
large_query_batch.map_batches(LanceDBActor,
                              fn_constructor_args={'parquet_path': audio_parquet_path, 
                                                   'pyarrow_schema': pyarrow_schema},
                              fn_args={'limit': 1},
                              num_cpus=48,
                              batch_size=250)


