In [1]:
%load_ext autoreload
%autoreload 2
import contex

In [200]:
from typing import Callable
import pandas as pd
import numpy as np

from collections import defaultdict
from functools import partial
from torchdata.datapipes import functional_datapipe
import torchdata.datapipes.iter as pipes
import torch
# import torchdata.datapipes as dp


In [11]:
@functional_datapipe("rolling")
class RollingWindow(pipes.IterDataPipe):
    """From https://github.com/tcapelle/torchdata/blob/main/02_Custom_timeseries_datapipe.ipynb """ 

    def __init__(self, source_dp: pipes.IterDataPipe, window_size, step=1) -> None:
        super().__init__()
        self.source_dp = source_dp
        self.window_size = window_size
        self.step = step
    
    def __iter__(self):
        it = iter(self.source_dp)
        cur = []
        while True:
            try:
                while len(cur) < self.window_size:
                    cur.append(next(it))
                yield np.array(cur)
                for _ in range(self.step):
                    if cur:
                        cur.pop(0)
                    else:
                        next(it)
            except StopIteration:
                return

In [3]:
@functional_datapipe("parse_pandas_dataframe")
class PandasParserIterDataPipe(pipes.IterDataPipe):
    def __init__(self, df) -> None:
        self.source_df = df

    def __iter__(self):
         for row in self.source_df.iterrows():
            yield list(row[1])

In [4]:
@functional_datapipe('rolling_groupby')
class RollingGrouperIterDataPipe(pipes.IterDataPipe):
    r"""
    """
    def __init__(self,
                 datapipe: pipes.IterDataPipe,
                 group_key_fn: Callable,
                 *,
                 window_size=1,
                 step_size=1,
                 buffer_size: int = 10000,
                 # group_size: Optional[int] = None,
                 # guaranteed_group_size: Optional[int] = None,
                 # drop_remaining: bool = False
                ):
        # check_lambda_fn(group_key_fn)
        self.datapipe = datapipe
        self.group_key_fn = group_key_fn

        self.window_size = window_size
        self.step_size = step_size
        
        self.group_size = window_size
        self.buffer_size = buffer_size
        self.guaranteed_group_size = None
        if self.group_size is not None and buffer_size is not None:
            assert 0 < self.group_size <= buffer_size
            self.guaranteed_group_size = self.group_size
        #if guaranteed_group_size is not None:
        #    assert group_size is not None and 0 < guaranteed_group_size <= group_size
        #    self.guaranteed_group_size = guaranteed_group_size
        self.drop_remaining = True
        # self.wrapper_class = DataChunk

    def _remove_biggest_key(self, buffer_elements, buffer_size):
        biggest_key = None
        biggest_size = 0
        result_to_yield = None
        for findkey in buffer_elements.keys():
            if len(buffer_elements[findkey]) > biggest_size:
                biggest_size = len(buffer_elements[findkey])
                biggest_key = findkey

        if self.guaranteed_group_size is not None and biggest_size < self.guaranteed_group_size and not self.drop_remaining:
            raise RuntimeError('Failed to group items', str(buffer_elements[biggest_key]))

        if self.guaranteed_group_size is None or biggest_size >= self.guaranteed_group_size:
            result_to_yield = buffer_elements[biggest_key]

        new_buffer_size = buffer_size - biggest_size
        del buffer_elements[biggest_key]

        return result_to_yield, new_buffer_size

    def __iter__(self):
        buffer_elements: DefaultDict[Any, List] = defaultdict(list)
        buffer_size = 0
        for x in self.datapipe:
            key = self.group_key_fn(x)

            buffer_elements[key].append(x)
            buffer_size += 1

            if self.group_size is not None and self.group_size == len(buffer_elements[key]):
                #yield self.wrapper_class(buffer_elements[key])
                yield buffer_elements[key]
                if self.step_size < self.window_size:
                    del buffer_elements[key][:self.step_size]
                    buffer_size -= self.step_size
                else:
                    del buffer_elements[key]
                    buffer_size -= self.window_size

            if buffer_size == self.buffer_size:
                (result_to_yield, buffer_size) = self._remove_biggest_key(buffer_elements, buffer_size)
                if result_to_yield is not None:
                    yield result_to_yield

