<a href="https://colab.research.google.com/github/dracifer/td_exploration/blob/main/TorchData_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Option 1 -- Cohesive with current Dataloader V2 design:
* Business logic expressed by `Datapipe` (or conventional `Dataset`)
  * Can be traced to a backend independent execution plan/graph -- *Logical IR*
  * Function calls like `shuffle` and `map_batches` accumulately build the plan rother than execute immediately. Actual data loading and processing is triggered by an "execution" call. 
  * Datapipe can be used to compose plans for both training and inference. In certain cases (e.g. eager serving mode), Datapipe is not needed. 
```
class Datapipe:
    def shuffle(self, shuffle_spec: ShuffleSpec) -> "Datapipe":
    def repeat(self, num: int) -> "Datapipe":
    def sort(self, fn: Callable[[Row], bool]) -> "Datapipe":
    def map_batches(self, fn: Callable[[RowBatch], RowBatch], batch_size: Optional[int] = None) -> "Datapipe":
    def filter(self, fn: Callable[[Row], bool]) -> "Datapipe":
    # aggregation functions like max, mean, min, sum, etc. will be added later
```


* Execution specified by `ReadingService`
  * Can adapt *Logical IR* to a backend *Execution Plan* 

```
class ReadingService(ABC):
  def initialize(self, dp: Datapipe) -> Datapipe:
  def finalize(self) -> bool:
  def iter_batches(self, state: Dict[str, Any]) -> RowBatchIterator:
```
```   
class SimpleReadingService(ReadingService):
class MultiprocessingReadingService(ReadingService):
class OnboxDppReadingService(ReadingService):
class DisaggDppReadingService(ReadingService):
class RayDppReadingService(ReadingService):
```
* Connect business logic (specified by `Datapipe`) with execution engine (specified by `ReadingService`) and produce data through iteration by `Dataloader`

```
class Dataloader:
  def __init__(self, datapipe: Datapipe, reading_service: ReadingService) -> None:
  def reset(self) -> None:
  def iter_batches(self) -> RowBatchIterator:
  def __iter__(self) -> "Dataloader":
  def __next__(self) -> RowBatch:
  def state_dict(self) -> Dict[str, Any]:
  def load_state_dict(self, state: Dict[str, Any]) -> bool:
```

### Note
* Ray Dataset is a combination of all the 3 concepts above. For us, since we need to support multiple execution engines (e.g. DPP, multiprocessing, and potentially Sparse and Ray), separating `ReadingService` abstraction is useful. In addition, Pytorch already has `Dataloader` concept working with Dataset, this 3-component paradigm is reasonable.  
* Ray DatasetPipeline is a parallelism mechanism that pipelining operations on batches. This is to overlapping data reading with training. We will ignore this use case in this doc. There is similar implementation in TorchRec. Will have follow-up design on enabling high efficient pipelining later. 

## install libs

In [None]:
!python --version

Python 3.7.13


In [None]:
# MapArray.offsets need version 7.0.0
!pip3 install pyarrow==7.0.0 pandas numpy 
!pip3 install torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/nightly/cpu


## Utils

In [None]:
import itertools
from enum import Enum
from typing import (
    Any, Collection, Dict, Iterable, Iterator, List, Optional, 
    Tuple, Type, Union, cast, Mapping, Callable, 
)
from collections import UserDict

def compute_offsets(input: Iterable, include_last: bool) -> List[int]:
    offsets = list(itertools.accumulate([0] + input, lambda x, y: x + len(y)))
    if not include_last:
        return offsets[:-1]
    return offsets


def is_primitive(input: Any) -> bool:
    PRIMITIVE = (int, str, bool, float)
    if isinstance(input, PRIMITIVE):
        return True
    return False


class StructType(Enum):
    NON_LEAF = 0
    LEAF_FLOAT = 1
    LEAF_ID_LIST = 2
    LEAF_ID_SCORE_MAP = 3


def struct_type(input: Any) -> bool:
    if type(input) == float:
        return StructType.LEAF_FLOAT
    if isinstance(input, dict) and (
        len(input) == 0 or (type(next(iter(input.keys()))), type(next(iter(input.values())))) == (int, float)
    ):
        return StructType.LEAF_ID_SCORE_MAP
    if type(input) == list and (len(input) == 0 or type(next(iter(input))) == int):
        return StructType.LEAF_ID_LIST
    return StructType.NON_LEAF
    

