In [1]:
import numpy as np
from numpy.typing import NDArray
import atdata
from atdata.local import LocalDatasetEntry, S3DataStore, Index
import webdataset as wds

In [2]:
@atdata.packable
class TrainingSample:
    """A sample containing features and label for training."""
    features: NDArray
    label: int

@atdata.packable
class TextSample:
    """A sample containing text data."""
    text: str
    category: str

x = TextSample(
    text = 'Hello',
    category = 'test',
)

---

In [3]:
from redis import Redis

# Connect to S3
store = S3DataStore( '.credentials/r2-analysis-hive.env',
    bucket = "analysis-hive"
)

print(f"Bucket: {store.bucket}")
print(f"Supports streaming: {store.supports_streaming()}")

# Connect to Redis
index = Index(
    data_store = store,
    auto_stubs = True,
)

print("LocalIndex connected")

Bucket: analysis-hive
Supports streaming: True
LocalIndex connected


TextSample = index.decode_schema( 'atdata://local/sampleSchema/TextSample@1.0.1' )

In [4]:
x = TextSample(
    text = 'hello',
    category = 'test',
)

In [5]:
# Publish a schema
schema_ref = index.publish_schema(TrainingSample, version="1.0.0")
print(f"Published schema: {schema_ref}")

# List all schemas
for schema in index.list_schemas():
    print(f"  - {schema.get('name', 'Unknown')} v{schema.get('version', '?')}")

# Get schema record
schema_record = index.get_schema(schema_ref)
print(f"Schema fields: {[f['name'] for f in schema_record.get('fields', [])]}")

# Decode schema back to a PackableSample class
decoded_type = index.decode_schema(schema_ref)
print(f"Decoded type: {decoded_type.__name__}")

Published schema: atdata://local/sampleSchema/TrainingSample@1.0.0
  - TrainingSample v1.0.0
Schema fields: ['features', 'label']
Decoded type: TrainingSample


In [6]:
# Publish a schema
schema_ref_2 = index.publish_schema(TextSample, version="1.0.1")
print(f"Published schema: {schema_ref_2}")

# List all schemas
for schema in index.list_schemas():
    print(f"  - {schema.get('name', 'Unknown')} v{schema.get('version', '?')}")

# Get schema record
schema_record = index.get_schema(schema_ref_2)
print(f"Schema fields: {[f['name'] for f in schema_record.get('fields', [])]}")

# Decode schema back to a PackableSample class
decoded_type = index.decode_schema(schema_ref_2)
print(f"Decoded type: {decoded_type.__name__}")

Published schema: atdata://local/sampleSchema/TextSample@1.0.1
  - TrainingSample v1.0.0
  - TextSample v1.0.1
Schema fields: ['text', 'category']
Decoded type: TextSample


In [7]:
del TextSample

In [8]:
index.load_schema( 'atdata://local/sampleSchema/TextSample@1.0.1' )

_atdata_generated_TextSample_1_0_1.TextSample

In [9]:
TextSample = index.types.TextSample

In [12]:
x = TextSample(
    text = 'hello',
    category = 'test',
)

In [None]:
@atdata.packable
class LocalTextSample:
    content: str
    "Test"
    category: str
    "stuff"

@atdata.lens
def _convert_text_sample( s: TextSample ) -> LocalTextSample:
    return LocalTextSample(
        content = s.text,
        category = s.category,
    )

Notes:

* We get linting errors here on `@atdata.lens` because `LocalTextSample` doesn't show up as a subclass of `PackableSample`; is there a way to resolve this?

In [22]:
y = _convert_text_sample( x )

---

In [24]:
import webdataset as wds
from uuid import uuid4

data_pattern = 'data/TextSample_test-%06d.tar'

with wds.writer.ShardWriter( data_pattern, maxcount = 1_000 ) as sink:
    for i in range( 10_000 ):
        new_sample = TextSample(
            text = str( uuid4() ),
            category = 'test',
        )
        sink.write( new_sample.as_wds )

# writing data/TextSample_test-000000.tar 0 0.0 GB 0
# writing data/TextSample_test-000001.tar 1000 0.0 GB 1000
# writing data/TextSample_test-000002.tar 1000 0.0 GB 2000
# writing data/TextSample_test-000003.tar 1000 0.0 GB 3000
# writing data/TextSample_test-000004.tar 1000 0.0 GB 4000
# writing data/TextSample_test-000005.tar 1000 0.0 GB 5000
# writing data/TextSample_test-000006.tar 1000 0.0 GB 6000
# writing data/TextSample_test-000007.tar 1000 0.0 GB 7000
# writing data/TextSample_test-000008.tar 1000 0.0 GB 8000
# writing data/TextSample_test-000009.tar 1000 0.0 GB 9000


In [30]:
from atdata import Dataset
ds = Dataset[TextSample]( 'data/TextSample_test-{000000..000009}.tar' )
x = next( iter( ds.ordered( batch_size = None ) ) )

Notes:

* We should make the default for `Dataset.ordered` and `Dataset.shuffled` be to have `batch_size` be `None`, rather than 1.

In [32]:
entry = index.insert_dataset( ds, 
    name = 'proto-text-samples-2',
    prefix = 'prototyping',
    schema_ref = 'atdata://local/sampleSchema/TextSample@1.0.1',
)

# writing analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar 0 0.0 GB 0


In [33]:
entry

LocalDatasetEntry(_name='proto-text-samples-2', _schema_ref='atdata://local/sampleSchema/TextSample@1.0.1', _data_urls=['s3://analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar'], _metadata=None)

Notes:

* We should make sure that the `s3` URI-scheme here is properly used
    * Should we be using the `https` URI since actually this is doing data streaming with `wds`? Or does this indicate that we should think more deeply about the `Dataset` API design and generalizing how we're setting up the `wds` data streaming ...
    * No matter what, we're definitely going to want to make sure that we incorporate the actual host details of the `LocalIndex`'s `S3DataStore` for this, since the S3 host is definitely not local.
    * Should there be underscores here? These feel like public properties ...

---

In [35]:
from atdata import load_dataset

# Load from local index
ds = load_dataset( "@local/proto-text-samples-2",
    index = index,
    split = 'train',
)

# The index resolves the dataset name to URLs and schema
for batch in ds.shuffled(batch_size=32):
    break

ValueError: ('s3://analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar: no gopen handler defined', 's3://analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar')

Notes:

* This is also getting linting errors on `load_dataset` that there are no matching overloads.

In [36]:
ds.url

's3://analysis-hive/prototyping/data--4a5ff662-803b-4700-81f4-45f288f6e565--000000.tar'

Notes:

* We're getting linting errors because of the protocol use for `AbstractIndex`; better to subclass, or is there a way for this to get the protocol adherence?
* The S3 URI error is showing up here now because of how dataset loading works! The data is uploaded correctly on my end, but it can't be accessed because of this URI not being the correct way to access the data for `wds` streaming over `https`; we should think of how best to encode this!