In [214]:
df = pd.read_parquet("../data/stallion.parquet")
# add time index
df["time_idx"] = df["date"].dt.year * 12 + df["date"].dt.month
df["time_idx"] -= df["time_idx"].min()

# add additional features
df["month"] = df.date.dt.month.astype(str).astype("category")  # categories have be strings
df["log_volume"] = np.log(df.volume + 1e-8)
df["avg_volume_by_sku"] = df.groupby(["time_idx", "sku"], observed=True).volume.transform("mean")
df["avg_volume_by_agency"] = df.groupby(["time_idx", "agency"], observed=True).volume.transform("mean")
time_idx="time_idx"
target="volume"
group_ids=["agency", "sku"]
df = df.sort_values(by="time_idx")
df = df[[time_idx, target] + group_ids]
df = df.reset_index(drop=True)
df.head(10), df.shape

(   time_idx       volume     agency     sku
 0         0    52.272000  Agency_22  SKU_01
 1         0  3324.269700  Agency_32  SKU_04
 2         0   110.700000  Agency_22  SKU_02
 3         0     0.000000  Agency_58  SKU_23
 4         0    28.320000  Agency_48  SKU_07
 5         0   238.538700  Agency_22  SKU_05
 6         0     0.000000  Agency_58  SKU_17
 7         0   126.360000  Agency_31  SKU_01
 8         0   475.790396  Agency_48  SKU_02
 9         0     1.150200  Agency_40  SKU_04,
 (21000, 4))

In [215]:
datapipe = PandasParserIterDataPipe(df)
for x in datapipe:
    print(x)
    break

[0, 52.272, 'Agency_22', 'SKU_01']


In [218]:
ds = datapipe.rolling_groupby(group_key_fn=lambda x: x[2] + x[3], window_size=1, step_size=1)

In [219]:
for i, x in enumerate(ds):
    print()
    print(x)
    if i > 2:
        break


[[0, 52.272, 'Agency_22', 'SKU_01']]

[[0, 3324.2697, 'Agency_32', 'SKU_04']]

[[0, 110.7, 'Agency_22', 'SKU_02']]

[[0, 0.0, 'Agency_58', 'SKU_23']]


In [158]:
hashed = hash("".join([ str(y) for y in x[0]]))
print(hashed, format(hashed, 'b'))
%timeit (hashed % 5) > 2

-2558461773174316290 -10001110000001011110110110110101101001011110001001110100000010
41.2 ns ± 0.199 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)


In [83]:
hashed2 = 6405261876
print(hashed2, format(hashed2, 'b'))
%timeit (hashed2 % 5) > 2

6405261876 101111101110010001000101000110100
35.5 ns ± 0.745 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)


In [84]:
hashed2 = 640526187
print(hashed2, format(hashed2, 'b'))
%timeit (hashed2 % 5) > 2

640526187 100110001011011010011101101011
24.4 ns ± 0.0304 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)


In [85]:
hashed2 = 61
print(hashed2, format(hashed2, 'b'))
%timeit (hashed2 % 5) > 2

61 111101
24.5 ns ± 0.141 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)


In [220]:
def hash_split(x, which_id=0):
    hashed = hash("".join([ str(y) for y in x[0]]))
    perc = hashed % 100 
    if perc < 70:
        return 0
    if perc < 85:
        return 1
    return 2

train, test, val = ds.demux(num_instances=3, classifier_fn=hash_split, buffer_size=-1)
# train, test, val = ds.demux(num_instances=3, classifier_fn=hash_split)



In [221]:
try:
    for i, _ in enumerate(train):
        length = i
except BufferError:
    print("Buffer Error")
print(length)

try:
    for i, _ in enumerate(test):
        length = i
except BufferError:
    print("Buffer Error")
print(length)

try:
    for i, _ in enumerate(val):
        length = i
except BufferError:
    print("Buffer Error")
print(length)

14665
3160
3172


In [224]:
def hash_split(x, which_id=0):
    hashed = hash("".join([ str(y) for y in x[0]]))
    perc = hashed % 100 
    split_id = 0
    if perc >= 70:
        split_id =  1
    if perc >= 85:
        split_id = 2
    return split_id == which_id

