In [1]:
import os
os.chdir('../../..')

In [2]:
import pandas as pd
import examples.three_tank.constants as const
import numpy as np
import plotly.express as px
from ipywidgets import interact
import ipywidgets as widgets
from plotly.subplots import make_subplots
import plotly.graph_objects as go
%load_ext autoreload
%autoreload 2

In [8]:
# %run examples/three_tank/data_gen.py

In [9]:
df = pd.read_parquet(const.DATA_PATH)
df.head()

Unnamed: 0,h1,h2,h3,time,uid_ts_sample,q1,q3,kv12,kv23
0,10.0,95.0,33.0,0.0,0,6.825291,4.157763,0.235768,0.841853
1,10.571552,94.209491,34.022954,0.204082,0,6.825291,4.157763,0.235768,0.841853
2,11.141523,93.429404,35.035404,0.408163,0,6.825291,4.157763,0.235768,0.841853
3,11.709916,92.65977,36.037314,0.612245,0,6.825291,4.157763,0.235768,0.841853
4,12.27673,91.900619,37.028645,0.816327,0,6.825291,4.157763,0.235768,0.841853


In [10]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 500000 entries, 0 to 499999
Data columns (total 9 columns):
 #   Column         Non-Null Count   Dtype  
---  ------         --------------   -----  
 0   h1             500000 non-null  float64
 1   h2             500000 non-null  float64
 2   h3             500000 non-null  float64
 3   time           500000 non-null  float64
 4   uid_ts_sample  500000 non-null  int64  
 5   q1             500000 non-null  float64
 6   q3             500000 non-null  float64
 7   kv12           500000 non-null  float64
 8   kv23           500000 non-null  float64
dtypes: float64(8), int64(1)
memory usage: 38.1 MB


# From parquet file

In [13]:
def _create_ts_plot(idx):
    df_plot = df[df[const.UID_SAMPLE_COL_NAME] == idx]
    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
    # signal 1
    print(df_plot[['q1', 'q3', 'kv12', 'kv23']][0:1])
    for col, name in zip(const.STATE_COL_NAMES, ['h1(t)', 'h2(t)', 'h3(t)']):
        fig.add_trace(go.Scatter(x=df_plot.time, y=df_plot[col], name=name,
                      mode="lines", opacity=1),
            row=1, col=1)

    fig.update_xaxes(title_text='time')
    fig.update_yaxes(title_text='x', row=1)
    fig.update_yaxes(title_text='x_dot', row=2)
    fig.update_layout(title_text=f"load from parquet file", showlegend=True)
    fig.show()

interact(_create_ts_plot, idx=list(range(100)))

interactive(children=(Dropdown(description='idx', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 1…

<function __main__._create_ts_plot(idx)>

## from pytorch dataset

In [22]:
from examples.three_tank.dataset import ThreeTankDataSet
dataset = ThreeTankDataSet()
def _create_ts_plot(idx):
    x, idx = dataset[idx]
    df_plot = pd.DataFrame(x, columns=const.STATE_COL_NAMES)
    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
    # signal 1
#     print(df_plot[['q1', 'q3', 'kv12', 'kv23']][0:1])
    for col, name in zip(const.STATE_COL_NAMES, ['h1(t)', 'h2(t)', 'h3(t)']):
        fig.add_trace(go.Scatter(x=np.array(range(const.NUMBER_TIMESTEPS)), y=df_plot[col], name=name,
                      mode="lines", opacity=1),
            row=1, col=1)

    fig.update_xaxes(title_text='time')
    fig.update_yaxes(title_text='x', row=1)
    fig.update_yaxes(title_text='x_dot', row=2)
    fig.update_layout(title_text=f"load from pytorch dataset", showlegend=True)
    fig.show()

interact(_create_ts_plot, idx=list(range(100)))

interactive(children=(Dropdown(description='idx', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 1…

<function __main__._create_ts_plot(idx)>

## from lightning data module

In [25]:
from examples.three_tank.data_module import ThreeTankDataModule
hparams = dict(
    validdation_split=.1,
    batch_size=100,
    dl_num_workers=0
)
ttdm = ThreeTankDataModule(**hparams)
ttdm.setup()
dl = ttdm.train_dataloader()
train_batch = iter(dl).next()

In [26]:
x_train, idx = train_batch

In [27]:
idx

tensor([4949, 8190, 7238, 9815, 3407, 8975, 4954, 2617, 6490, 4517, 2685, 4841,
        8759, 6693, 4110, 2411, 6173, 6031, 3648, 5426, 5203, 5038, 8565, 4147,
        4729, 6881, 7380, 8257, 2952, 3708, 5940, 9730, 4704, 4456, 9309, 5204,
        7891, 5389, 5677, 7720, 9248, 1838, 3427, 1400, 5576, 4117, 4839,   41,
        9765,  698, 7786, 8288,  710, 2110, 3253,  609, 7401, 7406, 5987, 6094,
        7513, 9384, 9967, 7957,  578, 3497, 5259, 5966, 1845, 9652,    4, 7177,
         983, 7899, 8920, 9700, 5705, 8812, 1469, 3560, 5452, 2588, 6634, 7071,
        2462,  148, 2312, 7905,  370, 4773, 7301, 7620, 6559, 2406, 6020, 8380,
        7947, 2953, 5053, 9485])

In [29]:
x_train.shape

torch.Size([100, 50, 3])