In [4]:
import os
import json
import multiprocessing
from pathlib import Path

import polars as pl
import webdataset as wds
from tqdm import tqdm

from graphnet.data.constants import FEATURES, TRUTH


[1;34mgraphnet[0m: [32mINFO    [0m 2023-03-09 00:12:30 - get_logger - Writing log to [1mlogs/graphnet_20230309-001230.log[0m


In [11]:
import pandas as pd

df = pd.read_parquet(input_data_path.format(batch_id=51))

In [14]:
df.loc[162683598, ["time", "charge"]]

Unnamed: 0_level_0,time,charge
event_id,Unnamed: 1_level_1,Unnamed: 2_level_1
162683598,6491,0.525
162683598,6812,2.825
162683598,6853,0.975
162683598,7091,1.125
162683598,7147,0.575
...,...,...
162683598,16012,0.925
162683598,16286,1.075
162683598,16756,0.325
162683598,17207,0.925


In [5]:
meta_data_path = Path("../../input/icecube/icecube-neutrinos-in-deep-ice/train_meta.parquet")
geometry_path = Path("../../input/icecube/icecube-neutrinos-in-deep-ice/sensor_geometry.csv")
input_data_path = "../../input/icecube/icecube-neutrinos-in-deep-ice/train/batch_{batch_id}.parquet"
shard_dir = Path("../../input/webdatasets")
shard_dir.mkdir(exist_ok=True)

shard_filename = str(shard_dir / 'shards-%03d.tar')

shard_size = int(1 * 1024**3)

In [10]:
import datetime
from webdataset import TarWriter
from multiprocessing import Process
from tqdm import tqdm

def make_wds_shards(pattern, batch_ids, num_workers, map_func, **kwargs):
    meta_data = pl.read_parquet(meta_data_path)
    geometry_table = pl.read_csv(geometry_path, dtypes={"sensor_id": pl.Int16})
    print("Read meta data")
    processes = [
        Process(
            target=write_partial_samples,
            args=(
                pattern,
                batch_ids[::num_workers],
                meta_data,
                geometry_table,
                map_func,
                kwargs
            )
        )
        for i in range(num_workers)]
    
    for p in processes:
        p.start()
    for p in processes:
        p.join()


def write_partial_samples(pattern, batch_ids, meta_data, geometry_table, map_func, kwargs):
    for batch_id in batch_ids:
        meta_data_batch = meta_data.filter(pl.col("batch_id") == batch_id)
        event_ids = meta_data_batch["event_id"].unique()
        df_batch = pl.read_parquet(input_data_path.format(batch_id=batch_id))

        for event_id in tqdm(event_ids):
            write_samples_into_single_shard(pattern, meta_data_batch, batch_id, event_id, df_batch, geometry_table, map_func, kwargs)


def write_samples_into_single_shard(pattern, meta_data_batch, batch_id, event_id, df_batch, geometry_table, map_func, kwargs):
    fname = pattern % batch_id
    # print(f"[{datetime.datetime.now()}] start to write samples to shard {fname}")
    stream = TarWriter(fname, **kwargs)

    size = stream.write(map_func(event_id, meta_data_batch, df_batch, geometry_table))
    stream.close()
    # print(f"[{datetime.datetime.now()}] complete to write samples to shard {fname}")
    return size

def map_func(event_id, meta_data_batch, df_batch, geometry_table):
    truth = meta_data_batch.filter(pl.col("event_id") == event_id)[TRUTH.KAGGLE].to_numpy()
    features = (
        df_batch.filter(pl.col("event_id") == event_id)
            .join(geometry_table, on="sensor_id", how="left")[FEATURES.KAGGLE]
            .to_numpy()
    )
    return {
            "__key__": str(event_id),
            "pickle": (
                features, truth,
            )
        }

In [6]:
import pandas as pd
geometry_table = pd.read_csv(geometry_path)

In [8]:
geometry_table.loc[[0, 1, 0]]

Unnamed: 0,sensor_id,x,y,z
0,0,-256.14,-521.08,496.03
1,1,-256.14,-521.08,479.01
0,0,-256.14,-521.08,496.03


In [11]:
batch_ids = range(51, 101)


make_wds_shards(
    pattern=shard_filename,
    batch_ids=batch_ids,
    num_workers=8,
    map_func=map_func,
)

Read meta data


KeyboardInterrupt: 

In [None]:

# for batch_id in batch_ids:
#     meta_data_batch = meta_data.filter(pl.col("batch_id") == batch_id)
#     event_ids = meta_data_batch["event_id"].unique()
#     df_batch = pl.read_parquet(input_data_path.format(batch_id=batch_id))

#     for event_id in tqdm(event_ids, desc=f"Batch {batch_id}"):
#         features, truth = get_features_truth(event_id)
#         sink.write({
#             "__key__": str(event_id),
#             "pickle": (
#                 features, truth,
#             )
#         })
        
#     del df_batch, meta_data_batch