In [18]:
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import numpy as np
import ipywidgets as widgets
from ipywidgets import Layout
from IPython.display import display
import os

In [19]:
plt.rc('font', family='serif', size=20)
matplotlib.rc('text', usetex=True)
matplotlib.rc('legend', fontsize=20)
output_path = "D:/Dissetation/overleaf/dissertation/pics"
# output_path = "./output_figures"
os.makedirs(output_path, exist_ok=True)


In [20]:

def plot(data, x, y, hue, filter_column1, filter_values1, filter_column2, filter_values2, normalize_by_bmi, normalize_by_moment, output_path):
    filtered_data = data.copy()
    if filter_column1 and filter_values1:
        filtered_data = filtered_data[filtered_data[filter_column1].isin(filter_values1)]
    if filter_column2 and filter_values2:
        filtered_data = filtered_data[filtered_data[filter_column2].isin(filter_values2)]
    
    y_column = y

    if normalize_by_bmi and 'BMI' in filtered_data.columns:
        filtered_data[y_column] = filtered_data[y_column] / filtered_data['BMI']
        y_label = r'\textbf{Normalized ' + y.replace('AnteriorPosteriorShear', 'AP Shear').replace('LateralShear', 'Lateral Shear') + r' ($\frac{N}{\mathrm{kg/m}^2}$)}'
    elif normalize_by_moment and 'Total' in filtered_data.columns:
        filtered_data[y_column] = filtered_data[y_column] / filtered_data['Total']
        y_label = r'\textbf{Normalized ' + y.replace('AnteriorPosteriorShear', 'AP Shear').replace('LateralShear', 'Lateral Shear') + r' ($\frac{N}{Nm}$)}'
    else:
        y_label = f'\\textbf{{{y.replace("AnteriorPosteriorShear", "AP Shear").replace("LateralShear", "Lateral Shear")} (N)}}'
    
    fig, ax = plt.subplots(figsize=(16, 9))
    hue_unique = filtered_data[hue].nunique()
    palette = sns.color_palette("Blues", hue_unique)
    
    sns.barplot(
        x=x, y=y_column, hue=hue, data=filtered_data, errorbar="ci", capsize=.4,
        err_kws={"color": "0.2", "linewidth": 1}, palette=palette
    )
    
    plt.xlabel(f'\\textbf{{{x.replace("Decade", "Age Group")}}}', fontsize=22)
    plt.ylabel(y_label, fontsize=22)
    ax.legend()
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    filename = f"barplot_{y.lower()}_by_{hue.lower()}"
    if normalize_by_bmi:
        filename += "_normalized_by_bmi"
    elif normalize_by_moment:
        filename += "_normalized_by_moment"
    if filter_column1 and filter_values1:
        filename += f"_filtered_by_{filter_column1.lower()}_{'_'.join(map(str, filter_values1)).lower()}"
    if filter_column2 and filter_values2:
        filename += f"_filtered_by_{filter_column2.lower()}_{'_'.join(map(str, filter_values2)).lower()}"
    # filtered_data.to_csv(filename.replace('png', 'csv'), index=False)
    # plt.savefig(os.path.join(output_path, filename), bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()
    return filtered_data

def plot_distribution(data, x, y, hue, output_path):
    plt.figure(figsize=(16, 9))
    sns.histplot(
        data=data,
        x=x,
        y=y,
        hue=hue,
        multiple="layer",
        # palette=palette,
        kde=False,
        element="bars"
    )
    
    plt.xlabel(f'\\textbf{{{x}}}', fontsize=22)
    plt.ylabel(r'\textbf{Frequency}', fontsize=22)
    plt.title(f'Distribution of {x} by {hue}', fontsize=24)
    filename = f"distribution_{x.lower()}_by_{hue.lower()}.png"
    filepath = os.path.join(output_path, filename)
    # plt.savefig(filepath, bbox_inches='tight', dpi=300)
    
    plt.show()
    plt.close()

