In [1]:
import numpy as np
from numpy.typing import NDArray
import atdata
from atdata.local import LocalDatasetEntry, S3DataStore, Index
import webdataset as wds

In [34]:
@atdata.packable
class TrainingSample:
    """A sample containing features and label for training."""
    features: NDArray
    label: int

from dataclasses import dataclass

@atdata.packable
class TextSample( atdata.PackableSample ):
    """A sample containing text data."""
    text: str
    category: str

In [3]:
x = TextSample(
    text = 'Hello',
    category = 'test',
)

---

In [4]:
# Connect to S3
store = S3DataStore( '.credentials/r2-analysis-hive.env',
    bucket = "analysis-hive"
)
print(f"Bucket: {store.bucket}")
print(f"Supports streaming: {store.supports_streaming()}")

# Connect to Redis
index = Index(
    data_store = store,
    auto_stubs = True,
)
print( "LocalIndex connected" )

Bucket: analysis-hive
Supports streaming: True
LocalIndex connected


In [6]:
list( index.list_schemas() )

[{'name': 'TrainingSample',
  'version': '1.0.0',
  'fields': [{'name': 'features',
    'fieldType': {'$type': 'local#ndarray', 'dtype': 'float32'},
    'optional': False},
   {'name': 'label',
    'fieldType': {'$type': 'local#primitive', 'primitive': 'int'},
    'optional': False}],
  '$ref': 'atdata://local/sampleSchema/TrainingSample@1.0.0',
  'description': 'A sample containing features and label for training.',
  'createdAt': '2026-01-22T22:01:47.560660+00:00'},
 {'name': 'TextSample',
  'version': '1.0.1',
  'fields': [{'name': 'text',
    'fieldType': {'$type': 'local#primitive', 'primitive': 'str'},
    'optional': False},
   {'name': 'category',
    'fieldType': {'$type': 'local#primitive', 'primitive': 'str'},
    'optional': False}],
  '$ref': 'atdata://local/sampleSchema/TextSample@1.0.1',
  'description': 'A sample containing text data.',
  'createdAt': '2026-01-22T22:09:51.907476+00:00'}]

In [7]:
s = next( index.schemas )

In [8]:
s.ref

'atdata://local/sampleSchema/TrainingSample@1.0.0'

In [9]:
# Publish a schema
schema_ref = index.publish_schema( TrainingSample, version="1.0.0")
print(f"Published schema: {schema_ref}")

# List all schemas
for schema in index.list_schemas():
    print(f"  - {schema.get('name', 'Unknown')} v{schema.get('version', '?')}")

# Get schema record
schema_record = index.get_schema(schema_ref)
print(f"Schema fields: {[f['name'] for f in schema_record.get('fields', [])]}")

# Decode schema back to a PackableSample class
decoded_type = index.decode_schema(schema_ref)
print(f"Decoded type: {decoded_type.__name__}")

Published schema: atdata://local/sampleSchema/TrainingSample@1.0.0
  - TrainingSample v1.0.0
  - TextSample v1.0.1
Schema fields: ['features', 'label']
Decoded type: TrainingSample


In [10]:
# Publish a schema
schema_ref_2 = index.publish_schema(TextSample, version="1.0.1")
print(f"Published schema: {schema_ref_2}")

# List all schemas
for schema in index.list_schemas():
    print(f"  - {schema.get('name', 'Unknown')} v{schema.get('version', '?')}")

# Get schema record
schema_record = index.get_schema(schema_ref_2)
print(f"Schema fields: {[f['name'] for f in schema_record.get('fields', [])]}")

# Decode schema back to a PackableSample class
decoded_type = index.decode_schema(schema_ref_2)
print(f"Decoded type: {decoded_type.__name__}")

Published schema: atdata://local/sampleSchema/TextSample@1.0.1
  - TrainingSample v1.0.0
  - TextSample v1.0.1
Schema fields: ['text', 'category']
Decoded type: TextSample


In [12]:
from typing import TypeVar, TypeAlias, Generic, Callable, Any

S = TypeVar( 'S', bound = atdata.PackableSample )
V = TypeVar( 'V', bound = atdata.PackableSample )

FromAnyTo = Callable[[Any], V]

def make_local_lens( f: FromAnyTo[V], remote: type[S], local: type[V] ) -> atdata.Lens[S, V]:
    """TODO"""
    @atdata.lens
    def _to_local( s: S ) -> V:
        return f( s )
    return _to_local

In [18]:
index.load_schema( 'atdata://local/sampleSchema/TextSample@1.0.1' )
TextSampleRemote = index.types.TextSample

