In [27]:
import typing as t
import itertools

import pandas as pd
import tensorflow as tf
from tqdm.auto import tqdm

In [29]:
df = pd.read_csv('data/yahoo_stock.csv', parse_dates=['Date'], index_col='Date')

In [32]:
df

Unnamed: 0_level_0,High,Low,Open,Close,Volume,Adj Close
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2015-11-23,2095.610107,2081.389893,2089.409912,2086.590088,3.587980e+09,2086.590088
2015-11-24,2094.120117,2070.290039,2084.419922,2089.139893,3.884930e+09,2089.139893
2015-11-25,2093.000000,2086.300049,2089.300049,2088.870117,2.852940e+09,2088.870117
2015-11-26,2093.000000,2086.300049,2089.300049,2088.870117,2.852940e+09,2088.870117
2015-11-27,2093.290039,2084.129883,2088.820068,2090.110107,1.466840e+09,2090.110107
...,...,...,...,...,...,...
2020-11-16,3628.510010,3600.159912,3600.159912,3626.909912,5.281980e+09,3626.909912
2020-11-17,3623.110107,3588.679932,3610.310059,3609.530029,4.799570e+09,3609.530029
2020-11-18,3619.090088,3567.330078,3612.090088,3567.790039,5.274450e+09,3567.790039
2020-11-19,3585.219971,3543.840088,3559.409912,3581.870117,4.347200e+09,3581.870117


In [35]:
df.index.min()

Timestamp('2015-11-23 00:00:00')

In [34]:
df.index.max()

Timestamp('2020-11-20 00:00:00')

In [36]:
df['subset'] = pd.qcut(df.index, q=[.0, .8, .9, 1.], labels=['train', 'val', 'test'])

In [71]:
df[df.subset == 'train'].index.max()

Timestamp('2019-11-21 00:00:00')

In [37]:
df

Unnamed: 0_level_0,High,Low,Open,Close,Volume,Adj Close,subset
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
2015-11-23,2095.610107,2081.389893,2089.409912,2086.590088,3.587980e+09,2086.590088,train
2015-11-24,2094.120117,2070.290039,2084.419922,2089.139893,3.884930e+09,2089.139893,train
2015-11-25,2093.000000,2086.300049,2089.300049,2088.870117,2.852940e+09,2088.870117,train
2015-11-26,2093.000000,2086.300049,2089.300049,2088.870117,2.852940e+09,2088.870117,train
2015-11-27,2093.290039,2084.129883,2088.820068,2090.110107,1.466840e+09,2090.110107,train
...,...,...,...,...,...,...,...
2020-11-16,3628.510010,3600.159912,3600.159912,3626.909912,5.281980e+09,3626.909912,test
2020-11-17,3623.110107,3588.679932,3610.310059,3609.530029,4.799570e+09,3609.530029,test
2020-11-18,3619.090088,3567.330078,3612.090088,3567.790039,5.274450e+09,3567.790039,test
2020-11-19,3585.219971,3543.840088,3559.409912,3581.870117,4.347200e+09,3581.870117,test


In [39]:
df['day'] = (df.index - df.index.min()).days

In [40]:
df

Unnamed: 0_level_0,High,Low,Open,Close,Volume,Adj Close,subset,day
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
2015-11-23,2095.610107,2081.389893,2089.409912,2086.590088,3.587980e+09,2086.590088,train,0
2015-11-24,2094.120117,2070.290039,2084.419922,2089.139893,3.884930e+09,2089.139893,train,1
2015-11-25,2093.000000,2086.300049,2089.300049,2088.870117,2.852940e+09,2088.870117,train,2
2015-11-26,2093.000000,2086.300049,2089.300049,2088.870117,2.852940e+09,2088.870117,train,3
2015-11-27,2093.290039,2084.129883,2088.820068,2090.110107,1.466840e+09,2090.110107,train,4
...,...,...,...,...,...,...,...,...
2020-11-16,3628.510010,3600.159912,3600.159912,3626.909912,5.281980e+09,3626.909912,test,1820
2020-11-17,3623.110107,3588.679932,3610.310059,3609.530029,4.799570e+09,3609.530029,test,1821
2020-11-18,3619.090088,3567.330078,3612.090088,3567.790039,5.274450e+09,3567.790039,test,1822
2020-11-19,3585.219971,3543.840088,3559.409912,3581.870117,4.347200e+09,3581.870117,test,1823


In [68]:
def get_historical_samples(subset: str, history_size: int = 5) -> t.Iterable[t.Tuple[pd.DataFrame, pd.Series]]:
    return filter(
        # only take given subset of samples
        lambda past_current: past_current[1].subset == subset,
        map(
            # split into historical data and cucrrent row
            lambda df: (df.iloc[:-1], df.iloc[-1]),
            # rolling window of constant size
            filter(
                lambda df: len(df) >= history_size + 1,
                df.rolling(window=history_size + 1)
            )
        )
    )

In [11]:
def _float_feature(value: float) -> tf.train.Feature:
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

In [65]:
def _float_list_feature(values: t.Iterable[float]) -> tf.train.Feature:
    return tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))

In [19]:
def _int64_feature(value: int) -> tf.train.Feature:
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [12]:
def _string_encode(value: str) -> bytes:
    assert value
    return value.encode("utf-8")

def _string_feature(value: str) -> tf.train.Feature:
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[_string_encode(value)])
    )

In [66]:
def serialize_example(past_rows: pd.DataFrame, current_row: t.Dict[str, t.Any]) -> bytes:
    features = {
        'past_high': _float_list_feature(past_rows['High']),
        'past_low': _float_list_feature(past_rows['Low']),
        'past_open': _float_list_feature(past_rows['Open']),
        'past_close': _float_list_feature(past_rows['Close']),
        'past_volume': _float_list_feature(past_rows['Volume']),
        'past_adj_close': _float_list_feature(past_rows['Adj Close']),
        'open': _float_feature(current_row['Open']),
        'close': _float_feature(current_row['Close']),
    }

    # Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(features=tf.train.Features(feature=features))
    message: bytes = example_proto.SerializeToString()
    return message

In [70]:
for subset in tqdm(['train', 'val', 'test']):
    with tf.io.TFRecordWriter(f'data/yahoo_stock.{subset}.tfrecord') as writer:
        for past_rows, current_row in get_historical_samples(subset):
            example = serialize_example(past_rows, current_row)
            writer.write(example)

  0%|          | 0/3 [00:00<?, ?it/s]