## Datapipe: Data logical processor
* Datapipe describes operations. It is STATELESS.
The operation plan/IR can be compiled and reused (not implemented yet). Both training and inference logic can be described by Datapipe but should have different data source and operations (e.g. training specifies shuffle, while inference does not). 

* RowBatch: A batch of rows. 

## Datapipe Accessor

Access data in specific format.


In [None]:
import pyarrow
import torch
from typing import TypeVar, Generic
from abc import ABC, abstractmethod, abstractclassmethod
from dataclasses import dataclass

@dataclass
class JaggedTensor:
  offsets: Optional[torch.Tensor] = None
  indices: Optional[torch.Tensor] = None
  values: Optional[torch.Tensor] = None


def to_torch(input: Union[dict, pyarrow.Array, JaggedTensor]) -> Union[dict, JaggedTensor]:
  if type(input) == pyarrow.lib.FloatArray:
    return JaggedTensor(
        values=torch.from_numpy(input.to_numpy())
    )
  elif type(input) == pyarrow.lib.ListArray:
    return JaggedTensor(
        offsets=torch.from_numpy(input.offsets.to_numpy()),
        indices=torch.from_numpy(input.values.to_numpy()),
    )
  elif type(input) == pyarrow.lib.MapArray:
    return JaggedTensor(
        offsets=torch.from_numpy(input.offsets.to_numpy()),
        indices=torch.from_numpy(input.keys.to_numpy()),
        values=torch.from_numpy(input.items.to_numpy()),
    )
  elif type(input) == JaggedTensor:
    return input

  output = {}
  if isinstance(input, dict):
    for key, val in input.items():
      output[key] = to_torch(val)
  return output



Row = dict

# collate operation such as to_torch() can be done by batch data OR
# By building the operation as a part of execution plan by Datapipe so that it
# will be conducted on server side.
# Here we use RowBatch.to_torch() for demonstration purpose

class RowBatch(dict):
  def __init__(self, data, row_num: int) -> None:
    self._row_num = row_num
    super().__init__(data)
    return

  @property
  def row_num(self) -> int:
    return self._row_num

  def to_torch(self, schema: Optional[dict] = None) -> dict:
    if schema is not None:
      raise NotImplementedError("not supporting customized schema")
    return to_torch(self)
        

# More robus design is needed for this class. 
# In this example, PyDictArrowDataAccessor constructs Arrow data from Python dict for demo purpose. 
class BatchAccessor(Enum):
  NATIVE = "native_batch_accessor"
  ARROW = "pydict_arrow_batch_accessor"
  TORCH_TENSOR = "torch_tensor_batch_accessor"

BATCH_ACCESSORS = {}

def get_batch_accessor(name: BatchAccessor):
  global BATCH_ACCESSORS
  accessor = BATCH_ACCESSORS[name]()
  return accessor

def register_batch_accessor(name: BatchAccessor): 
  def register(cls):
    print(f"Registering Batch Accessor {name}: {cls}")
    global BATCH_ACCESSORS
    BATCH_ACCESSORS[name] = cls
    return cls
  return register

T = TypeVar("T")
class PyDictDataAccessor(ABC, Generic[T]):
  @abstractmethod
  def fetch_batch(self, data: Dict[str, Collection], block_size: Optional[int] = None, start: int = 0) -> RowBatch:
    ...

  @abstractclassmethod
  def build_column_val(cls, input: List[Any]) -> T:
    ...

  @classmethod
  def build_columns(cls, input: Dict[str, Collection], block_size: Optional[int] = None, start: int = 0) -> RowBatch:
      assert len(input) > 0, f"input {input} cannot have empty columns"
      col_values = []
      row_num = block_size
      for val in input.values():
          if type(val) == list:
              block_size = block_size or len(val)
              end = min(start + block_size, len(val))
              row_num = end - start
              col_val = cls.build_column_val(val[start:end])
          elif type(val) == dict:
              col_val = cls.build_columns(val, block_size, start)
              row_num = col_val.row_num
          col_values.append(col_val)

      return RowBatch(data=zip(input.keys(), col_values), row_num=row_num)

