In [16]:
import re
from typing import List, Any

import pandas as pd
import plotly.express as px

In [17]:
class InvalidFormatError(Exception):
    pass


def pipe(raw_input: Any, *functions: Any, **functions_with_args: Any) -> Any:

    output = raw_input

    if functions:
        for function in functions:
            output = function(output)

    if functions_with_args:
        for function, args_list in functions_with_args.items():
            output = eval(function)(output, *args_list)

    return output


def _clean_states(text: str) -> List[str]:
    return text.strip('][').strip().replace('\n', '')   


def _format_whitespaces(text: str) -> List[str]:
    return re.sub(' {2,}', ' ', text).split()


def _evaluate_states(states: List[str]) -> List[float]:
    return [eval(state) for state in states]


def _format_states(states: str) -> List[str]:
    return pipe(
        states,
        _clean_states,
        _format_whitespaces,
        _evaluate_states
    )


def format_states(df: pd.DataFrame) -> pd.DataFrame:
    if df.shape[0] == 1:
        return df.T[0].apply(_format_states).apply(pd.Series)
    else:
        try:
            return df[0].apply(_format_states).apply(pd.Series)
        except Exception:
            raise InvalidFormatError("Invalid DataFrame format")

In [18]:
path = '/home/ubuntu/Dabid/epi-rl/states.csv'

In [19]:
df = pd.read_csv(path, header=None)

In [20]:
df_states = format_states(df)

In [21]:
df_states

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16
0,34218.0,107188.0,168098.0,11652.0,0.034218,0.107188,0.168098,0.011652,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0
1,34216.973,107179.57,168096.84,11651.532,0.000313,1.029953,0.0,0.0,1.711166,2.414158,1.721944,0.902732,1.955139,6.777658,1.724705,0.72066,2.0
2,34199.7188,107113.281,168074.938,11652.1455,5.327175,20.778029,7.755228,0.0,8.923193,21.90869,8.005012,0.83055,7.729094,33.819969,11.033174,1.995465,2.0
3,34120.9688,106762.195,167970.875,11647.2256,19.332945,83.183624,18.793667,0.0,22.464193,98.576881,35.385826,2.401134,59.228333,245.840912,76.676895,7.14044,2.0
4,33758.355,105018.74,167484.64,11630.247,82.41011,431.55847,100.65711,5.201004,115.26726,526.85309,162.75548,5.56075,265.96075,1212.6461,353.67105,16.513575,2.0
5,32109.055,96931.336,165079.91,11536.477,408.42142,1848.8801,540.97083,21.327322,517.28943,2464.2893,792.58575,31.622288,1187.2299,5945.2993,1688.2606,68.098335,2.0
6,26263.387,71211.336,156100.53,11208.546,1171.9341,4759.457,1879.6831,70.336334,1694.6167,7405.4956,2590.543,100.39578,5092.0576,23813.504,7530.9741,278.24606,2.0
7,24935.518,68859.82,153068.56,11047.682,46.790451,120.00272,134.30284,9.154311,253.54269,601.14435,557.83301,27.142998,8986.1455,37608.828,14341.022,573.54401,1.0
8,23965.5664,65684.7266,151550.875,10978.7705,145.382721,522.509705,220.730011,4.373356,235.708237,809.531738,406.985291,28.368282,9875.34277,40173.0391,15923.1377,646.010437,1.0
9,22686.148,61186.336,149449.28,10896.979,189.68022,745.2356,335.91638,13.9542,352.19116,1176.8209,566.57123,19.337362,10993.985,44081.414,17749.967,727.25256,1.0


In [22]:
states_columns = [
    's_children', 's_adolescents', 's_adults', 's_elderly',
    'e_children', 'e_adolescents', 'e_adults', 'e_elderly', 
    'i_children', 'i_adolescents', 'i_adults', 'i_elderly',
    'r_children', 'r_adolescents', 'r_adults', 'r_elderly',
    'budget'
]

In [23]:
df_states.columns = states_columns

In [24]:
px.line(df_states)