In [20]:
x = TextSampleRemote(
    text = 'hello',
    category = 'test',
)

In [26]:
def _to_text_sample( s: Any ) -> TextSample:
    return TextSample(
        text = s.text,
        category = s.category,
    )

l = make_local_lens( _to_text_sample, TextSampleRemote, TextSample )

In [27]:
y = l( x )

---

In [29]:
import webdataset as wds
from uuid import uuid4

data_pattern = 'data/TextSample_test-%06d.tar'

with wds.writer.ShardWriter( data_pattern, maxcount = 1_000 ) as sink:
    for i in range( 10_000 ):
        new_sample = TextSample(
            text = str( uuid4() ),
            category = 'test',
        )
        sink.write( new_sample.as_wds )

# writing data/TextSample_test-000000.tar 0 0.0 GB 0
# writing data/TextSample_test-000001.tar 1000 0.0 GB 1000
# writing data/TextSample_test-000002.tar 1000 0.0 GB 2000
# writing data/TextSample_test-000003.tar 1000 0.0 GB 3000
# writing data/TextSample_test-000004.tar 1000 0.0 GB 4000
# writing data/TextSample_test-000005.tar 1000 0.0 GB 5000
# writing data/TextSample_test-000006.tar 1000 0.0 GB 6000
# writing data/TextSample_test-000007.tar 1000 0.0 GB 7000
# writing data/TextSample_test-000008.tar 1000 0.0 GB 8000
# writing data/TextSample_test-000009.tar 1000 0.0 GB 9000


In [36]:
from atdata import load_dataset

ds = (
    load_dataset( 'data/TextSample_test-{000000..000009}.tar',
        split = 'test'
    )
    .as_type( TextSample )
)

In [40]:
x = next( iter( ds.ordered() ) )

In [41]:
x

TextSample(text='d06a8072-5833-4867-9bc6-03baa3cee75b', category='test')

In [42]:
entry = index.insert_dataset( ds, 
    name = 'proto-text-samples-3',
    prefix = 'prototyping',
    schema_ref = 'atdata://local/sampleSchema/TextSample@1.0.1',
)

# writing analysis-hive/prototyping/data--2b5dd738-c9f6-46d4-8c31-e218531601be--000000.tar 0 0.0 GB 0


In [None]:
entry

LocalDatasetEntry(name='proto-text-samples-3', schema_ref='atdata://local/sampleSchema/TextSample@1.0.1', data_urls=['s3://analysis-hive/prototyping/data--2b5dd738-c9f6-46d4-8c31-e218531601be--000000.tar'], metadata=None)

Notes:

* We should make sure that the `s3` URI-scheme here is properly used
    * Should we be using the `https` URI since actually this is doing data streaming with `wds`? Or does this indicate that we should think more deeply about the `Dataset` API design and generalizing how we're setting up the `wds` data streaming ...
    * No matter what, we're definitely going to want to make sure that we incorporate the actual host details of the `LocalIndex`'s `S3DataStore` for this, since the S3 host is definitely not local.
    * Should there be underscores here? These feel like public properties ...

---

In [45]:
ds

<atdata.dataset.Dataset at 0x114eb9160>

In [None]:
from atdata import load_dataset

# Load from local index
ds = load_dataset( "@local/proto-text-samples-3", TextSample,
    split = 'train',
    #
    index = index,
)

# The index resolves the dataset name to URLs and schema
for batch in ds.shuffled():
    break

OSError: ("((['curl', '--connect-timeout', '30', '--retry', '30', '--retry-delay', '2', '-f', '-s', '-L', 'https://f5bf77c06cb35b5136ff6d61ab4b7dbc.r2.cloudflarestorage.com/analysis-hive/prototyping/data--2b5dd738-c9f6-46d4-8c31-e218531601be--000000.tar'],), {'bufsize': 8192}): exit 22 (read) {}", <webdataset.gopen.Pipe object at 0x11425e150>, 'https://f5bf77c06cb35b5136ff6d61ab4b7dbc.r2.cloudflarestorage.com/analysis-hive/prototyping/data--2b5dd738-c9f6-46d4-8c31-e218531601be--000000.tar')

Notes:

* This is also getting linting errors on `load_dataset` that there are no matching overloads.

In [36]:
ds.url

's3://analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar'

Notes:

* We're getting linting errors because of the protocol use for `AbstractIndex`; better to subclass, or is there a way for this to get the protocol adherence?
* The S3 URI error is showing up here now because of how dataset loading works! The data is uploaded correctly on my end, but it can't be accessed because of this URI not being the correct way to access the data for `wds` streaming over `https`; we should think of how best to encode this!