# NOTE: Native mode does not work now!!!
@register_batch_accessor(BatchAccessor.NATIVE)
class PyDictNativeDataAccessor(PyDictDataAccessor[list]):
  def fetch_batch(self, data: Dict[str, Collection], block_size: Optional[int] = None, start: int = 0) -> RowBatch:
    return RowBatch(data, len(data))

@register_batch_accessor(BatchAccessor.ARROW)
class PyDictArrowDataAccessor(PyDictDataAccessor[pyarrow.Array]):
  def fetch_batch(self, data: Dict[str, Collection], block_size: Optional[int] = None, start: int = 0) -> RowBatch:
    return PyDictArrowDataAccessor.build_columns(input=data, block_size=block_size, start=start)

  @classmethod
  def build_column_val(cls, input: List[Any]) -> pyarrow.Array:
      child_type = struct_type(next(iter(input)))
      assert child_type != StructType.NON_LEAF, f"Invalid input {input}"
      # List[List[int]]
      if child_type == StructType.LEAF_ID_LIST:
          col = pyarrow.ListArray.from_arrays(
              offsets=pyarrow.array(compute_offsets(input, True)),
              values=pyarrow.array(list(itertools.chain(*input)), type=pyarrow.int64()),
          )
      # List[float]
      elif child_type == StructType.LEAF_FLOAT:
          col = pyarrow.array(input, type=pyarrow.float32())
      # List[Dict[int, float]]
      elif child_type == StructType.LEAF_ID_SCORE_MAP:
          col = pyarrow.MapArray.from_arrays(
              offsets=pyarrow.array(compute_offsets(input, True)),
              keys=pyarrow.array(list(itertools.chain(*[item.keys() for item in input])), type=pyarrow.int64()),
              items=pyarrow.array(list(itertools.chain(*[item.values() for item in input])), type=pyarrow.float32()),
          )
      return col


@register_batch_accessor(BatchAccessor.TORCH_TENSOR)
class PyDictTorchTensorDataAccessor(PyDictDataAccessor[JaggedTensor]):
  def fetch_batch(self, data: Dict[str, Collection], block_size: Optional[int] = None, start: int = 0) -> RowBatch:
    return PyDictTorchTensorDataAccessor.build_columns(input=data, block_size=block_size, start=start)

  @classmethod
  def build_column_val(cls, input: List[Any]) -> JaggedTensor:
      child_type = struct_type(next(iter(input)))
      assert child_type != StructType.NON_LEAF, f"Invalid input {input}"
      # List[List[int]]
      if child_type == StructType.LEAF_ID_LIST:
          col = JaggedTensor(
              offsets=torch.tensor(compute_offsets(input, True), dtype=torch.int32),
              indices=torch.tensor(list(itertools.chain(*input)), dtype=torch.int64),
          )
      # List[float]
      elif child_type == StructType.LEAF_FLOAT:
          col = JaggedTensor(
              values=torch.tensor(input, dtype=torch.float32),
          )
      # List[Dict[int, float]]
      elif child_type == StructType.LEAF_ID_SCORE_MAP:
          col = JaggedTensor(
              offsets=torch.tensor(compute_offsets(input, True), dtype=torch.int32),
              indices=torch.tensor(list(itertools.chain(*[item.keys() for item in input])), dtype=torch.int64),
              values=torch.tensor(list(itertools.chain(*[item.values() for item in input])), dtype=torch.float32),
          )
      return col

  


Registering Batch Accessor BatchAccessor.NATIVE: <class '__main__.PyDictNativeDataAccessor'>
Registering Batch Accessor BatchAccessor.ARROW: <class '__main__.PyDictArrowDataAccessor'>
Registering Batch Accessor BatchAccessor.TORCH_TENSOR: <class '__main__.PyDictTorchTensorDataAccessor'>


