In [None]:
import dataclasses

import jax
from jax.sharding import Mesh

from deeprte.input_pipeline import input_pipeline_interface
from deeprte.train_lib import utils
from deeprte.train_lib.multihost_dataloading import prefetch_to_device

In [None]:
@dataclasses.dataclass
class Config:
    # Data
    dataset_type: str = "tfds"
    dataset_name: str = "rte"
    data_dir: str = "/workspaces/deeprte/data/tfds"
    train_split: str = "train[80%:]"
    enable_data_shuffling: bool = True
    data_shuffle_seed: int = 42
    prefetch_to_device: bool = True

    # Parallelism
    mesh_axes: tuple[str, ...] = ("data", "fsdp", "tensor")
    data_partitions: tuple[str, ...] = (("data", "fsdp", "tensor"),)
    # One axis for each parallelism type may hold a placeholder (-1)
    # value to auto-shard based on available slices and devices.
    # By default, product of the DCN axes should equal number of slices
    # and product of the ICI axes should equal number of devices per slice.
    # ICI (Inter-Chip Interconnection): A high-speed connection between
    # sets of TPU chips, which form the TPU network.
    # DCN (Data Center Network): A connection between the TPU networks;
    # not as fast as ICI.
    # ICI has around 100x the bandwidth of DCN, but it is not a general
    # purpose connection, which is why DCN is necessary for scaling to
    # extremely large ML models.
    dcn_data_parallelism: int = -1
    dcn_fsdp_parallelism: int = 1
    dcn_tensor_parallelism: int = 1
    ici_data_parallelism: int = 1
    ici_fsdp_parallelism: int = -1
    ici_tensor_parallelism: int = 1

    # Train
    global_batch_size_to_load: int = 8
    global_batch_size_to_train_on: int = global_batch_size_to_load
    collocation_sizes: tuple[int] = (128,)
    repeat_batch: int = 1
    expansion_factor_real_data: int = -1

    # Evaluation
    eval_interval: int = -1


config = Config()

In [None]:
device_array = utils.create_device_mesh(config)
global_mesh = Mesh(device_array, config.mesh_axes)
global_mesh

In [None]:
train_iter, eval_iter = input_pipeline_interface.create_data_iterator(
    config, global_mesh
)

if config.prefetch_to_device:
    train_iter = prefetch_to_device(train_iter, 4)

In [None]:
jax.tree.map(lambda x: (x.shape, x.sharding), next(train_iter))

In [None]:
next(train_iter)["psi_label"].sharding.spec