train_split = partial(hash_split, which_id = 0)
test_split = partial(hash_split, which_id = 1)
val_split = partial(hash_split, which_id = 2)

train, test, val = ds.filter(filter_fn=train_split), ds.filter(filter_fn=test_split), ds.filter(filter_fn=val_split)

In [225]:
try:
    for i, _ in enumerate(train):
        length = i
except BufferError:
    print("Buffer Error")
print(length)

try:
    for i, _ in enumerate(test):
        length = i
except BufferError:
    print("Buffer Error")
print(length)

try:
    for i, _ in enumerate(val):
        length = i
except BufferError:
    print("Buffer Error")
print(length)

14665
3160
3172


In [162]:
next(iter(train)), next(iter(test)), next(iter(val)), next(iter(ds))

([[0, 375.1569, 'Agency_18', 'SKU_05'], [1, 386.595, 'Agency_18', 'SKU_05']],
 [[0, 161.2836, 'Agency_15', 'SKU_05'], [1, 182.8818, 'Agency_15', 'SKU_05']],
 [[0, 1618.8426, 'Agency_12', 'SKU_05'], [1, 1681.017, 'Agency_12', 'SKU_05']],
 [[0, 375.1569, 'Agency_18', 'SKU_05'], [1, 386.595, 'Agency_18', 'SKU_05']])

In [110]:
next(iter(train.shuffle()))

[[-9, 24.4287, 'Agency_25', 'SKU_04'], [-8, 39.1707, 'Agency_25', 'SKU_04']]

In [77]:
df.time_idx.max()

59

In [185]:
def stratified_split(x):
    time_idx = x[0][0]
    if time_idx < 40:
        return 0
    if time_idx  < 50:
        return 1
    return 2
    
stratified_split(next(iter(ds)))

0

In [208]:
train, test, val = ds.demux(num_instances=3, classifier_fn=stratified_split, buffer_size=-1)
# train, test, val = ds.demux(num_instances=3, classifier_fn=stratified_split)



In [197]:
train, test, val = ds.fork(num_instances=3)

In [209]:
for i, _ in enumerate(train):
    length = i
print(length)
for i, _ in enumerate(test):
    length = i
print(length)
for i, _ in enumerate(val):
    length = i
print(length)

6999
1749
1749


In [188]:
next(iter(train)), next(iter(test)), next(iter(val))

([[0, 375.1569, 'Agency_18', 'SKU_05'], [1, 386.595, 'Agency_18', 'SKU_05']],
 [[40, 3081.258, 'Agency_49', 'SKU_04'],
  [41, 2362.383, 'Agency_49', 'SKU_04']],
 [[50, 4.95285, 'Agency_38', 'SKU_14'], [51, 5.2824, 'Agency_38', 'SKU_14']])

In [None]:
def groups_

grouped = ds.demux(num_instances=3, classifier_fn=)

In [None]:
#ds  = (pipes.FileOpener(datapipe, mode='rt').parse_csv(delimiter=',', skip_lines=1)
#            .map(parse_price)
#            .rolling(window_size=5, step=1)
#            .batch(4)
#      )

In [16]:
from torchdata.datapipes.iter import FileLister
# import torcharrow.dtypes as dt
# dp = pipes.FileLister([camvid_path/"images"], masks="*.png")
# DTYPE = dt.Struct([dt.Field("Values", dt.int32)])
#source_dp = FileLister(".", masks="df*.parquet")
#parquet_df_dp = source_dp.load_parquet_as_df(dtype=DTYPE)
#arquet_df_dp = source_dp.load_parquet_as_df()
# list(parquet_df_dp)[0]

ImportError: The library 'torcharrow' is necessary for this DataPipe but it is not available.Please visit https://github.com/facebookresearch/torcharrow/ to install it.

In [None]:
datapipe = pipes.IterableWrapper(["../data/HistoricalQuotes.csv"])
csv = pipes.FileOpener(datapipe, mode='rt').parse_csv(delimiter=',', skip_lines=1)

next(iter(csv))
