# Data loading example

In [None]:
from atek.dataset.atek_webdataset import create_atek_webdataset, create_wds_dataloader
import yaml
import os
import webdataset as wds

In [None]:
tars = [
    "/source/data/atek/demo/shards-0000.tar",
    "/source/data/atek/demo/shards-0001.tar",
    "/source/data/atek/demo/shards-0002.tar",
    "/source/data/atek/demo/shards-0003.tar",
]

## Default native loading

In [None]:
ds = create_atek_webdataset(
    urls = tars, 
    batch_size = 2, 
    nodesplitter=wds.shardlists.split_by_node,
    select_key_fn = None,
    remap_key_fn = None,
    data_transform_fn = None,
)

for obj in ds:
    print(obj.keys())
    break

In [None]:
print("RGB image tensor shape: ", obj['f#214-1+image'].shape)

In [None]:
dataloader = create_wds_dataloader(ds, num_workers=4, pin_memory=False)

In [None]:
from tqdm import tqdm
sample_count = 0
per_url_smaple_count = {}
for data in tqdm(dataloader):
    urls = data["__url__"]
    for url in urls:
        if url in per_url_smaple_count:
            per_url_smaple_count[url] += 1
        else:
            per_url_smaple_count[url] = 1
    sample_count += len(urls)
print("Total samples loaded: ",sample_count)
print(per_url_smaple_count)

## Data transform Example

In [None]:
target_key_map = {
    "f#214-1+image": "rgb_image",
}

def simple_key_selection_fn(key: str) -> bool:
    return key in target_key_map.keys()

def simple_key_remap(key: str)-> str:
    if key in target_key_map.keys():
        return target_key_map[key]
    else:
        return key

def simple_transform_fn(data):
    " Very naive implementation to use the first rgb image in the group"
    for sample in data:
        new_sample = {}
        for k, v in sample.items():
            if k == 'rgb_image':
                new_sample[k] = v[0]
            else:
                new_sample[k] = v
        yield new_sample


# Note that the transform order is 
# urls-> full atek dict-> [select keys] -> [remap keys] -> [more transform] -> [batch collation]
transformed_ds = create_atek_webdataset(
    urls = tars, 
    batch_size = 2, 
    nodesplitter=wds.shardlists.split_by_node,
    select_key_fn = simple_key_selection_fn,
    remap_key_fn = simple_key_remap,
    data_transform_fn = wds.filters.pipelinefilter(simple_transform_fn)(),
)

In [None]:
trans_obj = next(iter(transformed_ds))
print(trans_obj.keys())
print(trans_obj["rgb_image"].shape)

## Omni3d Loading for CubeRCNN

In [None]:
from typing import List, Optional, Dict
import torch

from detectron2.data import detection_utils
from detectron2.structures import Boxes, BoxMode, Instances

from atek.dataset.omni3d_adapter import create_omni3d_webdataset


In [None]:
omni3d_wds = create_omni3d_webdataset(tars, batch_size=2)
dataloader = create_wds_dataloader(omni3d_wds, num_workers=4, pin_memory=False)

from tqdm import tqdm
sample_count = 0
for data in tqdm(dataloader):
    sample_count += len(data)
    pass

print(sample_count)

In [None]:
print(data[0].keys())
print("Image shape: ", data[0]['image'].shape)
print("K: ", data[0]['K'])