In [None]:
class RowBatchIterator(Iterator[RowBatch]):
  def __init__(self, fetch_fn: Callable[[Dict[str, Any]], RowBatch], states: Dict[str, Any]) -> None:
    self._fetch_fn = fetch_fn
    self._orig_states = states
    self._states = states

  def __iter__(self) -> "RowBatchIterator": 
    # reset iterator
    self._states = self._orig_states
    return self

  def __next__(self) -> RowBatch:
    data, self._states = self._fetch_fn(self._states)
    return data


class ShuffleSpec(Enum):
  NO_SHUFFLE = 0
  FULL_SHUFFLE = 1


class Datapipe(ABC):
  def __init__(self) -> None:
    self._batch_size = 3
    self._batch_accessor = BatchAccessor.NATIVE

  # Commenting out all plan building logic for now.
  def shuffle(self, shuffle_spec: ShuffleSpec) -> "Datapipe":
    self._shuffle_spec = shuffle_spec
    # self._plan = PlanBuilder.addShuffle(shuffle_spec)
    return self

  def repeat(self, num: int) -> "Datapipe":
    # self._plan = PlanBuilder.addRepeat(num)
    self._total_passes = num
    return self

  def sort(self, fn: Callable[[Row], bool]) -> "Datapipe":
    # self._plan = PlanBuilder.addSort(fn)
    return self

  def map_batches(
      self, 
      fn: Callable[[RowBatch], RowBatch], 
      batch_size: Optional[int] = None,
      batch_accessor: Optional[BatchAccessor] = None,
    ) -> "Datapipe":
    # self._plan = PlanBuilder.addMapBatches(fn, batch_size)
    self._map_batch_fn = fn
    self._batch_size = batch_size or self._batch_size
    self._batch_accessor = batch_accessor or self._batch_accessor
    return self

  # can be implemented with map_batches()
  def filter(self, fn: Callable[[Row], bool]) -> "Datapipe":
    # self._plan = PlanBuilder.addFilter(fn)
    return self

  # can be implemented with map_batches()
  def map(self, fn: Callable[[Row], Row]) -> "Datapipe":
    # self._plan = PlanBuilder.addMap(fn)
    return self

  def rebatch(self, batch_size: int) -> "Datapipe":
    # self._plan = PlanBuilder.rebatch(batch_size)
    self._batch_size = batch_size
    return self

  def change_batch_accessor(self, batch_accessor: BatchAccessor) -> "Datapipe":
    # self._plan = PlanBuilder.batchFormat(batch_format)
    self._batch_accessor = batch_accessor
    return self

  # aggregation functions like max, mean, min, sum, etc. will be added later

  @property
  def batch_size(self) -> int:
    return self._batch_size

# IterDatapipe -- iterate over data source to produce data. 
# Used in training together with Dataloader and ReadingService
class IterDatapipe(Datapipe):
  @abstractmethod
  def iter_batches(self, batch_size: int, start: int = 0, stride: int = 1) -> RowBatchIterator:
    pass


class PyDictDatapipe(IterDatapipe): 
  def __init__(
      self, data: Dict[str, Any], 
      batch_size: int = 3, 
      batch_accessor: BatchAccessor = BatchAccessor.NATIVE,
  ) -> None:
    super().__init__()
    self._data = data
    self._batch_size = batch_size
    self._batch_accessor = batch_accessor

  def fetch_batch(self, states: Dict[str, Any]) -> Tuple[RowBatch, Dict[str, Any]]:
    # The current implementation is for demo purpose. 
    # Real implementation should be based on executing self._plan rather than calling specific functions
    batch = get_batch_accessor(self._batch_accessor).fetch_batch(self._data, states["batch_size"], states["start"])
    if hasattr(self, "_map_batch_fn"):
      batch = self._map_batch_fn(batch)
    states["start"] += states["batch_size"] * states["stride"]
    return (batch, states)

  def iter_batches(self, batch_size: int, start: int = 0, stride: int = 1) -> RowBatchIterator:
    return RowBatchIterator(self.fetch_batch, {"start": start, "batch_size": batch_size, "stride": stride})

# class HiveDatapipe(IterDatapipe):
#   def __init__(self, hive_specs: Dict[str, Any], block_size: int, compute_adaptor) -> None: 

# CallableDatapipe -- Apply transformation on data. 
# Action in reactive way (need to be called with provided data). Userd in inference
class CallableDatapipe(Datapipe):
  def __call__(self, data):
    raise NotImplementedError("not supporting customized schema")

