In [None]:
from plotly import graph_objects as go
from plotly.graph_objects import Layout
import pandas as pd
import datetime

In [None]:
path = "./weight_loss_dfs/jordan_df.pqt"
df = pd.read_parquet(path)

In [None]:
class FigureGenerator():
    def __init__(self, df):
        self.modes = ["start_delta", "end_delta"]
        self.df = df
        self.columns = df.columns
        
    def create_figure(self, axes, start=None, end=None, mode=None):
        df = self.df
        
        for axis in axes:
            if axis not in df.columns:
                raise ValueError(
                    "{} is not a column in the dataframe. All axes must be columns in the dataframe.".format(
                        axis
                    )
                )
        
        if start is not None:
            if not isinstance(start, datetime.date):
                raise ValueError("start must be an instance of datetime.date")
            else:
                df = df[df.index >= str(start)]
        else:
            start = df.index[0]
                
        if end is not None:
            if not isinstance(end, datetime.date):
                raise ValueError("end must be an instance of datetime.date")
            else:
                df = df[df.index <= str(end)]
        else:
            end = df.index[-1]
            
        if mode is not None:
            if mode not in self.modes:
                raise ValueError("{} mode is invalid".format(mode))
        
        traces = self.get_traces(df, axes, mode=mode)
        fig = go.Figure(
            data=traces,
        )
        
        return fig
        
    def get_traces(self, df, axes, mode):
        # get axis range here? leave it up to plotly for now
        traces = []
        for axis in axes:
            axis_df = df.copy()
            if mode is not None:
                if mode == "start_delta":
                    axis_df[axis] = axis_df[axis] - axis_df[axis].dropna()[-1]
                elif mode == "end_delta":
                    axis_df[axis] = axis_df[axis] - axis_df[axis].dropna()[0]
                else:
                    pass
            else:
                pass
            traces.append(
                go.Scatter(
                    x=axis_df.index,
                    y=axis_df[axis],
                    name=axis,
                    showlegend=True,
                    connectgaps=True
                )
            )
        return traces

In [None]:
fg = FigureGenerator(df = df)

In [None]:
fg.create_figure(axes=["Belly", "Chest", "Hips", "Waist", "Bicep"], mode="start_delta")