In [30]:
import re
from typing import List, Any
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.express as px

In [18]:
class InvalidFormatError(Exception):
    pass


def pipe(raw_input: Any, *functions: Any, **functions_with_args: Any) -> Any:

    output = raw_input

    if functions:
        for function in functions:
            output = function(output)

    if functions_with_args:
        for function, args_list in functions_with_args.items():
            output = eval(function)(output, *args_list)

    return output


def _clean_states(text: str) -> List[str]:
    return text.strip('][').strip().replace('\n', '')   


def _format_whitespaces(text: str) -> List[str]:
    return re.sub(' {2,}', ' ', text).split()


def _evaluate_states(states: List[str]) -> List[float]:
    return [eval(state) for state in states]


def _format_states(states: str) -> List[str]:
    return pipe(
        states,
        _clean_states,
        _format_whitespaces,
        _evaluate_states
    )


def format_states(df: pd.DataFrame) -> pd.DataFrame:
    if df.shape[0] == 1:
        return df.T[0].apply(_format_states).apply(pd.Series)
    else:
        try:
            return df[0].apply(_format_states).apply(pd.Series)
        except Exception:
            raise InvalidFormatError("Invalid DataFrame format")
            
            
def format_others(path: Path) -> np.ndarray:
    arr = pd.read_csv(path, header=None).values.flatten()
    return np.insert(arr, 0, np.nan)

In [3]:
path = lambda directory, filename: Path.joinpath(Path.cwd().parent, 'data/processed', directory, filename)

In [32]:
actions_10 = format_others(path('actions', 'actions_seir_budget_10.csv'))
rewards_10 = format_others(path('rewards', 'rewards_seir_budget_10.csv'))

In [33]:
df_10 = pd.read_csv(path('states', 'states_seir_budget_10.csv'), header=None)

In [39]:
df_states_10 = format_states(df_10)

In [40]:
df_states_10['actions'] = actions
df_states_10['rewards'] = rewards

In [41]:
df_states_10

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,actions,rewards
0,34218.0,107188.0,168098.0,11652.0,0.034218,0.107188,0.168098,0.011652,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.0,,
1,34213.602,107178.7,168092.94,11651.81,4.101307,1.081136,1.19603,1.21252,0.0,4.099381,1.075274,0.0,3.683418,7.014423,4.517403,1.147615,10.0,1.0,1.6e-05
2,34197.379,107093.47,168077.78,11650.303,3.481745,12.598425,4.010811,0.685008,6.560257,28.649109,5.881249,0.0,13.963347,56.1731,14.363219,4.559427,10.0,1.0,0.000257
3,34099.277,106634.69,167932.64,11642.339,28.358915,104.33933,30.28558,2.907368,28.117466,142.70163,49.279537,0.539464,65.631744,309.1597,89.820564,10.947827,10.0,1.0,0.001591
4,33628.9648,104357.984,167316.688,11621.4941,115.054878,543.100281,136.082718,6.01796,142.594849,672.883911,202.003204,4.929506,334.769897,1616.91455,447.263794,26.814171,10.0,1.0,0.008835
5,31825.93,94975.688,164621.14,11526.121,414.31201,2103.9539,609.96472,23.227119,516.43384,2860.4707,807.75342,24.940081,1464.7065,7250.7793,2063.1855,84.968788,10.0,1.0,0.044636
6,25661.371,67815.977,155009.11,11198.568,1235.5907,4987.8955,1965.5515,69.908073,1834.3503,7717.9961,2811.6306,88.230438,5490.0693,26669.035,8315.7734,302.55032,10.0,1.0,0.167521
7,24327.559,65430.559,151866.8,11017.091,82.99868,94.254135,144.61655,9.68015,275.70822,579.0,552.88428,29.186918,9535.1133,41087.074,15537.766,603.29889,9.0,0.0,0.189403
8,23430.305,62660.633,150301.36,10952.429,134.16406,434.21399,231.7348,9.527807,210.94156,765.27057,394.14471,9.661077,10445.966,43330.77,17174.809,687.63855,9.0,1.0,0.20673
9,22440.8027,59245.1875,148551.062,10882.2686,132.552872,516.502625,256.016449,10.871448,266.148987,907.087402,437.186066,12.95128,11381.874,46522.1055,18857.7637,753.162903,9.0,1.0,0.2297


In [42]:
px.line(df_states_10)

In [43]:
px.line(df_states)