## Load and manage Segment Anything

[Segment Anything](https://segment-anything.com/) (SA-1B) is a image segmentation dataset containing 1B masks and 11M images. The source dataset has the following file layout:

```
sa_000000/sa_1.json
         /sa_2.json
         ...
sa_000001/...
```

where each JSON file contains an image information and its labels, the equivalent PyArrow schema for the JSON file is:

In [None]:
import pyarrow as pa

segmentation_schema = pa.struct([
  ("size", pa.list_(pa.int64())),
  ("counts", pa.string())
])

annotation_schema = pa.struct([
  ("area", pa.int64()),
  ("bbox", pa.list_(pa.float64())),
  ("crop_box", pa.list_(pa.float64())),
  ("id", pa.int64()),
  ("point_coords", pa.list_(pa.list_(pa.float64()))),
  ("predicted_iou", pa.float64()),
  ("segmentation", segmentation_schema),
  ("stability_score", pa.float64())
])

image_schema = pa.struct([
  ("image_id", pa.int64()),
  ("file_name", pa.string()),
  ("width", pa.int64()),
  ("height", pa.int64())
])

ds_schema = pa.schema([
  ("image_id", pa.int64()),  # Add an image_id as the primary key.
  ("shard", pa.string()),  # Record shard ("sa_000000") for inferring full image path.
  ("image", image_schema),
  ("annotations", pa.list_(annotation_schema))
])

The following method converts the JSON objects to an Arrow table. Each row represents an image with annotations.

In [None]:
from typing import Iterator, List

import glob
import os
from pyarrow import json

sa_dir = "/space/segmeng_anything"  # Change to your data folder.
# Knobs for making things faster.
images_per_shard = 100
total_shards = 10


def make_iter(shards: List[str]) -> Iterator[pa.Table]:

  def json_files_to_arrow(shard: str) -> pa.Table:
    batch: List[pa.Table] = []
    pattern = os.path.join(sa_dir, shard, "*.json")
    print(f"processing pattern: {pattern}")
    for f in glob.glob(pattern):
      batch.append(json.read_json(f,
        parse_options=json.ParseOptions(explicit_schema=ds_schema)))
      if len(batch) >= images_per_shard:
        break

    table = pa.concat_tables(batch)
    image_id_column = table.column("image").combine_chunks().field("image_id")
    shard_column = pa.StringArray.from_pandas([shard] * table.num_rows)

    # Add image_id and shard columns to the original data.
    table = table.drop("shard").append_column(
      pa.field("shard", pa.string()), shard_column)
    table = table.drop("image_id").append_column(
      pa.field("image_id", pa.int64()), image_id_column)
    return table

  for shard in shards:
    yield json_files_to_arrow(shard)

Create an empty Space dataset:

In [None]:
from space import Dataset

ds_location = "/space/datasets/sa"
# There is no record field, because this dataset does not store image bytes.
ds = Dataset.create(ds_location, ds_schema,
  primary_keys=["image_id"], record_fields=[])

ds = Dataset.load(ds_location)

Write data into it distributedly using Ray:

In [None]:
ds.ray().append_from(
  [lambda: make_iter([f"sa_{i:06}"]) for i in range(total_shards)])

After loading image information and annotations into Space, next step we read and preprocess images, then persist the result in Space as Materialized Views. When the source dataset is modified, we can refresh the image dataset to incrementally synchronizes changes.

In [None]:
from typing import Any, Dict
import cv2

def read_and_preprocess_image(data: Dict[str, Any]) -> Dict[str, Any]:
  ims = []
  for image_id, shard, image in zip(data["image_id"], data["shard"], data["image"]):
    full_path = os.path.join(sa_dir, shard, image["file_name"])
    im = cv2.imread(full_path)
    im = cv2.resize(im, dsize=(100, 100), interpolation=cv2.INTER_CUBIC)
    ims.append(im.tobytes())

  # Drop `image` column that won't writer into the output view, add a new
  # `image_bytes` column.
  del data["image"]
  data["image_bytes"] = ims
  return data


# Create a view by transforming the source dataset.
view = ds.map_batches(
  fn=read_and_preprocess_image,
  input_fields=["image_id", "shard", "image"], # Don't need to read annotations.
  output_schema=pa.schema([
    ("image_id", pa.int64()),
    ("shard", pa.string()),
    ("image_bytes", pa.binary())  # Add a new field for image bytes.
  ]),
  output_record_fields=["image_bytes"] # Store this field in ArrayRecord.
)
# Create an empty materialized view.
mv = view.materialize("/space/datasets/sa_mv")

# Synchronize the MV to the source dataset.
mv.ray().refresh()

# Verifly the content in mv
mv.ray().read_all().num_rows

Make some modifications in the source dataset, and synchronize changes to MV.

In [None]:
# Check all image IDs, and pick a few to delete.
ds.local().read_all().select(["image_id"])

# Delete two images.
ds.local().delete((pc.field("image_id") == 4811) | (pc.field("image_id") == 6973))

mv.ray().refresh()
# The images are deleted in MV as well.
mv.local().read_all().select(["image_id"]).num_rows