In [1]:
from atek.dataset.atek_webdataset import create_atek_webdataset, create_wds_dataloader
import yaml
import os
import webdataset as wds

In [2]:
tars_yaml = "/source_1a/data/atek/adt_503849115/train_debug.yaml"

def get_tars(tar_yaml, use_relative_path=False):
    with open(tar_yaml, "r") as f:
        tar_files = yaml.safe_load(f)["tars"]
    if use_relative_path:
        data_dir = os.path.dirname(tar_yaml)
        tar_files = [os.path.join(data_dir, x) for x in tar_files]
    return tar_files

tars = get_tars(tars_yaml, True)
print(len(tars))

828


In [3]:
small_tars = tars[:10]
print(small_tars)

['/source_1a/data/atek/adt_503849115/1WM10360071292_optitrack_release_multiskeleton_party_seq121/shards-0015.tar', '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0000.tar', '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0001.tar', '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0002.tar', '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0003.tar', '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0004.tar', '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0005.tar', '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0006.tar', '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0007.tar', '/source_1a/data/atek/adt_503849115/1

In [4]:
import webdataset as wds

def filter_keys(k: str)->str:
    return k == "FSG+T_world_local"

def remap_keys(k: str)->str:
    key_map = {"FSG+T_world_local" : "T"}
    if k in key_map.keys():
        return key_map[k]
    else:
        return k

def create_dup(data):
    for sample in data:
        for i in range(2):
            yield sample

ds = create_atek_webdataset(
    urls = small_tars, 
    batch_size = 2, 
    nodesplitter=wds.shardlists.split_by_node,
    select_key_fn = filter_keys, 
    remap_key_fn = remap_keys,
    data_transform_fn = wds.filters.pipelinefilter(create_dup)(),
    )

# ds = ds.compose(wds.filters.log_keys())
# ds = ds.compose(wds.filters.pipelinefilter(create_dup)())

for obj in ds:
    print(obj)
    break

{'__key__': ['fsg_000480', 'fsg_000480'], '__url__': ['/source_1a/data/atek/adt_503849115/1WM10360071292_optitrack_release_multiskeleton_party_seq121/shards-0015.tar', '/source_1a/data/atek/adt_503849115/1WM10360071292_optitrack_release_multiskeleton_party_seq121/shards-0015.tar'], 'T': tensor([[[-0.7943, -0.2227,  0.5653, -2.5847],
         [ 0.0740, -0.9590, -0.2737,  1.5859],
         [ 0.6030, -0.1756,  0.7782, -2.0159]],

        [[-0.7943, -0.2227,  0.5653, -2.5847],
         [ 0.0740, -0.9590, -0.2737,  1.5859],
         [ 0.6030, -0.1756,  0.7782, -2.0159]]])}


In [5]:
dataloader = create_wds_dataloader(ds, num_workers=4, pin_memory=False)

In [6]:
from tqdm import tqdm
sample_count = 0
per_url_smaple_count = {}
for data in tqdm(dataloader):
    urls = data["__url__"]
    for url in urls:
        if url in per_url_smaple_count:
            per_url_smaple_count[url] += 1
        else:
            per_url_smaple_count[url] = 1
    sample_count += len(urls)
print(sample_count)
print(per_url_smaple_count)

289it [00:02, 117.96it/s]

578
{'/source_1a/data/atek/adt_503849115/1WM10360071292_optitrack_release_multiskeleton_party_seq121/shards-0015.tar': 2, '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0000.tar': 64, '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0001.tar': 64, '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0002.tar': 64, '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0003.tar': 64, '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0004.tar': 64, '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0005.tar': 64, '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0006.tar': 64, '/source_1a/data/atek/adt_503849115/1WM103600M1292_optitrack_release_recognition_seq133/shards-0007.tar': 64




In [7]:
from typing import List, Optional, Dict
import torch

from detectron2.data import detection_utils
from detectron2.structures import Boxes, BoxMode, Instances

from webdataset.filters import pipelinefilter
from atek.dataset.omni3d_adapter import create_omni3d_webdataset


In [8]:
omni3d_wds = create_omni3d_webdataset(small_tars, batch_size=2)
dataloader = create_wds_dataloader(omni3d_wds, num_workers=10, pin_memory=False)

from tqdm import tqdm
sample_count = 0
for data in tqdm(dataloader):
    sample_count += len(data)
    pass

print(sample_count)

145it [00:01, 77.84it/s] 

289