class SchemaDatapipe(CallableDatapipe):
  def __init__(self, schema: Dict[str, Any], batch_accessor: str) -> None:
    super().__init__()
    self._schema = schema
    self._batch_accessor = batch_accessor

  def conversion(self, data) -> RowBatch:
    # The current implementation is for demo purpose. 
    # Real implementation should be based on executing self._plan rather than calling specific functions
    batch = get_batch_accessor(self._batch_accessor).fetch_batch(data, None, 0)
    if hasattr(self, "_map_batch_fn"):
      batch = self._map_batch_fn(batch)
    return batch

  def __call__(self, data) -> RowBatch: 
    return self.conversion(data)

## Test Data

In [None]:
data={
    "float_features": {
        "f1": [1.0, 1.1, 1.2, 1.3, 1.4, 1.5],
        "f2": [2.0, 2.1, 2.2, 2.3, 2.4, 2.5],
    },
    "id_list_features": {
        "id1": [[111, 112], [121], [131, 132], [], [151], [161]],
        "id2": [[], [221], [222], [223, 224], [225], []],
        "id3": [[311], [], [331], [341], [351], [361]],
    },
    "id_score_list_features": {
        "ids1": [{411: 0.1}, {421: 0.2}, {431: 0.7}, {}, {451: 0.4, 452: 0.1}, {}],
    },
}


In [None]:
BATCH_ACCESSORS

{<BatchAccessor.ARROW: 'pydict_arrow_batch_accessor'>: __main__.PyDictArrowDataAccessor,
 <BatchAccessor.NATIVE: 'native_batch_accessor'>: __main__.PyDictNativeDataAccessor}

In [None]:
arrow_dp = PyDictDatapipe(
    data=data,
    batch_size=4,
    batch_accessor=BatchAccessor.ARROW,
)
# test single iter
rb = next(iter(arrow_dp.iter_batches(batch_size=4, start=0, stride=1)))
type(rb)
rb.row_num

4

In [None]:
import pyarrow.compute as F

print(rb["float_features"]["f1"])
print(F.ln(F.add(rb["float_features"]["f1"], 1)))

[
  1,
  1.1,
  1.2,
  1.3
]
[
  0.6931472,
  0.7419373,
  0.7884574,
  0.8329091
]


In [None]:
# test iter_batches
for batch in arrow_dp.iter_batches(batch_size=5):
  print(batch["float_features"]["f1"])
  print(batch.row_num)

[
  1,
  1.1,
  1.2,
  1.3,
  1.4
]
5
[
  1.5
]
1


In [None]:
# Tensor DP
tensor_dp = PyDictDatapipe(
    data=data,
    batch_size=4,
    batch_accessor=BatchAccessor.TORCH_TENSOR,
)

In [None]:
# test single iter
rb = next(iter(tensor_dp.iter_batches(batch_size=4, start=0, stride=1)))
type(rb)
rb.row_num

4

In [None]:
for batch in tensor_dp.iter_batches(batch_size=5):
  print(batch["float_features"]["f1"])
  print(batch.row_num)

JaggedTensor(offsets=None, indices=None, values=tensor([1.0000, 1.1000, 1.2000, 1.3000, 1.4000]))
5
JaggedTensor(offsets=None, indices=None, values=tensor([1.5000]))
1


## ReadingService
Specify the data reading execution engine. It should be STATELESS


* SimpleReadingService
* DisaggDppReadingService
* OnboxDppReadingService
* OnboxMultiprocessReadingService

In [None]:
class ReadingService(ABC):
  def __init__(self) -> None:
    
    return

  def initialize(self, dp: IterDatapipe) -> IterDatapipe:
    raise NotImplementedError("Not implement initialize()")

  def finalize(self) -> bool:
    return True

  def iter_batches(self, dp: IterDatapipe, state: Dict[str, Any]) -> RowBatchIterator:
    raise NotImplementedError("Not implement read_batch()")


class SimpleReadingService(ReadingService):
  def initialize(self, dp: Datapipe) -> IterDatapipe:
    return dp

  def iter_batches(self, dp: IterDatapipe, state: Dict[str, Any]) -> RowBatchIterator:
    yield from dp.iter_batches(batch_size=dp.batch_size, start=state["start"], stride=1)