In [21]:
def plotReg(data, x, y, hue, filter_column1, filter_values1, filter_column2, filter_values2, normalize_by_bmi, normalize_by_moment, output_path):
    filtered_data = data.copy()
    if filter_column1 and filter_values1:
        filtered_data = filtered_data[filtered_data[filter_column1].isin(filter_values1)]
    if filter_column2 and filter_values2:
        filtered_data = filtered_data[filtered_data[filter_column2].isin(filter_values2)]
    
    y_column =y

    if normalize_by_bmi and 'BMI' in filtered_data.columns:
        filtered_data[y_column] = filtered_data[y_column] / filtered_data['BMI']
        normalization = "bmi"
    elif normalize_by_moment and 'Total' in filtered_data.columns:
        filtered_data[y_column] = filtered_data[y_column] / filtered_data['Total']
        normalization = "moment"
    else:
        normalization = "none"

    fig, ax = plt.subplots(figsize=(16, 9))
    
    hue_unique = filtered_data[hue].unique()
    palette = sns.color_palette("Blues", len(hue_unique))
    palette = sns.color_palette("viridis", len(hue_unique))
    markers = ['o', 'X']
    # Plot scatter and regression lines
    for i, hue_val in enumerate(hue_unique):
        hue_data = filtered_data[filtered_data[hue] == hue_val]
        sns.scatterplot(
            data=hue_data,
            x=x,
            y=y_column,
            label=f"{hue_val}",
            # color=palette[i],
            marker=markers[i % len(markers)],
            s=200,
            ax=ax
        )
        sns.regplot(
            data=hue_data,
            x=x,
            y=y_column,
            x_ci='ci',
            ci=90,
            scatter=False,
            order=1,
            # color=palette[i],
            ax=ax,
            # line_kws={"label": f"Fit: {hue_val}"}
        )

    # x_label = x.replace("Decade", "Age Group")
    # y_label = y.replace('AnteriorPosteriorShear', 'AP Shear').replace('LateralShear', 'Lateral Shear')
    
    # ax.set_xlabel(f'\\textbf{{{x_label}}}', fontsize=22)
    # if normalize_by_bmi:
    #     ax.set_ylabel(r'\textbf{Normalized ' + f'{y_label}' + r' ($\frac{N}{\mathrm{kg/m}^2}$)}', fontsize=22)
    # elif normalize_by_moment:
    #     ax.set_ylabel(r'\textbf{Normalized ' + f'{y_label}' + r' ($\frac{N}{Nm}$)}', fontsize=22)
    # else:
    #     ax.set_ylabel(f'\\textbf{{{y_label} (N)}}', fontsize=22)

    # title_parts = []
    # if normalize_by_bmi:
    #     title_parts.append("Normalized by BMI")
    # if normalize_by_moment:
    #     title_parts.append("Normalized by Moment")
    # if filter_column1 and filter_values1:
    #     title_parts.append(f"Filtered by {filter_column1}: {', '.join(map(str, filter_values1))}")
    # if filter_column2 and filter_values2:
    #     title_parts.append(f"Filtered by {filter_column2}: {', '.join(map(str, filter_values2))}")
    
    ax.legend()
    ax.grid(axis='y', linestyle='--', alpha=0.7)

    filename_parts = []
    filename_parts.append(y.lower())
    filename_parts.append(f"by_{hue.lower()}")
    if normalization != "none":
        filename_parts.append(f"normalized_by_{normalization}")
    if filter_column1 and filter_values1:
        filter_str = "_".join([f"{filter_column1.lower()}_{str(v).lower()}" for v in filter_values1])
        filename_parts.append(f"filtered_by_{filter_str}")
    if filter_column2 and filter_values2:
        filter_str = "_".join([f"{filter_column2.lower()}_{str(v).lower()}" for v in filter_values2])
        filename_parts.append(f"filtered_by_{filter_str}")
    # filename_parts.append(f"{endplate.lower()}_endplate")
    filename = "_".join(filename_parts) + ".png"
    filepath = os.path.join(output_path, filename)
    
    plt.savefig(filepath, bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()
    return filtered_data


In [22]:
path = "E:/Quanitifying EMG/Summary-Muscle Forces.csv"
data = pd.read_csv(path)
sorted_decades = sorted(data['Decade'].unique())
data['Status'] = data['Status'].replace('Control', 'Asymptomatic')
data['Decade'] = pd.Categorical(data['Decade'], categories=sorted_decades, ordered=True)
data['Muscle'] = data['Muscle'].replace('RightErectorSpinae', 'RES')
data['Muscle'] = data['Muscle'].replace('LeftErectorSpinae', 'LES')
data['Muscle'] = data['Muscle'].replace('RightInternalOblique', 'RIO')
data['Muscle'] = data['Muscle'].replace('LeftInternalOblique', 'LIO')
data['Muscle'] = data['Muscle'].replace('RightLatissimusDorsi', 'RLD')
data['Muscle'] = data['Muscle'].replace('LeftLatissimusDorsi', 'LLD')
data['Muscle'] = data['Muscle'].replace('RightExternalOblique', 'REO')
data['Muscle'] = data['Muscle'].replace('LeftExternalOblique', 'LEO')
data['Muscle'] = data['Muscle'].replace('RightRectusAbdominis', 'RRA')
data['Muscle'] = data['Muscle'].replace('LeftRectusAbdominis', 'LRA')
x_options = ['Status', 'Sex', 'Decade', 'Muscle']
y_options = ['Total', 'Active', 'Passive', 'ActiveUnmodulated']
hue_options = ['Status', 'Decade', 'Sex',  'Trial Type']
filter_column_options = ['Status', 'Sex', 'Subject', 'Decade', 'Trial Name', 'Trial Type', 'Muscle', None]
endplate_options = ['Superior', 'Inferior']

default_filter_column1 = 'Sex'
unique_levels = data[default_filter_column1].unique()
default_filter_values1 = unique_levels
default_filter_column2 = 'Trial Name'
unique_trials = data[default_filter_column2].unique()
default_filter_values2 = unique_trials

x_dropdown = widgets.Dropdown(options=x_options, value='Muscle', description='X-axis:')
y_dropdown = widgets.Dropdown(options=y_options, value='Total', description='Y-axis:')
hue_dropdown = widgets.Dropdown(options=hue_options, value='Status', description='Hue:')
filter_column_dropdown1 = widgets.Dropdown(options=filter_column_options, value=default_filter_column1, description='Filter by:')
filter_values_select1 = widgets.SelectMultiple(options=default_filter_values1, description='Filter Values:', disabled=False)
filter_column_dropdown2 = widgets.Dropdown(options=filter_column_options, value=default_filter_column2, description='Filter by:')
filter_values_select2 = widgets.SelectMultiple(options=default_filter_values2, description='Filter Values:', disabled=False)

normalize_by_bmi_checkbox = widgets.Checkbox(value=False, description='Normalize by BMI', disabled=False)
normalize_by_moment_checkbox = widgets.Checkbox(value=False, description='Normalize by Moment', disabled=False)

def update_filter_values1(change):
    if change['new'] and change['new'] in data.columns:
        unique_values = data[change['new']].unique()
        sorted_values = sorted(unique_values, key=lambda x: str(x))
        filter_values_select1.options = sorted_values
        filter_values_select1.disabled = False
    else:
        filter_values_select1.options = []
        filter_values_select1.disabled = True

def update_filter_values2(change):
    if change['new'] and change['new'] in data.columns:
        unique_values = data[change['new']].unique()
        sorted_values = sorted(unique_values, key=lambda x: str(x))
        filter_values_select2.options = sorted_values
        filter_values_select2.disabled = False
    else:
        filter_values_select2.options = []
        filter_values_select2.disabled = True

filter_column_dropdown1.observe(update_filter_values1, names='value')
filter_column_dropdown2.observe(update_filter_values2, names='value')

def update_plot(x, y, hue, filter_column1, filter_values1, filter_column2, filter_values2, normalize_by_bmi, normalize_by_moment):
    filtered_data = plot(data, x, y, hue, filter_column1, filter_values1, filter_column2, filter_values2, normalize_by_bmi, normalize_by_moment, output_path)
    # filtred_data = plotReg(data, x, y, hue, filter_column1, filter_values1, filter_column2, filter_values2, normalize_by_bmi, normalize_by_moment, output_path)

x_dropdown.layout = Layout(width='200px', margin='5px')
y_dropdown.layout = Layout(width='200px', margin='5px')
hue_dropdown.layout = Layout(width='200px', margin='5px')
normalize_by_bmi_checkbox.layout = Layout(width='auto', margin='5px')
normalize_by_moment_checkbox.layout = Layout(width='auto', margin='5px')
filter_column_dropdown1.layout = Layout(width='200px', margin='5px')
filter_values_select1.layout = Layout(width='200px', height='100px', margin='5px')
filter_column_dropdown2.layout = Layout(width='200px', margin='5px')
filter_values_select2.layout = Layout(width='200px', height='100px', margin='5px')

controls_group1 = widgets.HBox([
    widgets.VBox([widgets.HTML('<b>Plot Controls</b>'), x_dropdown, y_dropdown, hue_dropdown])
], layout=Layout(margin='10px'))

controls_group2 = widgets.HBox([
    widgets.VBox([widgets.HTML('<b>Filtering 1</b>'), filter_column_dropdown1, filter_values_select1])
], layout=Layout(margin='10px'))

controls_group3 = widgets.HBox([
    widgets.VBox([widgets.HTML('<b>Filtering 2</b>'), filter_column_dropdown2, filter_values_select2])
], layout=Layout(margin='10px'))

controls_group4 = widgets.HBox([
    widgets.VBox([
        widgets.HTML('<b>Normalization</b>'), 
        widgets.Box([normalize_by_bmi_checkbox], layout=Layout(align_items='flex-start', padding='0px')),
        widgets.Box([normalize_by_moment_checkbox], layout=Layout(align_items='flex-start', padding='0px'))
    ])
], layout=Layout(margin='10px'))

widgets_container = widgets.HBox(
    [controls_group1, controls_group2, controls_group3, controls_group4],
    layout=Layout(
        display='inline-flex',
        flex_flow='row nowrap',
        align_items='flex-start',
        justify_content='space-around',
        width='100%'
    )
)

display(widgets_container)

out = widgets.interactive_output(
    update_plot,
    {
        'x': x_dropdown,
        'y': y_dropdown,
        'hue': hue_dropdown,
        'filter_column1': filter_column_dropdown1,
        'filter_values1': filter_values_select1,
        'filter_column2': filter_column_dropdown2,
        'filter_values2': filter_values_select2,
        'normalize_by_bmi': normalize_by_bmi_checkbox,
        'normalize_by_moment': normalize_by_moment_checkbox
    }
)

display(out)

HBox(children=(HBox(children=(VBox(children=(HTML(value='<b>Plot Controls</b>'), Dropdown(description='X-axis:…

Output()