## Load and manage Segment Anything

### Load Raw Data into Space Datasets

[Segment Anything](https://segment-anything.com/) (SA-1B) is a image segmentation dataset containing 1B masks and 11M images. The source dataset has the following file layout:

```
sa_000000/sa_1.json
         /sa_1.jpg
         /sa_2.json
         /sa_2.jpg
         ...
sa_000001/...
```

where each JSON file contains an image information and its labels, the equivalent PyArrow schema for the JSON file is:

In [None]:
import pyarrow as pa

segmentation_schema = pa.struct([
  ("size", pa.list_(pa.int64())),
  ("counts", pa.string())
])

annotation_schema = pa.struct([
  ("area", pa.int64()),
  ("bbox", pa.list_(pa.float64())),
  ("crop_box", pa.list_(pa.float64())),
  ("id", pa.int64()),
  ("point_coords", pa.list_(pa.list_(pa.float64()))),
  ("predicted_iou", pa.float64()),
  ("segmentation", segmentation_schema),
  ("stability_score", pa.float64())
])

image_schema = pa.struct([
  ("image_id", pa.int64()),
  ("file_name", pa.string()),
  ("width", pa.int64()),
  ("height", pa.int64())
])

ds_schema = pa.schema([
  ("image_id", pa.int64()),  # Add an image_id as the primary key.
  ("shard", pa.string()),  # Add shard ("sa_000000") for inferring full image path later.
  ("image", image_schema),
  ("annotations", pa.list_(annotation_schema))
])

The following method reads JSON files into Arrow tables, one row per file. Each row represents an image with annotations.

In [None]:
from typing import Iterator, List

import glob
import os
from pyarrow import json

sa_dir = "/segment_anything"  # Change to your data folder.
# Knobs for making things faster.
images_per_shard = 100
total_shards = 10


def make_iter(shards: List[str]) -> Iterator[pa.Table]:

  def json_files_to_arrow(shard: str) -> pa.Table:
    batch: List[pa.Table] = []

    pattern = os.path.join(sa_dir, shard, "*.json")
    print(f"processing pattern: {pattern}")
    for f in glob.glob(pattern):
      batch.append(json.read_json(f,
        parse_options=json.ParseOptions(explicit_schema=ds_schema)))
      if len(batch) >= images_per_shard:
        break

    table = pa.concat_tables(batch)
    image_id_column = table.column("image").combine_chunks().field("image_id")
    shard_column = pa.StringArray.from_pandas([shard] * table.num_rows)

    # Add image_id and shard columns to the original data.
    table = table.drop("shard").append_column(
      pa.field("shard", pa.string()), shard_column)
    table = table.drop("image_id").append_column(
      pa.field("image_id", pa.int64()), image_id_column)
    return table

  for shard in shards:
    yield json_files_to_arrow(shard)

Create an empty Space dataset:

In [None]:
from space import Dataset

raw_ds_location = "/space/datasets/raw_data"
# There is no record field, because this dataset does not store image bytes.
raw_ds = Dataset.create(raw_ds_location, ds_schema,
  primary_keys=["image_id"], record_fields=[])

After creation, the dataset can be loaded from location later:

In [None]:
raw_ds = Dataset.load(raw_ds_location)

Write data into Space distributedly using Ray:

In [None]:
# append_from accepts a list of no-args methods that returns a generator.
# A method is used because generator cannot be pickled for distributed execuion.
raw_ds.ray().append_from(
  [lambda idx=i: make_iter([f"sa_{idx:06}"]) for i in range(total_shards)])

### Build Pipelines of Ray Transforms and Materialized Views (MVs)

Our goal in this example is to build two Space Materialized Views (MVs), for pre-processed images and annotations separately. When the source raw dataset is modified, the MVs can be refreshed to incrementally synchronize changes. Only the changes are processed.

By separating images and annotations, the pre-processed image MV can be a central repository to be shared by multiple training tasks. It can be joined with other datasets or MVs of annotations to construct the final training input. Low cost JOINs are achieved by joining the reference (e.g., addresses in row format files) instead of data content. The JOINs can be effectively distributed using column statistics and data skipping technologies.

#### Pre-processed Image MVs

After loading image information and annotations into Space, next step we read and pre-process images, then persist the result in a Space Materialized Views (MV). When the source dataset is modified, we can refresh the MV to incrementally synchronizes changes.

In [None]:
from typing import Any, Dict
import cv2

def read_and_preprocess_image(data: Dict[str, Any]) -> Dict[str, Any]:
  ims = []
  for image_id, shard, image in zip(data["image_id"], data["shard"], data["image"]):
    full_path = os.path.join(sa_dir, shard, image["file_name"])
    im = cv2.imread(full_path)
    im = cv2.resize(im, dsize=(100, 100), interpolation=cv2.INTER_CUBIC)
    ims.append(im.tobytes())

  # Drop `image` column that won't writer into the output view, add a new
  # `image_bytes` column.
  del data["image"]
  data["image_bytes"] = ims
  return data


# Create a view by transforming the source dataset.
image_view = ds.map_batches(
  fn=read_and_preprocess_image,
  input_fields=["image_id", "shard", "image"], # Don't need to read annotations.
  output_schema=pa.schema([
    ("image_id", pa.int64()),
    ("shard", pa.string()),
    ("image_bytes", pa.binary())  # Add a new field for image bytes.
  ]),
  output_record_fields=["image_bytes"] # Store this field in ArrayRecord.
)

# Create an empty MV.
image_mv_location = "/space/mvs/image"
image_mv = image_view.materialize(image_mv_location)

# Synchronize the MV to the source dataset. It populates the MV storage.
image_mv.ray().refresh()

# Verifly the images in MV.
image_mv.ray().read_all().num_rows

The MV can be loaded from file location anytime:

In [None]:
from space import MaterializedView

image_mv = MaterializedView.load(image_mv_location)

Make some modifications in the source dataset, and synchronize changes to MV.

In [None]:
# Show all image IDs, and pick a few to delete.
raw_ds.local().read_all().select(["image_id"])

# Delete two images.
raw_ds.local().delete((pc.field("image_id") == 4811) | (pc.field("image_id") == 6973))

image_mv.ray().refresh()
# The images are deleted in MV.
image_mv.local().read_all().select(["image_id"]).num_rows

#### Annotations MVs

In the next step, we create a MV for annotations. Annotations are stored as binary in ArrayRecord files, instead of columnar format in Parquet. There could be hundreds of object annotations per image, and annotation contains segmentation data that is much larger than class labels. Storing in ArrayRecord row format allows us to read annotations by reference instead of data.

In [None]:
import pickle

def serialize_annotations(data: Dict[str, Any]) -> Dict[str, Any]:
  # pickle is just an example of serialization.
  data["annotations_bytes"] = [pickle.dumps(d) for d in data["annotations"]]
  return data


# Create a view by transforming the source dataset.
annotations_view = ds.map_batches(
  fn=serialize_annotations,
  input_fields=["image_id", "annotations"],
  output_schema=pa.schema([
    ("image_id", pa.int64()),
    ("annotations_bytes", pa.binary())
  ]),
  output_record_fields=["annotations_bytes"]
)

# Create an empty MV.
annotations_mv_location = "/space/mvs/annotations"
annotations_mv = annotations_view.materialize(annotations_mv_location)

annotations_mv.ray().refresh()

#### Joining Images and Annotations MVs

Preprocessed images and its annotations are feed together to the training framework, after matching rows by `image_id` using a JOIN. JOIN is relatively cheap in Space, because it can:

- Read references (addresses) of records instead of data.
- Joining references and persist the result.
- Scan the result, and read data from the references (de-reference).

In addition, with the OLAP style column statistics, the JOIN can be easily parallelized on multiple workers:

- For the primary key `image_id`, read its min and max across all files from manifests.
- If we have N workers, split the min max range into N ranges. Each worker is responsible for one range.
- On each worker, a filter is built based on the assigned range. The filter is pushed down to the data source. Because of the column statistics Space stored in manifest files, uninterested files are skipped (data skipping) to save IO cost.
  - The column statistics can be read as follows. `_STATS_f0` is the stats for field ID `0`. Field ID to name mapping can be read from storage schema.

In [None]:
# Show the schema to know field ID to name mapping.
print(annotations_view.view.schema)

# `_STATS_f0` stores stats for `image_id`.
min_max = pa.concat_tables(
  mv_annotations.storage.index_manifest()).column("_STATS_f0").combine_chunks()

# This is the full range to split.
min_ = min(min_max.field("_MIN").to_pylist())
max_ = max(min_max.field("_MAX").to_pylist())

- Perform JOIN locally on each worker, and aggregate the result:

In [None]:
import pyarrow.compute as pc

def join_image_and_annotations(
    location_image: str, location_annotations: str,
    min_: int, max_: int, last: bool):
  mv_image = MaterializedView.load(location_image)
  mv_annotations = MaterializedView.load(location_annotations)

  if not last:
    filter_ = (pc.field("image_id") >= min_) & (pc.field("image_id") < max_)
  else:
    filter_ = (pc.field("image_id") >= min_) & (pc.field("image_id") <= max_)

  # `reference_read=True` reads tuple (file_path, row_id) of a record, instead of data content.
  image_refs = mv_image.local().read_all(filter_, reference_read=True).flatten()
  annotations_refs = mv_annotations.local().read_all(filter_, reference_read=True).flatten()
  return image_refs.join(annotations_refs, keys="image_id")


# This example skips the details of splitting the full range, here use [0, 10000) as an example
# of a JOIN partition.
joined_references = join_image_and_annotations(
  image_mv_location,
  annotations_mv_location,
  # A range assigned to a partitioned JOIN.
  min_=0, max_=10000, last=False)

The JOIN result table contains record addresses of `image_bytes` and `annotations_bytes`. The data value can be read as follows:

In [None]:
from space.core.ops.read import read_record_column

images = read_record_column(image_mv.storage, joined_references, "image_bytes")