## Dataloader
* Take Datapipe (describing business logic) and ReadingService (specifying execution engine) to produce data batches. 
* Dataloader is STATEFUL
* Provide checkpointing interface

In [None]:
class Dataloader:
  def __init__(self, datapipe: Datapipe, reading_service: ReadingService) -> None:
    self._dp = datapipe
    self._rs = reading_service
    self.reset()

  def reset(self) -> None:
    self._states = {
        "start": 0,
    }
    self._dp = self._rs.initialize(self._dp)

  def iter_batches(self) -> RowBatchIterator:
    # TODO: need to update self._states appropriately
    yield from self._rs.iter_batches(self._dp, self._states)

  def __iter__(self) -> "Dataloader":
    self.reset()
    return self

  def __next__(self) -> RowBatch:
    return next(self.iter_batches())

  def state_dict(self) -> Dict[str, Any]:
    return self._states

  def load_state_dict(self, states: Dict[str, Any]) -> bool:
    self._states = states
    return True

## Training

### Arrow Preproc

In [None]:
# Arrow Preproc
import pyarrow.compute as F

def collate(rows: RowBatch) -> RowBatch:
  return rows.to_torch()

# could be just Row
def arrow_preproc(rows: RowBatch) -> RowBatch:
    rows["float_features"]["f1"] = F.ln(F.add(rows["float_features"]["f1"], 3))
    # can collate in preproc. Or can collate in training loop.
    output = collate(rows)
    return output

In [None]:
arrow_dp = PyDictDatapipe(data=data)
# dp = dp.repeat(num=2)
# dp = dp.shuffle(shuffle_mode=ShuffleMode.SHAFFLE_ALL)
arrow_dp = arrow_dp.map_batches(arrow_preproc, batch_size=5, batch_accessor=BatchAccessor.ARROW)

rs = SimpleReadingService()

dl = Dataloader(arrow_dp, rs)

for idx, batch in enumerate(dl.iter_batches()):
  # print(f"==Batch {idx}, size {batch.row_num}==")
  # input = batch.to_torch()
  print(batch)

