In [31]:
import pandas as pd
import numpy as np
from Functions.LoadData import get_data
import plotly.express as px
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

In [19]:
df_combined, df_fine_patient_phases, medication_settings = get_data('EM2')

In [20]:
fine_phases = df_fine_patient_phases['phase'].unique()
dbs_change_dates = []


for i in range(1,len(fine_phases)):
    current_phase = df_fine_patient_phases[df_fine_patient_phases['phase'] == fine_phases[i]].iloc[:,2:].to_numpy()
    previous_phase = df_fine_patient_phases[df_fine_patient_phases['phase'] == fine_phases[i-1]].iloc[:,2:].to_numpy()

    if(np.any(current_phase - previous_phase) != 0):
        dbs_change_dates.append(df_fine_patient_phases['date'][df_fine_patient_phases['phase'] == fine_phases[i]].iloc[0])

In [21]:
df_dbs_settings = df_fine_patient_phases[df_fine_patient_phases['date'].isin(dbs_change_dates)].iloc[:,2:]

In [28]:
def plotly_line(df_combined, dbs_change_dates, df_dbs_settings, column, agregation='10min', width=1200, height=1200):
    df_column = df_combined[['time', column]].groupby(pd.Grouper(freq=agregation, key='time')).mean()

    fig = px.line(df_column, x=df_column.index, y=column, title=column, line_shape='hv', width=width, height=height)
    for i in range(len(dbs_change_dates)):
        date = dbs_change_dates[i]
        fig.add_vline(x=date, line_width=3, line_dash="dash", line_color="red")
        
        # for each column in df_dbs_settings.iloc[i,:]: do column_name: column_value new line
        str_annotation = ''
        for column_name, column_value in df_dbs_settings.iloc[i,:].items():
            str_annotation += column_name + ': ' + str(column_value) + '<br>'

        fig.add_annotation(x=date,
                        y=df_column[column].max(),
                        text="DBS change:<br>" + str_annotation,
                        showarrow=True,
                        arrowhead=1,
                        yshift=10)
    
    fig.update_layout(
        xaxis=dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=1,
                        label="1m",
                        step="month",
                        stepmode="backward"),
                    dict(count=7,
                        label="1w",
                        step="day",
                        stepmode="backward"),
                    dict(count=1,
                        label="1d",
                        step="day",
                        stepmode="backward")
                ])
            ),
            rangeslider=dict(
                visible=True
            ),
            type="date"
        )
    )            

    fig.write_html("plot1.html")
    fig.show()

def plotly_heatmap(df_combined, dbs_change_dates, df_dbs_settings, column, agregation='15min', width=1200, height=1200):
    df_column = df_combined[['time', column]].groupby(pd.Grouper(freq=agregation, key='time')).mean()

    # split date and time: date as row, time as column
    df_aggregated = pd.DataFrame(df_column.index.to_series().dt.date.unique(), columns=['date'])
    df_aggregated = df_aggregated.set_index('date')

    for time in df_column.index.to_series().dt.time.unique():
        df_aggregated[time] = df_column[df_column.index.to_series().dt.time == time][column].to_numpy()
    
    # transpose dataframe
    df_aggregated = df_aggregated.transpose()
    df_aggregated = df_aggregated.iloc[::-1]
    
    # cmap=Reds
    fig = px.imshow(df_aggregated, title=column, labels=dict(x="Time", y="Date", color=column), width=width, height=height, color_continuous_scale='Reds')

    fig.update_layout(
        xaxis=dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=1,
                        label="1m",
                        step="month",
                        stepmode="backward"),
                    dict(count=7,
                        label="1w",
                        step="day",
                        stepmode="backward"),
                    dict(count=1,
                        label="1d",
                        step="day",
                        stepmode="backward")
                ])
            ),
            rangeslider=dict(
                visible=True
            ),
            type="date"
        )
    )
    
    for i in range(len(dbs_change_dates)):
        date = dbs_change_dates[i]
        fig.add_vline(x=date, line_width=3, line_dash="dash", line_color="red")
        
        # for each column in df_dbs_settings.iloc[i,:]: do column_name: column_value new line
        str_annotation = ''
        for column_name, column_value in df_dbs_settings.iloc[i,:].items():
            str_annotation += column_name + ': ' + str(column_value) + '<br>'

        fig.add_annotation(x=date,
                        y=df_aggregated.index.max(),
                        text="DBS change:<br>" + str_annotation,
                        showarrow=True,
                        arrowhead=1,
                        yshift=10)
    
    fig.write_html("plot2.html")
    fig.show()

def plot_data(df, column, agregation='10min', width='1200', height='1200'):
    width=int(width)
    height=int(height)
    plotly_line(df, dbs_change_dates, df_dbs_settings, column, agregation, width, height)
    plotly_heatmap(df, dbs_change_dates, df_dbs_settings, column, agregation, width, height),

    

In [34]:
interact(plot_data, df=fixed(df_combined), column=df_combined.columns[1:-2], agregation=['1min', '5min', '10min', '15min', '30min', '1h', '2h', '3h', '6h', '12h', '1d'], width='1200', height='1200')

interactive(children=(Dropdown(description='column', options=('probability_dyskinesia', 'probability_tremor', …

<function __main__.plot_data(df, column, agregation='10min', width='1200', height='1200')>

In [30]:
import webbrowser
import os

# generate an URL
url = 'file://' + os.path.realpath("plot1.html")
webbrowser.open(url)

url = 'file://' + os.path.realpath("plot2.html")
webbrowser.open(url)

True

True