{'float_features': {'f1': JaggedTensor(offsets=None, indices=None, values=tensor([1.3863, 1.4110, 1.4351, 1.4586, 1.4816])), 'f2': JaggedTensor(offsets=None, indices=None, values=tensor([2.0000, 2.1000, 2.2000, 2.3000, 2.4000]))}, 'id_list_features': {'id1': JaggedTensor(offsets=tensor([0, 2, 3, 5, 5, 6], dtype=torch.int32), indices=tensor([111, 112, 121, 131, 132, 151]), values=None), 'id2': JaggedTensor(offsets=tensor([0, 0, 1, 2, 4, 5], dtype=torch.int32), indices=tensor([221, 222, 223, 224, 225]), values=None), 'id3': JaggedTensor(offsets=tensor([0, 1, 1, 2, 3, 4], dtype=torch.int32), indices=tensor([311, 331, 341, 351]), values=None)}, 'id_score_list_features': {'ids1': JaggedTensor(offsets=tensor([0, 1, 2, 3, 3, 5], dtype=torch.int32), indices=tensor([411, 421, 431, 451, 452]), values=tensor([0.1000, 0.2000, 0.7000, 0.4000, 0.1000]))}}
{'float_features': {'f1': JaggedTensor(offsets=None, indices=None, values=tensor([1.5041])), 'f2': JaggedTensor(offsets=None, indices=None, values



### Tensor Preproc

In [None]:
# could be just Row
def tensor_preproc(rows: RowBatch) -> RowBatch:
    rows["float_features"]["f1"].values = rows["float_features"]["f1"].values.add(3)
    rows["id_list_features"]["id1"].indices = torch.remainder(rows["id_list_features"]["id1"].indices, 100)
    # can collate in preproc. Or can collate in training loop.
    output = collate(rows)
    return output

tensor_dp = PyDictDatapipe(data=data)
# dp = dp.repeat(num=2)
# dp = dp.shuffle(shuffle_mode=ShuffleMode.SHAFFLE_ALL)
tensor_dp = tensor_dp.map_batches(tensor_preproc, batch_size=5, batch_accessor=BatchAccessor.TORCH_TENSOR)

rs = SimpleReadingService()

dl = Dataloader(tensor_dp, rs)

for idx, batch in enumerate(dl.iter_batches()):
  # print(f"==Batch {idx}, size {batch.row_num}==")
  # input = batch.to_torch()
  print(batch)

{'float_features': {'f1': JaggedTensor(offsets=None, indices=None, values=tensor([4.0000, 4.1000, 4.2000, 4.3000, 4.4000])), 'f2': JaggedTensor(offsets=None, indices=None, values=tensor([2.0000, 2.1000, 2.2000, 2.3000, 2.4000]))}, 'id_list_features': {'id1': JaggedTensor(offsets=tensor([0, 2, 3, 5, 5, 6], dtype=torch.int32), indices=tensor([11, 12, 21, 31, 32, 51]), values=None), 'id2': JaggedTensor(offsets=tensor([0, 0, 1, 2, 4, 5], dtype=torch.int32), indices=tensor([221, 222, 223, 224, 225]), values=None), 'id3': JaggedTensor(offsets=tensor([0, 1, 1, 2, 3, 4], dtype=torch.int32), indices=tensor([311, 331, 341, 351]), values=None)}, 'id_score_list_features': {'ids1': JaggedTensor(offsets=tensor([0, 1, 2, 3, 3, 5], dtype=torch.int32), indices=tensor([411, 421, 431, 451, 452]), values=tensor([0.1000, 0.2000, 0.7000, 0.4000, 0.1000]))}}
{'float_features': {'f1': JaggedTensor(offsets=None, indices=None, values=tensor([4.5000])), 'f2': JaggedTensor(offsets=None, indices=None, values=tenso

# Infernece

In [None]:
INPUT_SCHEMA = (
    ("float_features", Dict[str, List[float]]),
    ("float_features", Dict[str, List[List[int]]]),
    ("id_score_list_features", Dict[str, List[Dict[int, float]]]),
)
arrow_dp = SchemaDatapipe(schema=INPUT_SCHEMA, batch_accessor=BatchAccessor.ARROW)
darrow_dpp = arrow_dp.map_batches(arrow_preproc)
pred = arrow_dp(data)
pred

{'float_features': {'f1': JaggedTensor(offsets=None, indices=None, values=tensor([1.3863, 1.4110, 1.4351, 1.4586, 1.4816, 1.5041])),
  'f2': JaggedTensor(offsets=None, indices=None, values=tensor([2.0000, 2.1000, 2.2000, 2.3000, 2.4000, 2.5000]))},
 'id_list_features': {'id1': JaggedTensor(offsets=tensor([0, 2, 3, 5, 5, 6, 7], dtype=torch.int32), indices=tensor([111, 112, 121, 131, 132, 151, 161]), values=None),
  'id2': JaggedTensor(offsets=tensor([0, 0, 1, 2, 4, 5, 5], dtype=torch.int32), indices=tensor([221, 222, 223, 224, 225]), values=None),
  'id3': JaggedTensor(offsets=tensor([0, 1, 1, 2, 3, 4, 5], dtype=torch.int32), indices=tensor([311, 331, 341, 351, 361]), values=None)},
 'id_score_list_features': {'ids1': JaggedTensor(offsets=tensor([0, 1, 2, 3, 3, 5, 5], dtype=torch.int32), indices=tensor([411, 421, 431, 451, 452]), values=tensor([0.1000, 0.2000, 0.7000, 0.4000, 0.1000]))}}

In [None]:
tensor_dp = SchemaDatapipe(schema=INPUT_SCHEMA, batch_accessor=BatchAccessor.TORCH_TENSOR)
tensor_dp = tensor_dp.map_batches(tensor_preproc)
pred = tensor_dp(data)
pred

{'float_features': {'f1': JaggedTensor(offsets=None, indices=None, values=tensor([4.0000, 4.1000, 4.2000, 4.3000, 4.4000, 4.5000])),
  'f2': JaggedTensor(offsets=None, indices=None, values=tensor([2.0000, 2.1000, 2.2000, 2.3000, 2.4000, 2.5000]))},
 'id_list_features': {'id1': JaggedTensor(offsets=tensor([0, 2, 3, 5, 5, 6, 7], dtype=torch.int32), indices=tensor([11, 12, 21, 31, 32, 51, 61]), values=None),
  'id2': JaggedTensor(offsets=tensor([0, 0, 1, 2, 4, 5, 5], dtype=torch.int32), indices=tensor([221, 222, 223, 224, 225]), values=None),
  'id3': JaggedTensor(offsets=tensor([0, 1, 1, 2, 3, 4, 5], dtype=torch.int32), indices=tensor([311, 331, 341, 351, 361]), values=None)},
 'id_score_list_features': {'ids1': JaggedTensor(offsets=tensor([0, 1, 2, 3, 3, 5, 5], dtype=torch.int32), indices=tensor([411, 421, 431, 451, 452]), values=tensor([0.1000, 0.2000, 0.7000, 0.4000, 0.1000]))}}

# Archived

In [None]:



class Table(ABC, dict):
    def __init__(self, data: Dict[str, Any], block_size: int) -> None: # , start: int = 0, stride: int = 1) -> None:
        self._data = data
        self._block_size = block_size
        # self._start = start
        # self._stride = stride

    # def __iter__(self) -> RowBatchIterator:
    #     return RowBatchIterator(self, self._start, self._block_size, self._stride)

    # TODO: wonder if we should have iterator method in Table?
    #       considering Table is an "accessors of data", but not necessary storing the real data, having it seems reasonable
    def iter_batches(self, batch_size: int, start: int = 0, stride: int = 1) -> RowBatchIterator:
      return RowBatchIterator(self, start=start, batch_size=batch_size, stride=stride)

    # TODO: this fucntion will be called by RowBatchIterator.__next__(). Seems a bit cyclic...
    @abstractmethod
    def fetch_batch(self, start: int, batch_size: int) -> RowBatch:
      ...


class ArrowTable(Table):
    # def __next__(self) -> RowBatch:
    #     columns = ArrowTable.build_columns(self._data, self._block_size, self._start)
    #     self._start += self._block_size * self._stride
    #     return columns
    
    def fetch_batch(self, start: int, batch_size: int) -> RowBatch:
      return ArrowTable.build_columns(self._data, batch_size, start)

    @staticmethod
    def build_column_val(input: List[Any]) -> pyarrow.Array:
        child_type = struct_type(next(iter(input)))
        assert child_type != StructType.NON_LEAF, f"Invalid input {input}"
        # List[List[int]]
        if child_type == StructType.LEAF_ID_LIST:
            col = pyarrow.ListArray.from_arrays(
                offsets=pyarrow.array(compute_offsets(input, True)),
                values=list(itertools.chain(*input)),
            )
        # List[float]
        elif child_type == StructType.LEAF_FLOAT:
            col = pyarrow.array(input, type=pyarrow.float32())
        # List[Dict[int, float]]
        elif child_type == StructType.LEAF_ID_SCORE_MAP:
            col = pyarrow.MapArray.from_arrays(
                offsets=pyarrow.array(compute_offsets(input, True)),
                keys=list(itertools.chain(*[item.keys() for item in input])),
                items=list(itertools.chain(*[item.values() for item in input])),
            )
        return col

    @staticmethod
    def build_columns(input: Dict[str, Collection], block_size: Optional[int] = None, start: int = 0) -> RowBatch:
        assert len(input) > 0, f"input {input} cannot have empty columns"
        col_values = []
        row_num = block_size
        for val in input.values():
            if type(val) == list:
                block_size = block_size or len(val)
                end = min(start + block_size, len(val))
                row_num = end - start
                col_val = ArrowTable.build_column_val(val[start:end])
            elif type(val) == dict:
                col_val = ArrowTable.build_columns(val, block_size, start)
                row_num = col_val.row_num
            col_values.append(col_val)

        return RowBatch(data=zip(input.keys(), col_values), row_num=row_num)

    @staticmethod
    def from_pydict(input: Dict[str, Any], block_size: int) -> "ArrowTable":
        return ArrowTable(data=input, block_size=block_size)