In [1]:
# Core data processing libraries
import pandas as pd
import numpy as np

# Visualization libraries
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize
import matplotlib.lines as mlines
import matplotlib.patches as patches
import seaborn as sns
from itertools import cycle

# Utility imports
import csv, json, re, io, warnings, os, copy, base64
from io import BytesIO
from functools import partial
from itertools import combinations

# Jupyter-specific imports
import traitlets
from traitlets import HasTraits, Instance, observe
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
from ipywidgets import (
    interact, interactive, fixed, interact_manual,
    GridspecLayout, VBox, HBox, Layout, Output
)
from ipydatagrid import DataGrid

# Initialize settings
import _settings as settings

# Global variables from settings
spec_translate_list = settings.SPEC_TRANSLATE_LIST
valid_discrete_cmaps = settings.valid_discrete_cmaps
valid_gradient_cmaps = settings.valid_gradient_cmaps
default_hm_color = settings.default_hm_color
default_lp_color = settings.default_lp_color
default_avglp_color = settings.default_avglp_color
hm_selected_color = settings.hm_selected_color
cmap = settings.cmap
#lp_selected_color = settings.lp_selected_color
#avglp_selected_color = settings.avglp_selected_color
#avg_cmap = settings.avg_cmap
legend_title = settings.legend_title
warnings.simplefilter(action='ignore', category=FutureWarning)
chuck_size = 78
plot_heatmap, plot_zero = 'yes', 'no'

data_transformer = HeatmapPlotHandler = None

In [2]:

def create_download_link(fig, filename):
    """
    Creates and returns an HTML download link for a matplotlib figure
    
    Parameters:
        fig: matplotlib figure object
        filename: str, name for the downloaded file
        
    Returns:
        str: HTML for automatic download trigger
    """
    from io import BytesIO
    import base64
    
    # Create a BytesIO buffer to save the figure
    buf = BytesIO()
    
    # Save figure while maintaining original dimensions but with more reasonable DPI
    fig.savefig(
        buf,
        format='png',
        dpi=300,  # Standard publication quality DPI
        bbox_inches='tight',
        pad_inches=0.1,
        transparent=False,
        metadata={
            'Creator': 'Matplotlib',
            'Title': filename,
            'Software': 'Python'
        },
        facecolor=fig.get_facecolor(),
    )
    
    buf.seek(0)
    
    # Encode the bytes as base64
    img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
    
    # Create the HTML download trigger with error handling
    href = f'data:image/png;base64,{img_str}'
    download_html = f'''
    <div id="download_{filename}">
        <a id="download_link_{filename}" 
           href="{href}" 
           download="{filename}"
           style="display: none;"></a>
        <script>
            try {{
                const link = document.getElementById('download_link_{filename}');
                if (link) {{
                    link.click();
                    setTimeout(() => {{
                        const container = document.getElementById('download_{filename}');
                        if (container) container.remove();
                    }}, 1000);
                }}
            }} catch(e) {{
                console.error("Error initiating download:", e);
            }}
        </script>
    </div>
    '''
    
    return download_html

def calculate_optimal_dpi(fig):
    """
    Calculate a reasonable DPI that maintains quality without excessive file size
    """
    fig_width, fig_height = fig.get_size_inches()
    
    # Use standard publication DPI
    base_dpi = 300
    
    # Only increase DPI if the figure is very small
    min_width_pixels = 1500
    min_height_pixels = 1000
    
    width_dpi = min_width_pixels / fig_width
    height_dpi = min_height_pixels / fig_height
    
    # Use the larger DPI value to ensure both dimensions meet minimum
    optimal_dpi = max(base_dpi, width_dpi, height_dpi)
    
    # Cap at 450 DPI to prevent oversized files
    return min(optimal_dpi, 450)

In [3]:
# Define functions outside of class
def update_filenames(input_filename_port, input_filename_land):
    # Append directory and .png only when necessary
    #updated_filename_port = f'{images_folder_name}/{input_filename_port}.png' if input_filename_port else None
    #updated_filename_land = f'{images_folder_name}/{input_filename_land}.png' if input_filename_land else None
    updated_filename_port = f'{input_filename_port}' if input_filename_port else None
    updated_filename_land = f'{input_filename_land}' if input_filename_land else None

    return updated_filename_land, updated_filename_port

def proceed_with_label_specific_options(selected_bio_or_pep, bio_or_pep):
    # Initialize selected_peptides and selected_functions to ensure they are always defined
    selected_peptides = []
    selected_functions = []

    # Check and handle different cases based on bio_or_pep value
    if bio_or_pep == '1':  # Assuming '1' indicates selection of peptides
        selected_peptides = list(selected_bio_or_pep) if isinstance(selected_bio_or_pep, (list, tuple)) else [
            selected_bio_or_pep]
    elif bio_or_pep == '2':  # Assuming '2' indicates selection of functions
        selected_functions = list(selected_bio_or_pep) if isinstance(selected_bio_or_pep, (list, tuple)) else [
            selected_bio_or_pep]
    # Optional: handle unexpected bio_or_pep values
    else:
        print(f"Unexpected value for bio_or_pep: {bio_or_pep}")

    return selected_peptides, selected_functions

"""-----------------Export_Function---------------------------------""" 
# Function to calculate y-ticks based on the min and max values of datasets
def calculate_y_ticks(min_values, max_values):
    """
    Calculate y-ticks with proper handling of zero/negative values
    """
    # Check for empty lists or all zeros
    if not min_values or not max_values:
        return [0, 1, 10]  # Default scale if no data
        
    # Filter out zeros and negative values, keep valid positives
    valid_mins = [x for x in min_values if x > 0]
    valid_maxs = [x for x in max_values if x > 0]
    
    # If no valid values after filtering
    if not valid_mins or not valid_maxs:
        return [0, 1, 10]
    
    try:
        overall_min = np.nanmin(valid_mins)
        overall_max = np.nanmax(valid_maxs)
        
        # If min or max are zero or negative after nanmin/nanmax
        if overall_min <= 0 or overall_max <= 0:
            return [0, 1, 10]
            
        min_power = 10 ** np.floor(np.log10(overall_min))
        max_power = 10 ** np.ceil(np.log10(overall_max))
        
        # Calculate midpoint in log space
        mid_point = np.sqrt(min_power * max_power)
        mid_point_rounded = 10 ** np.round(np.log10(mid_point))
        
        return [min_power, mid_point_rounded, max_power]
        
    except (ValueError, RuntimeWarning):
        return [0, 1, 10]  # Fallback for any calculation errors

def calculate_abundance(protein_sequence, peptide_dataframe, grouping_variable):
    protein_sequence_length = len(protein_sequence)
    data = []
    Avg_column = f'Avg_{grouping_variable}'

    for idx, row in peptide_dataframe.iterrows():
        try:
            start_idx = int(row['start']) - 1
            stop_idx = int(row['stop'])
            abundance_value = row[Avg_column]

            values = [0] * protein_sequence_length
            if abundance_value > 0:
                values[start_idx:stop_idx] = [abundance_value] * (stop_idx - start_idx)
            data.append(values)

        except (KeyError, ValueError) as e:
            print(f'{type(e).__name__}: {e} - Skipping this row')
            continue

    if not data:
        print("No valid data to process. Returning empty DataFrame.")
        return pd.DataFrame()

    # Create the initial DataFrame
    abundance_df = pd.DataFrame(data).T
    abundance_df.columns = [f'{int(row["start"])}-{int(row["stop"])}' for _, row in peptide_dataframe.iterrows()]

    
    # Initialize count list
    count_list = []
    
    # Check if abundance_df is empty or has no non-zero values
    if abundance_df.empty or abundance_df.eq(0).all().all():
        # Create all zeros for counts and averages
        count_list = [0] * len(protein_sequence)
        abundance_df['average'] = 0
    else:
        # Calculate counts and averages normally
        count_list = abundance_df.gt(0).sum(axis=1)
        abundance_df['average'] = abundance_df.replace(0, np.nan).mean(axis=1)
    
    # Assign the counts and amino acids
    abundance_df['count'] = count_list
    abundance_df['AA'] = list(protein_sequence)
    return abundance_df

def calculate_function(protein_sequence, peptide_dataframe, grouping_variable):
    protein_sequence_length = len(protein_sequence)
    data = []

    for _, row in peptide_dataframe.iterrows():
        start_idx = int(row['start'] - 1)
        stop_idx = int(row['stop'])
        if stop_idx > protein_sequence_length:
            stop_idx -= 1
            
        function_value = row['function'] if 'function' in peptide_dataframe.columns else np.nan
        
        values = [None] * protein_sequence_length
        for i in range(start_idx, stop_idx):
            values[i] = function_value

        data.append(values)

    function_df = pd.DataFrame(data).T
    function_df.columns = [f'{int(row["start"])}-{int(row["stop"])}' for _, row in peptide_dataframe.iterrows()]

    return function_df

def export_heatmap_data_to_dict(protein_id, group_key, group_info, protein_sequence, 
                               protein_species, protein_name, protein_df, is_all_null):
    """
    Exports the heatmap data to a dictionary based on filter type.
    """
    grouping_var = group_info['grouping_variable']
    relevant_columns = group_info['abundance_columns']
    
    # Check if required columns exist, if not create default DataFrame
    required_columns = ['function', 'start', 'stop']
    if not all(col in protein_df.columns for col in required_columns):
        filtered_df = pd.DataFrame({
            'start': [0],
            'stop': [1],
            'function': ['0']
        })
        filtered_protein_df = filtered_df.copy()
    else:
        filtered_protein_df = protein_df.dropna(subset=['function', 'start', 'stop'])
        filtered_df = filtered_protein_df[['start', 'stop', 'function']]

    # Calculate the heatmap data
    func_heatmap_df = calculate_function(protein_sequence, filtered_df, grouping_var)
    heatmap_df = calculate_abundance(protein_sequence, protein_df, grouping_var)
    filtered_heatmap_df = calculate_abundance(protein_sequence, filtered_protein_df, grouping_var)

    # Create the dictionary to store the results
    heatmap_data = {
        'protein_id': protein_id,
        'protein_sequence': protein_sequence,
        'protein_name': protein_name,
        'protein_species': protein_species,
        'func_heatmap_df': func_heatmap_df,
        'heatmap_df': heatmap_df,
        'filtered_heatmap_df': filtered_heatmap_df
    }
    
    return heatmap_data

def chunk_dataframe(df, chunk_size, exclude_columns=3):
    # Select all rows and all but the last 'exclude_columns' columns
    df_subset = df.iloc[:, :-exclude_columns] if exclude_columns else df
    
    # Check if df_subset is empty or all zeros
    if df_subset.empty or (df_subset == 0).all().all():
        # Return a single chunk with the default structure
        default_chunk = pd.DataFrame(
            np.zeros((chunk_size, df_subset.shape[1])), 
            columns=df_subset.columns
        )
        return [default_chunk]
    
    # Calculate the number of rows needed to make the last chunk exactly 'chunk_size'
    total_rows = df_subset.shape[0]
    remainder = total_rows % chunk_size
    
    if remainder != 0:
        # Rows needed to complete the last chunk
        rows_to_add = chunk_size - remainder
        # Create a DataFrame with zero values for the missing rows
        additional_rows = pd.DataFrame(
            np.zeros((rows_to_add, df_subset.shape[1])), 
            columns=df_subset.columns
        )
        # Append these rows to df_subset
        df_subset = pd.concat(
            [df_subset, additional_rows], 
            ignore_index=True, 
            copy=False,
            verify_integrity=True
        )
    
    # Create chunks of the DataFrame
    max_index = df_subset.index.max() + 1
    return [df_subset.iloc[i:i + chunk_size] for i in range(0, max_index, chunk_size)]
    
def process_available_data(available_data_variables, filter_type, selected_functions):
    """
    Process data for visualization when update_plot is called
    """
    if not available_data_variables:
        return 

    count_list = []
    min_values = []
    max_values = []
    seq_len_list = []
    chunk_size = 78
    
    # Process each variable's data for visualization
    for var in available_data_variables:
        if filter_type == 'all-peptides':
            df = available_data_variables[var]['heatmap_df']
        elif filter_type == 'bioactive-only':
            df = available_data_variables[var]['filtered_heatmap_df']
        elif filter_type == 'functional-only':
            if not selected_functions:
                error_html = '<span style="color: red; font-weight: bold;">Invalid Selection. Please select "Bioactive Functions" for "Plot Specific Peptides", with at least one function from the "Specific Options" dropdown menu selected.</span>'
                display(HTML(error_html))
                #return {'error': True,}    
            
            # Filter both functional and absorbance data for specific functions
            func_df = available_data_variables[var]['function_heatmap_df']
            if func_df is not None and not func_df.empty:
                # Get list of columns that contain any of the selected functions
                functional_positions = []
                for col in func_df.columns:
                    # Check if any selected function appears in any cell of this column
                    if any(any(func in str(cell) for func in selected_functions) for cell in func_df[col]):
                        functional_positions.append(col)
                              
                # Get absorbance data for these positions
                abs_df = available_data_variables[var]['filtered_heatmap_df']
                if abs_df is not None and not abs_df.empty and functional_positions:
                    # Only keep columns that exist in abs_df
                    valid_columns = [col for col in functional_positions if col in abs_df.columns]
                    if valid_columns:
                        # Get the filtered absorbance data for selected columns
                        selected_abs_df = abs_df[valid_columns]
                        
                        # Create DataFrame with sequence and selected columns
                        df = pd.DataFrame({
                            'AA': abs_df['AA'],
                        })
                        
                        # Add the selected columns
                        df = pd.concat([df, selected_abs_df], axis=1)
                        
                        # Now calculate count and average from all data
                        non_zero_mask = df.drop('AA', axis=1) > 0
                        df['count'] = non_zero_mask.sum(axis=1)
                        df['average'] = df.drop('AA', axis=1).where(non_zero_mask).mean(axis=1)
                    else:
                        df = None
                else:
                    df = None
            else:
                df = None

        # Get counts and MS data
        if df is None or df.empty or 'count' not in df.columns or 'average' not in df.columns:
            # Create default DataFrame with zero values
            protein_length = len(available_data_variables[var]['protein_sequence'])
            df = pd.DataFrame({
                'count': [0] * protein_length,
                'average': [0] * protein_length,
                'AA': list(available_data_variables[var]['protein_sequence'])
            })

        # Get counts and MS data
        peptide_counts = df['count']
        ms_data = df['average']
        count_list.append(peptide_counts)
  
        # Calculate min/max MS values
        min_ms = ms_data[ms_data > 0].min()
        max_ms = ms_data.max()
        min_values.append(min_ms)
        max_values.append(max_ms)
        
        # Add computed properties to the variable data
        available_data_variables[var].update({
            'peptide_counts': peptide_counts,
            'ms_data': ms_data,
            'ms_data_list': list(ms_data),
            'AA_list': df['AA'].tolist(),
            'max_peptide_counts': peptide_counts.max(),
            'min_peptide_counts': peptide_counts.min(),
            'max_ms_data': max_ms,
            'min_ms_data': min_ms,
            
            # Generate chunks
            'amino_acids_chunks': [
                available_data_variables[var]['protein_sequence'][i:i + chunk_size]
                for i in range(0, len(available_data_variables[var]['protein_sequence']), chunk_size)
            ],
            'peptide_counts_chunks': [
                peptide_counts[i:i + chunk_size]
                for i in range(0, len(peptide_counts), chunk_size)
            ],
            'ms_data_chunks': [
                ms_data[i:i + chunk_size]
                for i in range(0, len(ms_data), chunk_size)
            ]
        })
        
        # Process bioactive peptide data
        columns_to_include = df.columns.difference(['AA', 'count'])
        df_filtered = df[columns_to_include]
        
        available_data_variables[var].update({
            'bioactive_peptide_abs_df': df_filtered,
            #'bioactive_peptide_chunks': chunk_dataframe(df_filtered, chunk_size),
            #'bioactive_function_chunks': chunk_dataframe(available_data_variables[var]['function_heatmap_df'], chunk_size),
            'bioactive_peptide_func_df': available_data_variables[var]['function_heatmap_df']
        })
        
        seq_len_list.append(len(available_data_variables[var]['amino_acids_chunks'][0]))

    # Calculate global values
    global axis_number, total_plots, y_ticks
    max_sequence_length = max(seq_len_list)
    axis_number = len(available_data_variables) + 2
    num_sets = len(next(iter(available_data_variables.values()))['amino_acids_chunks'])
    total_plots = num_sets * axis_number
    style_map = assign_line_styles(available_data_variables)
    
    # Process counts
    if len(set([item for sublist in count_list for item in sublist])) >= 1:
        flat_list = [item for sublist in count_list for item in sublist]
        # Check if there are any non-zero values before filtering
        non_zero_values = [item for item in flat_list if item != 0]
        if non_zero_values:  # If there are non-zero values
            list_of_counts = set(non_zero_values)
            max_count = max(non_zero_values)
            num_unique_count = len(set(non_zero_values))
        else:  # If all values are zero
            list_of_counts = {0}
            max_count = 0
            num_unique_count = 1
    else:
        max_count = 0
        num_unique_count = 0
        list_of_counts = set()
    # Calculate number of colors
    num_colors = 6 if num_unique_count >= 6 else num_unique_count

    # Calculate global values
    if min_values and max_values:
        y_ticks = calculate_y_ticks(min_values, max_values)
        y_ticks_str = ', '.join(f'{tick:.2e}' for tick in y_ticks)
        y_ticks_html = f'<b>Max/Min of MS data (y-ticks):</b> {y_ticks_str}'
    else:
        y_ticks = [0.1, 1, 10]  # Default log scale values
        y_ticks_html = '<span style="color:red;">Insufficient data to calculate MS data y-ticks.</span>'
        y_ticks_html += f'<b>Protein Sequence Length:</b> {max_sequence_length}'

    return {
        'list_of_counts': list_of_counts,
        'min_values': min_values,
        'max_values': max_values,
        'seq_len_list': seq_len_list,
        'max_sequence_length': max_sequence_length,
        'y_ticks': y_ticks,
        'y_ticks_html': y_ticks_html,
        'max_count': max_count,
        'num_unique_count': num_unique_count,
        'num_colors': num_colors,
        'total_plots': total_plots,
        'style_map': style_map,
        'error': False

    }

# Update plot function
def update_plot(available_data_variables, ms_average_choice, bio_or_pep, selected_peptides, selected_functions,
                hm_selected_color, lp_selected_color, avglp_selected_color,
                xaxis_label, yaxis_label, yaxis_position, legend_title_input_1, legend_title_input_2,
                legend_title_input_3, legend_title_input_4, legend_title_input_5, plot_land, plot_port, filter_type):
    if not available_data_variables:  # Check if data is available
        display(HTML(f'<br><span style="color:red;">No data is available for plotting.</span>'))
        return (None, None)  # Return tuple of None values
    result = process_available_data(available_data_variables, filter_type, selected_functions)

    # Unpack the dictionary into individual global variables
    
    import _settings as settings
    if result:
        # Unpack the dictionary into individual global variables
        global list_of_counts, min_values, max_values, seq_len_list, max_sequence_length
        global y_ticks_html, max_count, num_unique_count, num_colors, total_plots, style_map, y_ticks
        
        lineplot_height, scale_factor = settings.port_hm_settings.get(len(available_data_variables), (20, 0.1))
        list_of_counts = result['list_of_counts']
        min_values = result['min_values']
        max_values = result['max_values']
        seq_len_list = result['seq_len_list']
        max_sequence_length = result['max_sequence_length']
        y_ticks_html = result['y_ticks_html']
        max_count = result['max_count']
        num_unique_count = result['num_unique_count']
        num_colors = result['num_colors']
        total_plots = result['total_plots']
        style_map = result['style_map']

        yaxis_position_land = yaxis_position_port = 0.0 - 0.01 * yaxis_position
        legend_title = [legend_title_input_1, legend_title_input_2, legend_title_input_3, legend_title_input_4,
                        legend_title_input_5]

        # Your plotting code here, using the widget values as inputs

        cmap = plt.get_cmap(hm_selected_color)
        avg_cmap = plt.get_cmap(avglp_selected_color)

        if ms_average_choice == 'no' and bio_or_pep == 'no':
            display(HTML(f'<br><b>Reslect "Ploting Options", no data is available for plotting.</b><br><br>'))
            
            return None, None

        else:
            """ #                                       Function Call to Generate Plot
            #---------------------------------------------------------------------------------------------------------------------------------------------
            """
            if plot_port and (bio_or_pep == 'no'):
                # Temporarily suppress specific warning
                with warnings.catch_warnings():
                    warnings.simplefilter('ignore', UserWarning)
                    if ms_average_choice in ['yes', 'only']:
                        fig_port = visualize_sequence_heatmap_portrait(
                            available_data_variables,
                            0.001,  
                            lineplot_height,  
                            1,  
                            xaxis_label,  
                            yaxis_label,  
                            legend_title,  
                            yaxis_position_port,  
                            cmap,
                            avg_cmap,
                            lp_selected_color,
                            avglp_selected_color,
                            selected_functions,
                            ms_average_choice,
                            bio_or_pep,
                            78)  # removed chunk_size= to make it a positional argument
                    else:
                        display(HTML(f'<br><b>No was selected earlier regarding the plotting of averaged absorbances, preventing the plotting of the averaged plot.</b><br><br>'))
                    if max_count <= 1:
                        display(HTML(f'<br><b style="color:red;">You have to few peptides for proper heatmapping.</b>'))

            """_____________________________________________________EXECUTING CODE TO PLOT W/ UPDATES_________________________________________________________________"""

            # Plotting code
            if plot_land:
                amino_acid_height = 0.25 + 0.1 * (
                    len(available_data_variables) // 4 if len(available_data_variables) >= 8 else 0)
                indices_height = amino_acid_height + 0.08
                with warnings.catch_warnings():
                    warnings.simplefilter('ignore', UserWarning)
                    fig_land = visualize_sequence_heatmap_lanscape(
                        available_data_variables,
                        amino_acid_height,
                        7.5,
                        indices_height,
                        xaxis_label,
                        yaxis_label,
                        legend_title,
                        yaxis_position_land,
                        cmap,
                        avg_cmap,
                        lp_selected_color,
                        avglp_selected_color,
                        selected_peptides,
                        selected_functions,
                        ms_average_choice,
                        bio_or_pep,
                        filter_type)

                if max_count <= 1:
                    display(HTML(f'<br><b style="color:red;">You have to few peptides for proper heatmapping.</b>'))
        if plot_port:
            if bio_or_pep != 'no':
                    display(HTML(f'<br><b>The portrait plot only supports ploting averaged abundance of peptides which is not formated to be combined with specified peptide intervals or functions.</b><br><br>'))
                    fig_port_return = None 
            else:
                if len(fig_port.axes) > 0:
                    fig_port_return = fig_port
        else:
            fig_port_return = None 
            
        if plot_land:
            if len(fig_land.axes) > 0:
                fig_land_return = fig_land
        else:
            fig_land_return = None         

        if not plot_port:
            if not plot_land:
                display(HTML(f"<br><br><p><strong>One or both of the create plot checkbox must be selected to generate a plot.</strong></p>"))

        return fig_port_return, fig_land_return # Return the list of figures
    else:
        display(HTML(f'<br><b>Reslect "Variable Options", no data is available for plotting.</b><br><br>'))

"""_________________________________________Data Visualization Functions_________________________________"""
# Function to plot rows of amino acids with backgrounds colored
def plot_row_color(ax, amino_acids, colors):
    ax.axis('off')
    ax.set_xlim(0, max_sequence_length)
    ax.set_xlabel('')
    for j, (aa, color) in enumerate(zip(amino_acids, colors)):
        ax.text(j + 0.5, 0.5, aa, color='black', ha='center', va='center', fontsize=14,
                backgroundcolor=mcolors.rgb2hex(color))

# Assigns line type for landscape plot if plotting individual peptides
def assign_line_styles(data_variables):
    # Define a set of line styles you find visually distinct
    line_styles = cycle(['-', '--', ':', '-.'])
    #line_styles = cycle(['-'])

    # Extract unique labels from your data variables
    unique_labels = set(data['label'] for data in data_variables.values())

    # Map each unique label to a line style
    style_map = {label: next(line_styles) for label in unique_labels}

    # Map each unique label to a line style
    #style_map = {'Gastric IVT': '--',
    #             'Intestinal IVT': ':',
    #             'J1H': '-', 'J2H': '-', 'J3H': '-', 'J4H': '-',}
    return style_map

# Function to plot rows of amino acids with NO backgrounds colored
def plot_row(ax, amino_acids):
    ax.axis('off')
    ax.set_xlim(0, max_sequence_length)
    ax.set_xlabel('')
    for j, (aa) in enumerate(amino_acids):
        ax.text(j + 0.5, 0.5, aa, color='black', ha='center', va='center', fontsize=8)  # backgroundcolor='white')

# Function to plot continuse averaged lines
def plot_average_ms_data(ax, ms_data, label, var_index, y_ticks, i, chunk_size, avg_cmap, line_style):
    """
    Plot average MS data as a line plot on a twin axis, incorporating a line style, and return the line object and its properties.
    """
    start_limit = i * chunk_size
    end_limit = (i + 1) * chunk_size - 1
    ax.set_xlim(start_limit, end_limit)

    if isinstance(ms_data, (pd.DataFrame, pd.Series)):
        x_values = ms_data.index.tolist()
        y_values = ms_data.values
    else:
        x_values = list(range(len(ms_data)))  # Indices from 0 to len(ms_data)-1
        y_values = ms_data

    num_colors = avg_cmap.N

    # Get the color from the colormap and save it into avglp_selected_color
    color = avg_cmap(var_index % num_colors)
    line, = ax.plot(x_values, y_values, label=label, color=color, linestyle=line_style)
    ax.set_yscale('log')
    ax.set_yticks(y_ticks)
    ax.set_ylim(min(y_ticks), max(y_ticks))
    ax.tick_params(axis='y', labelsize=16)
    ax.yaxis.tick_left()
    ax.set_xticks([])
    ax.set_xticklabels([])

    return line, label, var_index  # Return the line and its label for further processing if needed

# Function to extract non-zero, non-NaN values
def extract_non_zero_non_nan_values(df):
    unique_functions = set()
    # Iterate over each value in the DataFrame
    for value in df.stack().values:  # df.stack() stacks the DataFrame into a Series
        if value != 0 and not pd.isna(value):  # Check if value is non-zero and not NaN
            if isinstance(value, str):
                # If the value is a string, it could contain multiple delimited entries
                entries = value.split('; ')
                unique_functions.update(entries)
            else:
                unique_functions.add(value)
    return unique_functions

    # This function is used in plotting to filter the data to only plot selected peptides or bioactive peptides, independent of averaged data

def filter_data_by_selection(bp_abs, bp_func, selected_peptides, selected_functions, bio_or_pep, filter_type):
    """
    Filters the data based on user selection with improved column handling
    """
    # Early return if any DataFrame is None or empty
    if bp_abs is None or bp_abs.empty:
        print("DataFrame is None or empty")
        return pd.DataFrame(), pd.DataFrame()
        
    # Special handling for peptide selection (bio_or_pep = '1')
    if bio_or_pep == '1' and selected_peptides:
        try:
            # First identify which metadata columns we want to keep
            meta_columns = ['AA', 'count']
            keep_columns = [col for col in meta_columns if col in bp_abs.columns]
            
            # Then find our selected peptide columns that exist in the DataFrame
            peptide_columns = []
            for pep in selected_peptides:
                if pep in bp_abs.columns:
                    peptide_columns.append(pep)
            
           
            # Combine metadata and peptide columns
            all_columns = keep_columns + peptide_columns
            
            if peptide_columns:  # Only proceed if we found matching peptide columns
                # Create new DataFrames with only the columns we want
                filtered_bp_abs = bp_abs[all_columns].copy()
                filtered_bp_func = pd.DataFrame()

                return filtered_bp_abs, filtered_bp_func
            else:
                print("No matching peptide columns found")
                return pd.DataFrame(), pd.DataFrame()
                
        except Exception as e:
            print(f"Error in peptide filtering: {str(e)}")
            import traceback
            traceback.print_exc()
            return pd.DataFrame(), pd.DataFrame()
            
    # Handle other cases (bio_or_pep = '2' or 'no')
    elif bio_or_pep == '2' and selected_functions:
        # Filter by selected functions
        if bp_func is None or bp_func.empty:
            return pd.DataFrame(), pd.DataFrame()
            
        try:
            def has_selected_function(value):
                if pd.isna(value):
                    return False
                value_str = str(value)
                return any(func in value_str for func in selected_functions)
            
            # Create mask based on function presence
            mask = bp_func.apply(lambda col: col.apply(has_selected_function))
            cols_with_functions = mask.any()
            relevant_columns = cols_with_functions[cols_with_functions].index
            
            # Filter both DataFrames using the same columns
            if len(relevant_columns) > 0:
                filtered_bp_abs = bp_abs[relevant_columns]
                filtered_bp_func = bp_func[relevant_columns]
                return filtered_bp_abs, filtered_bp_func
            else:
                return pd.DataFrame(), pd.DataFrame()
                
        except Exception as e:
            print(f"Error in function filtering: {str(e)}")
            return pd.DataFrame(), pd.DataFrame()
            
    # No filtering case
    else:
        return bp_abs, bp_func

def process_chunk_data(ax2, chunk_abs, chunk_func, chunk_size, i, y_ticks, handles, labels, 
                      sample_list, var_name_list, line_style, var_name, var_ms_data, 
                      selected_peptides, selected_functions, lp_selected_color, ms_average_choice, bio_or_pep):
    """
    Process and plot chunk data with proper handling of peptides and functions.
    """
    print_list = []
    
    start_limit = i * chunk_size
    end_limit = (i + 1) * chunk_size - 1
    ax2.set_xlim(start_limit, end_limit)
    
    # Set up colormap
    colormap = plt.get_cmap(lp_selected_color)
    if bio_or_pep == '1':
        num_colors = max(len(selected_peptides), 1)
        items_to_color = selected_peptides
    elif bio_or_pep == '2':
        num_colors = max(len(selected_functions), 1)
        items_to_color = selected_functions
    else:
        num_colors = 1
        items_to_color = []
    
    colors = colormap(np.linspace(0, 1, num_colors))
    function_colors = {item: colors[i % len(colors)] for i, item in enumerate(items_to_color)}
    function_colors['Multiple'] = 'black'
    
    # Process each column in the abundance data
    for col in chunk_abs.columns:
        y_values = chunk_abs[col].dropna()
        y_values = y_values[y_values > 0]
        
        if not y_values.empty:
            x_values = y_values.index
            label_value = 'No Label'
            
            if bio_or_pep == '1':
                label_value = col
            elif bio_or_pep == '2' and col in chunk_func.columns:
                func_values = chunk_func[col].dropna()
                if not func_values.empty:
                    # Convert to string and split if needed
                    func_list = []
                    for val in func_values:
                        if pd.notna(val):
                            if isinstance(val, str) and ';' in val:
                                func_list.extend(val.split(';'))
                            else:
                                func_list.append(str(val))
                    
                    # Remove duplicates and sort
                    func_list = sorted(set(func_list))
                    
                    if len(func_list) > 1 and len(selected_functions) > 1:
                        print_list.append(func_list)
                        label_value = 'Multiple'
                    elif len(func_list) > 1 and len(selected_functions) == 1:
                        label_value = selected_functions[0]                    
                    elif len(func_list) == 1:
                        label_value = func_list[0]
            
            if label_value != 'No Label':
                color = function_colors.get(label_value, 'grey')
                try:
                    lines = ax2.plot(x_values, y_values, label=f'{label_value}', 
                                   linestyle=line_style, color=color)
                    line = lines[0]
                    
                    handles.append(line)
                    labels.append(f'{label_value}')
                    sample_list.append(f'{line_style}')
                    var_name_list.append(f'{var_name}')
                except Exception as e:
                    print(f"Error plotting line for {label_value}: {str(e)}")
                    continue
    
    # Set up axis properties
    ax2.set_yscale('log')
    ax2.set_yticks(y_ticks)
    ax2.set_ylim(min(y_ticks), max(y_ticks))
    ax2.tick_params(axis='y', labelsize=16)
    ax2.yaxis.tick_left()
    
    return print_list

def get_grouped_colors(counts, max_count, num_groups, plot_zero, cmap):
    # Initialize colors list with None to maintain length
    colors = [None] * len(counts)

    # Set the start point based on the user's input
    start_point = 0 if plot_zero == 'yes' else 1

    # Group counts into fewer categories if necessary, excluding zeros
    group_bounds = np.linspace(start_point, max_count, num_groups + 1)
    group_labels = np.digitize(counts, bins=group_bounds, right=True)  # Find group labels for counts
    ##print("group_labels", group_labels)
    norm = Normalize(vmin=start_point, vmax=num_groups)  # Normalize based on the number of groups

    # Map each count to a color
    #for i, count in enumerate(counts):
    for i, (count, label) in enumerate(zip(counts, group_labels)):
        if plot_zero == 'no' and count == 0:
            colors[i] = 'white'
        else:
            # Find which group the count belongs to
            group_idx = np.digitize(count, group_bounds) - 1
            
            if max_count > 20:
                # For large ranges, use continuous color mapping
                colors[i] = cmap(count / max_count)
            else:
                # For smaller ranges, use discrete color groups
                #colors[i] = cmap(norm(group_bounds[min(group_idx, len(group_bounds)-1)]))
                colors[i] = cmap(norm(label))

    #print("label",label)
    return colors

# Function to create legend for heatmap
def create_heatmap_legend_handles(cmap, num_colors, max_count, plot_zero):
    """
    Create color-coded legend handles for the heatmap with proper handling of zero values
    and error cases
    """
    try:
        # Handle case where all values are zero
        if max_count == 0:
           norm = Normalize(vmin=0, vmax=1)
           color = cmap(norm(0))
           return [patches.Patch(color=color, label='0')], ['0']
                 
        # Set the start point based on plot_zero
        start_point = 0 if plot_zero == 'yes' else 1
        
        # Create group boundaries to match get_grouped_colors
        count_ranges = np.linspace(start_point, max_count, num_colors + 1)
        
        # Determine interval type based on count vs intervals
        if max_count > num_colors:
           plt_interval = max_count
        else:
           plt_interval = num_colors
           
        legend_handles = []
        heatmap_labels = []
        norm = Normalize(vmin=start_point, vmax=max_count)

        # For small ranges, start from index 1 to skip the duplicate 1
        start_idx = 1 if plt_interval <= 6 else 0
        for i in range(start_idx, len(count_ranges)):
            color = cmap(norm(count_ranges[i]))
            if plt_interval <= 6 and plot_zero == 'no':
                label = f'{int(count_ranges[i])}'
            elif plt_interval <= 6 and plot_zero == 'yes':
                label = f'{int(count_ranges[i])}'
            elif i + 1 >= len(count_ranges):
                label = f'{int(count_ranges[i])} - {max(count_ranges)}'
                break
            else:
                # Create label showing range
                start_val = int(count_ranges[i])
                end_val = int(count_ranges[i + 1])
                label = f'{start_val} - {end_val}'
                
            #if i = len(count_ranges):
            #    color = cmap(norm(count_ranges[i]))
 
            #print("i:",i,"label:",label)
            #legend_handles.append(patches.Patch(color=color, label=label))
            legend_handles.append(patches.Patch(color=color, label=label))
            heatmap_labels.append(label)
        return legend_handles, heatmap_labels
       
    except Exception as e:
       # Fallback for error cases
       color = cmap(0)
       handle = patches.Patch(color=color, label='0')
       return [handle], ['0']

def create_custom_legends(fig, labels, handles, var_name_list, legend_titles, heatmap_legend_handles,
                          heatmap_legend_labels, ms_average_choice, bio_or_pep, plot_type):
    handles_dict = {}
    sample_handles_dict = {}

    # Modify label if needed and populate dictionaries
    for label, handle, sample_name in zip(labels, handles, var_name_list):
        handles_dict[label] = handle  # Store or update the handle with modified label

        # Store handles for sample types (assuming sample_name correctly aligns with the handles)
        if sample_name not in sample_handles_dict:
            sample_handles_dict[sample_name] = handle

    # Create new handles for the legend with modified properties
    new_handles_func = [copy.copy(handle) for handle in handles_dict.values()]
    new_labels_func = [label for label in handles_dict.keys()]

    # Initial empty lists for combined handles and labels
    combined_handles = []
    combined_labels = []
    # Filter for "Averaged" labels
    averaged_handles = []
    averaged_labels = []
    other_handles = []
    other_labels = []

    for handle, label in zip(new_handles_func, new_labels_func):
        if "Averaged" in label:
            clean_label = label.replace("Averaged ", "")  # Remove 'Averaged ' from the label
            averaged_handles.append(handle)
            averaged_labels.append(clean_label)  # Append the cleaned label

        else:
            handle.set_linestyle('-')  # Set line style to solid
            other_handles.append(handle)
            other_labels.append(label)

    new_handles_samples = []
    if not averaged_handles:
        for handle in sample_handles_dict.values():
            new_handle = copy.copy(handle)
            new_handle.set_color('black')  # Set color to black for sample type handles
            new_handles_samples.append(new_handle)

    # Dummy handles for subtitles
    line_type = mlines.Line2D([], [], color='none', label='Line Type')
    average_color = mlines.Line2D([], [], color='none', label='Average Absorbance')
    line_color = mlines.Line2D([], [], color='none', label='Line Color')
    pep_count_placeholder = mlines.Line2D([], [], color='none', label='Line Type')

    line_type_title = legend_titles[0]
    avgline_color_title = legend_titles[4]
    if bio_or_pep == '1':
        color_title = legend_titles[3]
    elif bio_or_pep == '2':
        color_title = legend_titles[2]

    if ms_average_choice == 'yes' and bio_or_pep != 'no':
        combined_handles = [line_color] + other_handles + [average_color] + averaged_handles
        combined_labels = [color_title] + other_labels + [avgline_color_title] + averaged_labels

    if ms_average_choice == 'yes' and bio_or_pep == 'no':
        combined_handles = [average_color] + averaged_handles
        combined_labels = [avgline_color_title] + averaged_labels

    if ms_average_choice == 'only':
        combined_handles = [average_color] + averaged_handles
        combined_labels = [avgline_color_title] + averaged_labels
        
    if ms_average_choice == 'no' and bio_or_pep != 'no':
        combined_handles = [line_color] + other_handles + [line_type] + new_handles_samples
        combined_labels = [color_title] + other_labels + [line_type_title] + [key for key in sample_handles_dict.keys()]

    legend_peptide_count = None
    if plot_type == "land":
        # Create the peptide count legend separately
        if plot_heatmap == 'yes':
            legend_peptide_count = fig.legend(handles=heatmap_legend_handles, loc='center',
                                              fontsize=14,
                                              title=legend_titles[1],
                                              title_fontsize=14,
                                              bbox_to_anchor=(0.5, -0.1),
                                              ncol=len(heatmap_legend_handles))

        # Create the combined legend (for other handles/labels)
        combined_legend = fig.legend(handles=combined_handles, labels=combined_labels,
                                     loc='upper left', bbox_to_anchor=(0.99, 0.975),
                                     fontsize=14, handlelength=2)
        plt.tight_layout()
        fig.subplots_adjust(right=0.9)  # Make room for side legend

    # Create a single combined legend
    elif plot_type == "port":

        # Combine heatmap handles with other handles for a single legend
        combined_handles = [line_color] + other_handles + [pep_count_placeholder] + heatmap_legend_handles
        combined_labels = [avgline_color_title] + other_labels + [legend_titles[1]] + heatmap_legend_labels

        # Create a single combined legend with just the combined handles (no need for additional labels)
        combined_legend = fig.legend(handles=combined_handles,
                                     labels=combined_labels,
                                     loc='upper left',
                                     bbox_to_anchor=(0.9025, 0.875),
                                     fontsize=14)

    return combined_legend, legend_peptide_count

"""_________________________________________Landscape Plot________________________________________"""
### def visualize_sequence_heatmap_individual_lines(available_data_variables, amino_acid_height, lineplot_height, indices_height, filename, xaxis_label, yaxis_label, legend_title, yaxis_position):
def visualize_sequence_heatmap_lanscape(available_data_variables,
                                                         amino_acid_height,
                                                         lineplot_height,
                                                         indices_height,
                                                         xaxis_label,
                                                         yaxis_label,
                                                         legend_title,
                                                         yaxis_position,
                                                         cmap,
                                                         avg_cmap,
                                                         lp_selected_color,
                                                         avglp_selected_color,
                                                         selected_peptides,
                                                         selected_functions,
                                                         ms_average_choice,
                                                         bio_or_pep, filter_type):
    # Use a list comprehension to find the maximum length of 'AA_list' across all variables
    max_sequence_length = max([len(available_data_variables[var]['AA_list']) for var in available_data_variables])

    # Check if there are multiple distinct protein IDs in available_data_variables
    multiple_proteins = len(set([available_data_variables[var]['protein_id'] for var in available_data_variables])) > 1

    chunk_size = max_sequence_length
    # Create legend handles for the heatmap
    heatmap_legend_handles, heatmap_legend_labels = create_heatmap_legend_handles(cmap, num_colors, max_count, plot_zero)  # You can change the number 5 to have more or fewer color intervals

    # Function to plot rows of amino acids with backgrounds colored
    # Function to plot rows of amino acids with backgrounds colored
    def plot_row_color_landscape(ax, amino_acids, colors):
        ax.axis('off')
        ax.set_xlim(0, max_sequence_length)
        ax.set_xlabel('')
        # Height and width for each cell
        cell_width = 1  # Each amino acid is spaced evenly by 1 unit on the x-axis
        cell_height = 1  # Set height of the row

        for j, (aa, color) in enumerate(zip(amino_acids, colors)):
            # Create a rectangle (cell) with the background color
            rect = patches.Rectangle((j, 0), cell_width, cell_height, color=mcolors.rgb2hex(color))
            ax.add_patch(rect)  # Add the colored rectangle to the plot

    # Function to plot rows of amino acids with backgrounds colored
    def plot_row_landscape(ax, amino_acids):
        ax3.axis('off')
        ax3.set_xlim(0, max_sequence_length)
        ax3.set_xlabel('')
        aa_font_size = 10
        if max_sequence_length > 200:
            aa_font_size -= 0.5
        if max_sequence_length > 250:
            aa_font_size -= 1
        if max_sequence_length > 300:
            return
        for j, (aa) in enumerate(amino_acids):
                ax3.text(j + 0.5, 0.5, aa, color='black', ha='center', va='center',
                         fontsize=aa_font_size)

    axis_number = 3  # Total number of plots per set of data
    if plot_heatmap == 'yes':
        # Define height ratios for each subplot in a set
        height_ratios = (
                    [lineplot_height] + [indices_height] + [amino_acid_height] * len(available_data_variables) + [amino_acid_height])
        axis_number = len(available_data_variables) + 3

    elif plot_heatmap == 'no':
        height_ratios = ([lineplot_height] +[indices_height] + [amino_acid_height] )
        axis_number = 3

    fig, axes = plt.subplots(axis_number, 1, figsize=(25, (lineplot_height + indices_height + amino_acid_height)),
                             gridspec_kw={'height_ratios': height_ratios, 'hspace': 0})


    # Initialize for legend handling
    handles, labels, sample_list, var_name_list = [], [], [], []
    total_count = 0

    # Loop through each set of data and create plots
    for var_index, var in enumerate(available_data_variables):
        ax1 = axes[0]
        ax1.axis('off')
        ax1 = ax1.twinx()  # Create a twin y-axis
        ax1.yaxis.set_minor_locator(plt.NullLocator())

        var_amino_acids = available_data_variables[var]['AA_list']
        var_counts = available_data_variables[var]['peptide_counts']
        var_ms_data = available_data_variables[var]['ms_data_list']
        var_name = available_data_variables[var]['label']
        var_colors = get_grouped_colors(var_counts, max_count, num_colors, plot_zero, cmap)
        bp_abs = available_data_variables[var]['bioactive_peptide_abs_df']
        bp_func = available_data_variables[var]['bioactive_peptide_func_df']
        # Default line style
        line_style = '-'  # Default style if other conditions don't apply

        #if ms_average_choice == 'yes' and bio_or_pep == 'no':
        #   line_style = '-'  # This can stay '-' or change as per your requirement
        #else:
        line_style = style_map[var_name]  # Get assigned line style from the style map

        if bio_or_pep != 'no' and ms_average_choice != 'only':
            filtered_bp_abs, filtered_bp_fun = filter_data_by_selection(bp_abs, bp_func, selected_peptides,
                                                                        selected_functions, bio_or_pep, filter_type)
            print_list = process_chunk_data(ax1, filtered_bp_abs, filtered_bp_fun, chunk_size, 0, y_ticks, handles,
                                            labels, sample_list, var_name_list, line_style, var_name, var_ms_data,
                                            selected_peptides, selected_functions, lp_selected_color, ms_average_choice,
                                            bio_or_pep)

        if ms_average_choice in ['yes', 'only']:
            # Ensure that line_style is defined before this point or provide a default value
            line, label, _ = plot_average_ms_data(ax1, var_ms_data, f'Averaged {var_name}', var_index, y_ticks, 0,
                                                  chunk_size, avg_cmap, line_style)
            handles.append(line)
            labels.append(f'{label}')
            sample_list.append(line_style)
            var_name_list.append(var_name)

        # Plot indices below the MS line plot
        ax2 = axes[1]
        ax2.axis('off')
        ax2.set_xlim(0, max_sequence_length)
        indices = [0]
        if var_index == 0:
            # Add indices at increments of 20, starting from 20 up to the length of the array, but not including the last index if it's less than 20 away
            indices.extend(range(20, max_sequence_length - 5, 20))
            # Always add the last index of the array
            indices.append(max_sequence_length - 1)
            for idx in indices:
                ax2.text(idx + 0.5, 0.5, str(total_count + idx + 1), ha='center', va='center', fontsize=16)
            total_count += max_sequence_length

            # Amino acid plots
            ax3 = axes[2]

            # Check if there's only one distinct protein
            if not multiple_proteins:
                plot_row_landscape(ax3, var_amino_acids)  # Call plot_row_landscape if only 1 protein
            else:
                ax3.axis('off')  # Plot a blank line by turning off the axis

        if plot_heatmap == 'yes':
            # Amino acid plots
            ax = axes[var_index + 3]
            plot_row_color_landscape(ax, var_amino_acids, var_colors)
            ax.text(0, 0.5, var_name, ha='right', va='center', fontsize=14)
    # Create the legend after the plotting loop, using the handles and labels without duplicates
    # This will create a dictionary only with entries where the label is not '0'abs
    create_custom_legends(fig, labels, handles, var_name_list, legend_title, heatmap_legend_handles,
                          heatmap_legend_labels, ms_average_choice, bio_or_pep, plot_type="land")

    if bio_or_pep == '2' and ms_average_choice != 'only':
        # Flatten the list of lists into a single list of strings
        print_list = [item for sublist in print_list for item in sublist]

        # Remove duplicates from the list by converting it to a set, then convert it back to a list
        print_list = list(set(print_list))

        # Remove the empty string if it exists
        if '' in print_list:
            print_list.remove('')

        # Enumerate the list and format the output
        print_list = [f"     {i + 1}. {label}" for i, label in enumerate(print_list)]
        print_list.insert(0, "The following labels have been relabeled as 'Multiple':")

        # Join all the elements of the list into a single string with newlines
        footnote = '\n'.join(print_list)
        #fig.text(0.15, -0.2, footnote, ha='left', va='center', fontsize=12)



    fig.text(yaxis_position, 0.90, yaxis_label, va='top', rotation='vertical', fontsize=16)
    fig.text(0.5, -0.025, xaxis_label, ha='center', va='center', fontsize=16)
    plt.tight_layout()
    #plt.subplots_adjust(left=0.05)  # Create space on the left for the y-label

    if ms_average_choice == 'only' and bio_or_pep == '1':
        display(HTML(f'<b style="color:red;">The selection of Plot Averaged Data option "only" and Plot Specific Peptides option "Peptide Intervals" is invalid. Only the average absrobance will be plotted.</b>'))

    if bio_or_pep == '2' and ms_average_choice != 'only':
        if len(print_list) > 1:
            print(footnote)
    return fig

"""_______________________________Portrait Plot______________________________________________"""
def visualize_sequence_heatmap_portrait(available_data_variables,
                                             amino_acid_height,
                                             lineplot_height,
                                             indices_height,
                                             xaxis_label,
                                             yaxis_label,
                                             legend_title,
                                             yaxis_position,
                                             cmap,
                                             avg_cmap,
                                             lp_selected_color,
                                             avglp_selected_color,
                                             selected_functions,
                                             ms_average_choice,
                                             bio_or_pep,
                                             chunk_size):


    lineplot_height, scale_factor = settings.port_hm_settings.get(
        len(available_data_variables), (20, 0.1))
    plot_zero == 'no'
    handles, labels, sample_list, var_name_list = [], [], [], []

    for var in available_data_variables:
        num_sets = len(available_data_variables[var]['amino_acids_chunks'])
    # Create legend handles for the heatmap
    heatmap_legend_handles, heatmap_legend_labels = create_heatmap_legend_handles(cmap, num_colors, max_count,
                                                                                  plot_zero)  # You can change the number 5 to have more or fewer color intervals

    # Define height ratios for each subplot in a set
    height_ratios = ([lineplot_height] + [indices_height] + [amino_acid_height] * len(available_data_variables)) * num_sets

    # Create a figure and set of subplots
    fig = plt.figure()

    fig, axes = plt.subplots(total_plots, 1, figsize=(20, num_sets * (
            lineplot_height + indices_height + amino_acid_height * len(available_data_variables)) * scale_factor),
                             gridspec_kw={'height_ratios': height_ratios, 'hspace': 1})

    # Initialize for legend handling
    handles, labels = [], []
    total_count = 0



    # Loop through each set of data and create plots
    for i in range(num_sets):

        # Determine the max_var_amino_acids for the current i across all variables
        max_var_amino_acids = max(
            len(available_data_variables[var]['amino_acids_chunks'][i])
            for var in available_data_variables
        )

        for var_index, var in enumerate(available_data_variables):
            ax1 = axes[axis_number * i]
            ax1.axis('off')
            ax2 = ax1.twinx()  # Create a twin y-axis

            # Get data chunks for the current variable
            var_amino_acids = available_data_variables[var]['amino_acids_chunks'][i]
            var_counts = available_data_variables[var]['peptide_counts_chunks'][i]
            var_ms_data = available_data_variables[var]['ms_data_chunks'][i]
            var_name = available_data_variables[var]['label']
            var_colors = get_grouped_colors(var_counts, max_count, num_colors, plot_zero, cmap)          
            line_style = '-'

            # Plot MS data using plot_average_ms_data and handle the returned line object
            line, label, _ = plot_average_ms_data(ax2, var_ms_data, var_name, var_index, y_ticks, i, chunk_size,
                                                  avg_cmap, line_style='-')
            handles.append(line)
            labels.append(label)
            var_name_list.append(var_name)

            # Plot indices below the MS line plot
            ax = axes[axis_number * i + 1]
            ax.axis('off')
            ax.set_xlim(0, max_sequence_length)
            indices = [0]
            if var_index == 0:
                # Add indices at increments of 20, starting from 20 up to the length of the array, but not including the last index if it's less than 20 away
                indices.extend(range(20, max_var_amino_acids - 5, 20))
                # Always add the last index of the array
                indices.append(max_var_amino_acids - 1)
                for idx in indices:
                    ax.text(idx + 0.5, 0.5, str(total_count + idx + 1), ha='center', va='center', fontsize=16)
                total_count += max_var_amino_acids

            # Amino acid plots
            ax = axes[axis_number * i + var_index + 2]
            plot_row_color(ax, var_amino_acids, var_colors)
            ax.text(0, 0.5, f'{var_name}  ', ha='right', va='center', fontsize=14)

    # Create the legend after the plotting loop, using the handles and labels without duplicates
    # Create legend handles for the heatmap
    # Initialize for legend handling
    total_count = 0

    # Create the legend after the plotting loop, using the handles and labels without duplicates
    # This will create a dictionary only with entries where the label is not '0'abs
    create_custom_legends(fig, labels, handles, var_name_list, legend_title, heatmap_legend_handles,
                          heatmap_legend_labels, ms_average_choice, bio_or_pep, plot_type="port")

    """
    handles_dict = dict(zip(labels, handles))
    legend_samples = fig.legend(handles_dict.values(), handles_dict.keys(), loc='center left',
                                bbox_to_anchor=(.905, top_legend_pos), fontsize=16, title=legend_title[0], title_fontsize=18)
    legend_peptide_count = fig.legend(handles=heatmap_legend_handles, loc='center left',
                                      bbox_to_anchor=(.905, bot_legend_pos), fontsize=16, title=legend_title[1], title_fontsize=18)
    """
    # Adjust layout and save the figure
    plt.tight_layout()
    #plt.subplots_adjust(left=0.15)  # Create space on the left for the y-label
    fig.text(yaxis_position, 0.5, yaxis_label, va='center', rotation='vertical', fontsize=16)
    fig.text(0.5, 0.05, xaxis_label, ha='center', va='center', fontsize=16)

    # Display the plot inline
    #display(fig)
    #plt.close(fig)  # Close the figure to avoid duplicate display in some environments
    return fig

In [4]:
class DataTransformation(HasTraits):
    merged_df = Instance(pd.DataFrame, allow_none=True)
    group_data = Instance(dict, allow_none=True)
    def __init__(self):
        super().__init__()
        self.merged_df = pd.DataFrame()
        self.protein_dict = {}
        self.group_data = {}
        self.output_area = None
        self.merged_uploader = None
        self.fasta_uploader = None
        self.reset_button = None
        self.uniprot_client = None
            
    def create_download_link(self, file_path, label):
        """Create a download link for a file."""
        if os.path.exists(file_path):
            # Read file content and encode it as base64
            with open(file_path, 'rb') as f:
                content = f.read()
            b64_content = base64.b64encode(content).decode('utf-8')

            # Generate the download link HTML
            return widgets.HTML(f"""
                <a download="{os.path.basename(file_path)}" 
                    href="data:application/octet-stream;base64,{b64_content}" 
                    style="color: #0366d6; text-decoration: none; margin-left: 20px; font-size: 14px;">
                    {label}
                </a>
            """)
        else:
            # Show an error message if the file does not exist
            return widgets.HTML(f"""
                <span style="color: red; margin-left: 20px; font-size: 14px;">
                    File "{file_path}" not found!
                </span>
                """)

    def setup_data_loading_ui(self):
        """Initialize and display the data loading UI."""
        # Create file upload widgets
        self.merged_uploader = widgets.FileUpload(
            accept='.csv,.txt,.tsv,.xlsx',
            multiple=False,
            description='Upload Merged Data File',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )

        self.fasta_uploader = widgets.FileUpload(
            accept='.fasta',
            multiple=True,
            description='Upload FASTA Files',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )
        
        self.uniprot_search = widgets.Checkbox(
            value=False,
            description='Use UniProt for missing proteins',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )

        # Reset button
        self.reset_button = widgets.Button(
            description='Reset',
            button_style='warning'
        )

        self.output_area = widgets.Output()

        # Create individual upload boxes with example links
        merged_box = widgets.HBox([
            self.merged_uploader,
            self.create_download_link("example_merged_dataframe.csv", "Example")
        ], layout=widgets.Layout(align_items='center'))

        fasta_box = widgets.HBox([
            self.fasta_uploader,
            self.create_download_link("example_fasta.fasta", "Example")
        ], layout=widgets.Layout(align_items='center'))
        
        uniprot_box = widgets.VBox([
            widgets.HTML("<h3><u>Option 2: Import from UniProt:</u></h3>"),
            self.uniprot_search
        ])

        combined_box = widgets.HBox([
            widgets.VBox([
            widgets.HTML("<h3><u>Option 1: Upload Protein FASTA Files:</u></h3>"),
            fasta_box]),
            widgets.HTML("<div style='margin: 0 20px; line-height: 100px;'><b>OR</b></div>"),
            uniprot_box
            ], layout=widgets.Layout(
                width='800px',
                margin='0 0 0 0px'
            ))

        # Create container for status display
        #self.status_area = widgets.Output(
        #    layout=widgets.Layout(
        #        width='300px',
        #        margin='0 0 0 0px'
        #    )
        #)
        # Create left column with upload widgets
        upload_widgets = widgets.VBox([
            widgets.HTML("<h3><u>Upload Data Files:</u></h3>"),
            merged_box,
            #widgets.HTML("<h3><u>Upload Protein FASTA Files:</u></h3>"),
            #fasta_box,
            #widgets.HTML("<h3><u>Import from UniProt:</u></h3>"),
            #uniprot_box,
            combined_box,
            #self.status_area,
            self.output_area
        ], layout=widgets.Layout(
            width='800px',
            margin='0 20px 0 0'
        ))


        # Create grid layout
        #grid = widgets.GridBox(
        #    [upload_widgets, self.status_area],
        #    layout=widgets.Layout(
        #        grid_template_columns='auto auto',
        #        grid_gap='5px',
        #        width='800px'
        #    )
        #)

        # Register observers
        self.merged_uploader.observe(self._on_merged_upload_change, names='value')
        self.fasta_uploader.observe(self._on_fasta_upload_change, names='value')
        self.reset_button.on_click(self._reset_ui)
        self.uniprot_search.observe(self._on_uniprot_search_change, names='value')


        # Display the grid
        #isplay(grid)
        display(upload_widgets)

    def _on_uniprot_search_change(self, change):
        """Handle UniProt search checkbox toggle and automatically search for missing proteins"""
        if change['new']:  # When checkbox is checked
            # Import the UniProt client if it's not already imported
            if self.uniprot_client is None:
                try:
                    from utils.uniprot_client import UniProtClient
                    self.uniprot_client = UniProtClient()
                    with self.output_area:
                        clear_output()
                        display(HTML('<b style="color:green;">UniProt client initialized successfully.</b>'))
                except ImportError as e:
                    with self.output_area:
                        clear_output()
                        display(HTML(f'<b style="color:red;">Error importing UniProt client: {str(e)}</b>'))
                    self.uniprot_search.value = False
                    return
            
            # Automatically check and fetch missing proteins
            if hasattr(self, 'merged_df') and self.merged_df is not None and not self.merged_df.empty:
                if 'Master Protein Accessions' in self.merged_df.columns:
                    with self.output_area:
                        clear_output()
                        display(HTML('<b>Checking for missing proteins and fetching from UniProt...</b>'))
                        
                        # Get unique protein IDs from merged data
                        unique_protein_ids = set(self.merged_df['Master Protein Accessions'].dropna().unique())
                        
                        # Find missing proteins
                        missing_proteins = unique_protein_ids - set(self.protein_dict.keys())
                        
                        if not missing_proteins:
                            display(HTML('<b style="color:green;">All proteins in your dataset already have sequence information.</b>'))
                            return
                        
                        # Process missing proteins
                        display(HTML(f'<b>Found {len(missing_proteins)} proteins missing sequence information.</b><br>' +
                                    f'<b>Fetching from UniProt...</b>'))
                        
                        success_count = 0
                        for protein_id in missing_proteins:
                            try:
                                name, species, sequence = self.uniprot_client.fetch_protein_info_with_sequence(protein_id)
                                
                                if sequence:
                                    self.protein_dict[protein_id] = {
                                        "name": name if name else protein_id,
                                        "sequence": sequence,
                                        "species": species if species else "Unknown"
                                    }
                                    success_count += 1
                                    display(HTML(f'<span style="color:green;">✓ {protein_id}: {name or "Unknown"} ({species or "Unknown"})</span>'))
                                else:
                                    display(HTML(f'<span style="color:orange;">✗ No sequence found for {protein_id}</span>'))
                            except Exception as e:
                                display(HTML(f'<span style="color:orange;">✗ Error fetching {protein_id}: {str(e)}</span>'))
                        
                        if success_count > 0:
                            display(HTML(f'<b style="color:green;">Successfully fetched {success_count} out of {len(missing_proteins)} missing proteins.</b>'))
                        else:
                            display(HTML(f'<b style="color:red;">Failed to fetch any of the {len(missing_proteins)} missing proteins.</b>'))

    def check_missing_proteins_in_merged_data(self):
        """Check for proteins in merged data that are missing from protein_dict and prompt to fetch from UniProt"""
        if not hasattr(self, 'merged_df') or self.merged_df is None or self.merged_df.empty:
            return
            
        if 'Master Protein Accessions' not in self.merged_df.columns:
            return
            
        # Get unique protein IDs from merged data
        unique_protein_ids = set(self.merged_df['Master Protein Accessions'].dropna().unique())
        
        # Find missing proteins
        missing_proteins = unique_protein_ids - set(self.protein_dict.keys())
        
        if not missing_proteins:
            return
            
        # Initialize UniProt client if needed
        if self.uniprot_search.value and self.uniprot_client is None:
            try:
                from utils.uniprot_client import UniProtClient
                self.uniprot_client = UniProtClient()
            except ImportError:
                with self.output_area:
                    display(HTML('<b style="color:red;">Error initializing UniProt client.</b>'))
                return
                
        if self.uniprot_search.value and self.uniprot_client:
            with self.output_area:
                display(HTML(f'<b>Found {len(missing_proteins)} proteins missing sequence information.</b><br>' +
                            f'<b>Attempting to fetch from UniProt...</b>'))
                
                success_count = 0
                for protein_id in missing_proteins:
                    try:
                        name, species, sequence = self.uniprot_client.fetch_protein_info_with_sequence(protein_id)
                        
                        if sequence:
                            self.protein_dict[protein_id] = {
                                "name": name if name else protein_id,
                                "sequence": sequence,
                                "species": species if species else "Unknown"
                            }
                            success_count += 1
                        else:
                            display(HTML(f'<span style="color:orange;">No sequence found for {protein_id}</span>'))
                    except Exception as e:
                        display(HTML(f'<span style="color:orange;">Error fetching {protein_id}: {str(e)}</span>'))
                        
                display(HTML(f'<b style="color:green;">Successfully fetched {success_count} out of {len(missing_proteins)} missing proteins.</b>'))
        else:
            with self.output_area:
                missing_list = ", ".join(list(missing_proteins)[:10])
                if len(missing_proteins) > 10:
                    missing_list += f" and {len(missing_proteins) - 10} more"
                    
                display(HTML(
                    f'<b style="color:orange;">Warning: {len(missing_proteins)} proteins in your data are missing sequences.</b><br>' +
                    f'Missing proteins: {missing_list}<br>' +
                    f'<b>To automatically fetch these proteins, check "Use UniProt for missing proteins" above.</b>'
                ))         
    
    def _reset_ui(self, b):
        """Reset the UI state"""
        self.merged_uploader._counter = 0
        self.fasta_uploader._counter = 0
        self.merged_uploader.value = ()
        self.fasta_uploader.value = ()
        self.merged_df = None
        self.group_data = None
        self.protein_dict = {}
        with self.output_area:
            self.output_area.clear_output()
            display(HTML('<b style="color:blue;">All uploads cleared.</b>'))
    
    def _on_merged_upload_change(self, change):
        if change['type'] == 'change' and change['name'] == 'value':
            with self.output_area:
                self.output_area.clear_output()
                if change['new'] and len(change['new']) > 0:
                    file_data = change['new'][0]
                    self.merged_df, merged_status = self._load_data(
                        file_data,
                        required_columns=['Master Protein Accessions'],
                        file_type='Merged'
                    )

                    self.process_protein_info(self.merged_df)

                    if merged_status == 'yes' and self.merged_df is not None:
                        avg_columns = [col for col in self.merged_df.columns if col.startswith('Avg_')]
                        self.col_order = [col.replace('Avg_', '') for col in avg_columns]

                        if avg_columns:
                            self.group_data = {}
                            for i, col in enumerate(avg_columns, 1):
                                group_name = col[4:]
                                self.group_data[str(i)] = {
                                    "grouping_variable": group_name,
                                    "abundance_columns": [col]
                                }
    
                            display(HTML(
                                f'<b style="color:green;">Merged data imported with {self.merged_df.shape[0]} rows and {self.merged_df.shape[1]} columns.</b><br>' +
                                f'<b style="color:green;">Automatically generated {len(self.group_data)} groups from Avg_ columns.</b>'
                            ))
                            
                            # Check for missing proteins
                            self.check_missing_proteins_in_merged_data()
    
                        else:
                            numerical_cols = list(self.merged_df.select_dtypes(include=['number']).columns)
                            if numerical_cols:
                                cols_selector = widgets.SelectMultiple(
                                    options=numerical_cols,
                                    description='Select data columns:',
                                    style={'description_width': 'initial'},
                                    layout=widgets.Layout(width='90%', height='200px')
                                )
    
                                confirm_button = widgets.Button(
                                    description='Confirm Selection',
                                    button_style='success'
                                )
    
                                def confirm_selection(b):
                                    selected_cols = list(cols_selector.value)
                                    if selected_cols:
                                        # Rename selected columns to start with 'Avg_'
                                        rename_dict = {col: f'Avg_{col}' for col in selected_cols}
                                        self.merged_df.rename(columns=rename_dict, inplace=True)
                                        self.group_data = {
                                            str(i): {
                                                "grouping_variable": col,
                                                "abundance_columns": [f'Avg_{col}']
                                            } for i, col in enumerate(selected_cols, 1)
                                        }
                                        with self.output_area:
                                            self.output_area.clear_output()
                                            display(HTML(
                                                f'<b style="color:green;">Merged data imported with {self.merged_df.shape[0]} rows and {self.merged_df.shape[1]} columns.</b>'
                                            ))
                                            # Check for missing proteins
                                            self.check_missing_proteins_in_merged_data()
                                    else:
                                        with self.output_area:
                                            display(HTML('<b style="color:red;">No columns selected. Please select at least one data column.</b>'))
    
                                confirm_button.on_click(confirm_selection)

                                display(widgets.VBox([cols_selector, confirm_button]))
    
                            else:
                                display(HTML(
                                    f'<b style="color:red;">No numerical columns available to select.</b><br>' +
                                    f'<b style="color:green;">Merged data imported with {self.merged_df.shape[0]} rows and {self.merged_df.shape[1]} columns.</b>'
                                ))
            self.check_missing_proteins_in_merged_data()
    
    def _on_fasta_upload_change(self, change):
        if change['type'] == 'change' and change['name'] == 'value':
            with self.output_area:
                self.output_area.clear_output()
                if change['new'] and len(change['new']) > 0:
                    for file_data in change['new']:
                        try:
                            file_name = getattr(file_data, 'name', None)
                            if file_name and file_name.endswith('.fasta'):
                                new_proteins = self._parse_uploaded_fasta(file_data)

                                if self.merged_df is not None:
                                    protein_ids = set(self.merged_df['Master Protein Accessions'].dropna().unique())
                                    match_count = 0

                                    for protein_id, protein_data in new_proteins.items():
                                        if protein_id in protein_ids:
                                            self.protein_dict[protein_id] = protein_data
                                            match_count += 1

                                    display(HTML(
                                        f'<b style="color:green;">Successfully imported FASTA file: {file_name} '
                                        f'({len(new_proteins)} proteins, {match_count} matched to dataset)</b>'
                                    ))

                                else:
                                    self.protein_dict.update(new_proteins)
                                    display(HTML(f'<b style="color:green;">Successfully imported Entire FASTA file: {file_name} ({len(new_proteins)} proteins)</b>'))
                            else:
                                display(HTML(f'<b style="color:red;">Invalid file format. Please upload FASTA files only.</b>'))
                        except Exception as e:
                            display(HTML(f'<b style="color:red;">Error processing FASTA file: {str(e)}</b>'))

    def validate_and_standardize_columns(self, df: pd.DataFrame):
        SOFTWARE_START_STOP_COLUMNS = {
            #'Positions in Proteins': ('start', 'stop'),
            'Start position': 'start',
            'End position': 'stop',
            'Peptide Start': 'start',
            'Peptide End': 'stop',
            'StartPosition': 'start',
            'EndPosition': 'stop',            
            'Peptide start': 'start',
            'Peptide end': 'stop',
            'Startposition': 'start',
            'Endposition': 'stop',
            'Peptide Start Position': 'start',
            'Peptide End Position': 'stop',
            'Peptide start position': 'start',
            'Peptide end position': 'stop',
        }

        SOFTWARE_PROTEIN_ID_COLUMNS = [
            'Master Protein Accessions', 'Leading proteins', 'Protein Name',
            'Protein Accession', 'Accession Number', 'ProteinGroupId',
            'Protein ID', 'Accession', 'protein_accession', 'Protein'
        ]

        # Rename start/stop columns
        for col in df.columns:
            if col in SOFTWARE_START_STOP_COLUMNS:
                mapping = SOFTWARE_START_STOP_COLUMNS[col]
                if isinstance(mapping, tuple):
                    start_col, stop_col = col, col
                else:
                    if SOFTWARE_START_STOP_COLUMNS[col] == 'start':
                        df.rename(columns={col: 'start'}, inplace=True)
                    elif SOFTWARE_START_STOP_COLUMNS[col] == 'stop':
                        df.rename(columns={col: 'stop'}, inplace=True)

        # Validate start/stop columns
        if 'start' not in df.columns or 'stop' not in df.columns:
            raise ValueError("Missing required columns 'start' or 'stop'.")

        # Rename Protein ID column
        for col in SOFTWARE_PROTEIN_ID_COLUMNS:
            if col in df.columns:
                df.rename(columns={col: 'Master Protein Accessions'}, inplace=True)
                break

        if 'Master Protein Accessions' not in df.columns:
            raise ValueError("Missing required column 'Master Protein Accessions'.")


        multi_id_rows = df[df['Master Protein Accessions'].str.contains(';', na=False)]
        
        if not multi_id_rows.empty:
            display(HTML('<b style="color:red;">Rows with multiple protein IDs detected:</b>'))
            display(multi_id_rows)
            raise ValueError(
                "Multiple protein IDs detected for one or more peptides; either exclude these rows or use another strategy to handle them."
            )
        # Extract UniProt ID from FASTA header
        df['Master Protein Accessions'] = df['Master Protein Accessions'].apply(
            lambda x: re.search(r'\|([A-Z0-9]+)\|', x).group(1) if '|' in x else x
        )

        return df

    def _load_data(self, file_obj, required_columns, file_type):
        try:
            content = file_obj.content
            filename = file_obj.name
            extension = filename.split('.')[-1].lower()
            file_stream = io.BytesIO(content)

            if extension == 'csv':
                df = pd.read_csv(file_stream)
            elif extension in ['txt', 'tsv']:
                df = pd.read_csv(file_stream, delimiter='\t')
            elif extension == 'xlsx':
                df = pd.read_excel(file_stream)
            else:
                raise ValueError("Unsupported file format.")

            df.columns = df.columns.str.strip()

            # Standardize columns
            try:
                df = self.validate_and_standardize_columns(df)
            except ValueError as e:
                display(HTML(f'<b style="color:red;">{file_type} File Error: {e}</b>'))
                return None, 'no'

            # Check again for required columns
            if not set(required_columns).issubset(df.columns):
                missing = set(required_columns) - set(df.columns)
                display(HTML(
                    f'<b style="color:red;">{file_type} File Error: Missing required columns after standardization: {", ".join(missing)}</b>'
                ))
                return None, 'no'

            return df, 'yes'
        except Exception as e:
            display(HTML(f'<b style="color:red;">{file_type} File Error: {str(e)}</b>'))
            return None, 'no'
    
    def process_protein_info(self, df):
        # Check if we need to fetch any data from UniProt
        has_protein_info = all(col in df.columns for col in ['protein_name', 'protein_species'])
        if has_protein_info:
            # Check if we have valid data for all entries
            all_data_present = (
                df['protein_name'].notna().all() and 
                df['protein_species'].notna().all() and
                (df['protein_name'] != '').all() and
                (df['protein_species'] != '').all()
            )
            if all_data_present:
                # If we have all data, just process it silently
                protein_info = df.groupby('Master Protein Accessions').agg({
                    'protein_name': 'first',
                    'protein_species': 'first'
                }).reset_index()
                
                for _, row in protein_info.iterrows():
                    protein_id = row['Master Protein Accessions']
                    self.protein_dict[protein_id] = {
                        "name": row['protein_name'],
                        "species": row['protein_species'],
                        "sequence":''
                    }
                return len(self.protein_dict)

    def _find_species(self, header):
        """Find species in FASTA header"""
        header_lower = header.lower()
        for spec_group in spec_translate_list:
            for term in spec_group[1:]:
                if term.lower() in header_lower:
                    return spec_group[0]
        return "unknown"

    def _parse_uploaded_fasta(self, file_data):
        """Parse uploaded FASTA file content"""
        fasta_dict = {}
        fasta_text = bytes(file_data.content).decode('utf-8')
        lines = fasta_text.split('\n')
        
        protein_id = ""
        protein_name = ""
        sequence = ""
        species = ""
        
        for line in lines:
            line = line.strip()
            if line.startswith('>'):
                if protein_id:
                    fasta_dict[protein_id] = {
                        "name": protein_name,
                        "sequence": sequence,
                        "species": species
                    }
                sequence = ""
                header_parts = line[1:].split('|')
                if len(header_parts) > 2:
                    protein_id = header_parts[1]
                    protein_name_full = re.split(r' OS=', header_parts[2])[0]
                    protein_name = protein_name_full if ' ' in protein_name_full else protein_name_full
                    species = self._find_species(line)
            else:
                sequence += line
                
        if protein_id:
            fasta_dict[protein_id] = {
                "name": protein_name,
                "sequence": sequence,
                "species": species
            }
        
        return fasta_dict

In [5]:
class HeatmapDataHandler:
    def __init__(self):    
        # Add filter_type to the existing initialization
        self.filter_type = 'all-peptides'  # Set default value
        # Initialize all widgets here
        self.protein_dropdown = widgets.Dropdown(description='Protein:')
        self.grouping_variable_text = widgets.Text(description='Grouping Variable:')
        self.var_key_dropdown = widgets.Dropdown(description='Variable Key:')
        self.button_box = HBox([widgets.Button(description='Submit')])
        self.var_selection_output = widgets.Output()
        self.label_order_output = widgets.Output()
        self.available_data_variables = {}  # Populate this as needed
        self.label_widgets = {}  # Populate this as needed

        # Initialize variables
        self.data_variables = {} #self.extract_and_format_data()
                # Extract available protein IDs and names
        self.protein_mapping = {
            key.split('_')[0]: value['protein_name']
            for key, value in self.data_variables.items()
        }
        self.available_proteins = set([key.split('_')[0] for key in self.data_variables.keys()])
        #self.available_grouping_vars = {
        #   protein: [key.split('_', 1)[1] for key in self.data_variables.keys() if key.startswith(protein)] for protein
        #   in self.available_proteins}    
        if data_transformer.group_data:
            self.available_grouping_vars = [group['grouping_variable'] for group in data_transformer.group_data.values()]
        else:
            self.available_grouping_vars = [] 
        self.selected_var_keys_list = []
        if data_transformer.col_order:
            self.col_order = data_transformer.col_order
        # Filtered Data Variables
        self.filtered_data_variables = {}
        self.available_data_variables = {}
        self.label_widgets = {}
        
        self.order_widgets = {}
        self.default_label_values = {}
        self.default_order_values = {}

        # Widgets
        self.create_widgets()

        # Additional attributes for plotting options
        self.ms_average_choice = None
        self.selected_peptides = []
        self.selected_functions = []
        self.legend_title = legend_title
        # Initialize variables
        self.bio_or_pep = 'no'  # Default value
        self.ms_average_choice = 'yes'  # Default value
        self.plot_heatmap = 'yes'  # Default value
        self.plot_zero = 'no'  # Default value

        self.user_protein_id = ''  # Will be set appropriately
        self.protein_name_short = ''  # Will be set appropriately

        self.label_order_output = widgets.Output()
    
    def create_filtered_data_variables(self):
        return {key: self.data_variables[key] 
                for key in self.selected_var_keys_list
                if key in self.data_variables}
         
    def update_protein_info(self, protein_name):
        self.protein_name_short = protein_name
        # Update plot handler if it exists
        if hasattr(self, 'plot_handler'):
            self.plot_handler.update_protein_name(protein_name)
  
    def process_export(self):       
        # Dictionary to hold all data for saving
        self.complete_data = {}

        # Get unique proteins from selected_var_keys_list
        selected_proteins = set(key.split('_')[0] for key in self.selected_var_keys_list)

        # Process each selected protein
        for protein_id in selected_proteins:
            protein_df = data_transformer.merged_df[data_transformer.merged_df['Master Protein Accessions'] == protein_id]
            is_all_null = 'function' in protein_df.columns and protein_df['function'].isna().all()

            if protein_id in data_transformer.protein_dict:
                print("protein_id",protein_id)
                print(data_transformer.protein_dict[protein_id])

                # Check if we need to fetch sequence from UniProt
                if 'sequence' not in data_transformer.protein_dict[protein_id] or not data_transformer.protein_dict[protein_id]['sequence']:
                    try:
                        # Initialize UniProt client if not already available
                        if not hasattr(data_transformer, 'uniprot_client'):
                            from utils.uniprot_client import UniProtClient
                            data_transformer.uniprot_client = UniProtClient()
                        
                        # Fetch the protein info with sequence
                        name, species, sequence = data_transformer.uniprot_client.fetch_protein_info_with_sequence(protein_id)
                        
                        # Update the protein_dict with the sequence
                        if sequence:
                            data_transformer.protein_dict[protein_id]['sequence'] = sequence
                            
                            display(HTML(f"<b style='color:green;'>Sequence for {protein_id} fetched from UniProt.</b>"))
                        else:
                            display(HTML(f"<b style='color:orange;'>Sequence for {protein_id} not found in UniProt.</b>"))
                    except Exception as e:
                        display(HTML(f"<b style='color:red;'>Error fetching sequence for {protein_id}: {str(e)}</b>"))
                
                # Get protein data - now with sequence if available
                protein_sequence = data_transformer.protein_dict[protein_id].get('sequence', '')
                protein_species = data_transformer.protein_dict[protein_id].get('species', 'Unknown')
                protein_name = data_transformer.protein_dict[protein_id].get('name', protein_id)

                protein_data = {}

                # Process each group for this protein
                for group_key, group_info in data_transformer.group_data.items():
                    grouping_var_name = group_info['grouping_variable']
                    
                    # Only process if this combination exists in selected_var_keys_list
                    if f"{protein_id}_{grouping_var_name}" in self.selected_var_keys_list:
                        heatmap_data = export_heatmap_data_to_dict(
                            protein_id, grouping_var_name, group_info,
                            protein_sequence, protein_species, protein_name,
                            protein_df, is_all_null
                        )
                        protein_data[grouping_var_name] = heatmap_data

                # Only add to complete_data if we have data for this protein
                if protein_data:
                    self.complete_data[protein_id] = protein_data
            else:
                display(HTML(f"<b style='color:red;'>Data for {protein_id} not found in protein dictionary.</b>"))
    
    def update_filter_type(self, new_filter_type):
        """Update the filter type and reprocess data"""
        self.filter_type = new_filter_type
        self.process_export()
        if hasattr(self, '_on_data_changed_callback'):
            self._on_data_changed_callback()
    
    def extract_and_format_data(self):
        # Load the data from the saved directory
        self.process_export()  # Remove the self argument
        self.loaded_data = self.complete_data
        # Initialize the new dictionary
        data_variables = {}

        # Iterate over the loaded data to extract and reorganize it
        for protein_id, protein_data in self.loaded_data.items():
            protein_sequence = protein_data.get('protein_sequence')

            for grouping_var_name, group_info in protein_data.items():
                # Extract the required DataFrames and other information
                func_df = group_info.get('func_heatmap_df')
                abs_df = group_info.get('heatmap_df')
                filtered_abs_df = group_info.get('filtered_heatmap_df')

                label = grouping_var_name
                protein_sequence = group_info.get('protein_sequence')
                protein_name = group_info.get('protein_name')
                protein_species = group_info.get('protein_species')

                # Determine if the func_df is all None
                is_func_df_all_none = func_df.isnull().all().all() if func_df is not None else True

                # Create a unique key combining protein_id and grouping_var_name
                var_key = f"{protein_id}_{grouping_var_name}"

                # Populate the data_variables dictionary using the unique key
                data_variables[var_key] = {
                    'protein_id': protein_id,
                    'protein_sequence': protein_sequence,
                    'protein_name': protein_name,
                    'protein_species': protein_species,
                    'heatmap_df': abs_df,
                    'function_heatmap_df': func_df,
                    'label': label,
                    'is_func_df_all_none': is_func_df_all_none,
                    'filtered_heatmap_df': filtered_abs_df
                }

        return data_variables
    
    def process_data_variables(self):
        """
        Minimal processing to pass essential data from filtered to available data variables
        """
        # Dynamically generate the list of variable names based on loaded data
        variables = list(self.filtered_data_variables.keys())
        protein_id_list = []
        protein_name_list = []
        
        # Pass through only essential data to available_data_variables
        self.available_data_variables = {}
        
        for var in variables:
            if var in self.filtered_data_variables:
                # Create new dict for this variable
                self.available_data_variables[var] = {
                    # Core data frames
                    'heatmap_df': self.filtered_data_variables[var]['heatmap_df'],
                    'function_heatmap_df': self.filtered_data_variables[var]['function_heatmap_df'],
                    'filtered_heatmap_df': self.filtered_data_variables[var]['filtered_heatmap_df'],
    
                    # Essential metadata
                    'protein_id': self.filtered_data_variables[var]['protein_id'],
                    'protein_name': self.filtered_data_variables[var]['protein_name'],
                    'protein_sequence': self.filtered_data_variables[var]['protein_sequence'],
                    'label': self.filtered_data_variables[var]['label'],
                    
                    # Track if function data is all null
                    'is_func_df_all_none': self.filtered_data_variables[var]['function_heatmap_df'].isnull().all().all() 
                        if self.filtered_data_variables[var]['function_heatmap_df'] is not None else True
                }
                
                protein_id_list.append(self.filtered_data_variables[var]['protein_id'])
                protein_name_list.append(self.filtered_data_variables[var]['protein_name'])
    
        # Set protein ID and name
        user_protein_id_set = list(set(protein_id_list))
        user_protein_name_set = list(set(protein_name_list))
    
        if len(user_protein_id_set) > 1 and len(user_protein_name_set) == 1:
            self.user_protein_id = '_'.join(user_protein_id_set)
            self.protein_name_short = user_protein_name_set[0]
        elif len(user_protein_id_set) > 1 and len(user_protein_name_set) > 1:
            self.user_protein_id = '_'.join(user_protein_id_set)
            self.protein_name_short = '_'.join(user_protein_name_set)
        elif len(user_protein_name_set) == 1:
            self.user_protein_id = user_protein_id_set[0]
            self.protein_name_short = user_protein_name_set[0]
            
        self.protein_name_short = str(self.protein_name_short)
                
    # Function to create order input widgets
    def create_order_input_widgets(self):
        description_layout_invisible = widgets.Layout(width='90%')

        self.label_widgets = {}
        self.order_widgets = {}
        for i, (var, info) in enumerate(self.available_data_variables.items()):
            self.label_widgets[var] = widgets.Text(
                value=info['label'],
                description='',
                layout=widgets.Layout(width='150px')
            )
            self.order_widgets[var] = widgets.IntText(
                value=i,
                description='',
                layout=description_layout_invisible,
            )
        # Optionally, you can return the widgets if needed
        # return self.label_widgets, self.order_widgets

    # Function to create widgets
    def create_widgets(self):
        if data_transformer:
            # Create widgets for protein selection
            # Count occurrences of each protein
            protein_counts = data_transformer.merged_df['Master Protein Accessions'].value_counts()
            sorted_proteins = protein_counts.index.tolist()
            
            # Create dropdown options list with first protein selected by default
            dropdown_options = [(f"{protein} - {data_transformer.protein_dict.get(protein, {'name': 'Unknown'})['name']}", protein) 
                              for protein in sorted_proteins]
            
            self.protein_dropdown = widgets.Dropdown(
                options=dropdown_options,
                value=sorted_proteins[0] if sorted_proteins else None,  # Set default value to first protein
                description='Protein ID:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='90%')
            )
        else:
            sorted_proteins = ''
            self.protein_dropdown = widgets.Dropdown(
                options=[],
                description='Protein ID:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='90%')
            )

        self.grouping_variable_text = widgets.Text(
            description='Search Term',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='90%')
        )

        self.var_key_dropdown = widgets.SelectMultiple(
            description='Select Groups',
            style={'description_width': 'initial'},
            disabled=False,
            layout=widgets.Layout(width='90%', height='100px')
        )
    
        # Since we're setting a default protein, we should trigger the update_var_keys
        # immediately after creation to populate the groups
        if sorted_proteins:
            self.update_var_keys({'new': sorted_proteins[0], 'type': 'change', 'name': 'value'})

        self.add_group_button = widgets.Button(
            description='Add Group',
            button_style='success',
            layout=widgets.Layout(width='150px', height='30px')
        )
        self.search_button = widgets.Button(
            description='Search',
            button_style='info',
            layout=widgets.Layout(width='150px', height='30px')
        )

        self.reset_button = widgets.Button(
            description='Reset Selection',
            button_style='warning',
            layout=widgets.Layout(width='150px', height='30px')
        )

        # Create buttons
        self.update_label_button = widgets.Button(
            description="Update Labels",
            button_style='success',
            layout=widgets.Layout(width='125px', height='30px')
        )
        self.update_order_button = widgets.Button(
            description="Update Order",
            button_style='success',
            layout=widgets.Layout(width='125px', height='30px')
        )
        self.reset_labelorder_button = widgets.Button(
            description="Reset to Default",
            button_style='warning',
            layout=widgets.Layout(width='125px', height='30px', margin='10px 10px 0 75px')
        )

        # Attach click event handlers
        self.update_label_button.on_click(self.on_update_label_click)
        self.update_order_button.on_click(self.on_update_order_click)
        self.reset_labelorder_button.on_click(self.on_reset_click)

        # Display buttons
        self.label_order_button_box = widgets.HBox([self.update_label_button, self.update_order_button])

        self.var_selection_output = widgets.Output()
    
        # Set widget events
        self.protein_dropdown.observe(self.update_var_keys, names='value')
        self.search_button.on_click(self.search_var_keys)
        self.add_group_button.on_click(self.add_group)
        self.reset_button.on_click(self.reset_selection)

        
        self.button_box = widgets.HBox([self.search_button, self.add_group_button, self.reset_button],
        layout=widgets.Layout(
            height='40px', 
            width='90%', 
            overflow='hidden', 
            justify_content='space-between'
            )
        )
    
    def update_var_keys(self, change):
        self.selected_protein = change['new']
        current_selection = set(self.var_key_dropdown.value)
        
        # Use col_order to maintain the same order, but only include available grouping vars
        new_options = [str(col) for col in self.col_order 
                      if str(col) in map(str, self.available_grouping_vars)]
        
        # Update options while preserving order from col_order
        self.var_key_dropdown.options = new_options
        # Keep selected values that are still valid
        self.var_key_dropdown.value = tuple(val for val in current_selection 
                                          if val in new_options)
        
    # Function to search and filter var_keys based on the grouping variable text input
    def search_var_keys(self, b):
        group_name = self.grouping_variable_text.value
        if group_name:
            matching_keys = [key for key in self.var_key_dropdown.options if group_name in key]
            self.var_key_dropdown.value = matching_keys
        else:
            with self.var_selection_output:
                self.var_selection_output.clear_output()
                display(HTML('<b style="color:red;">Please enter a group name to search.</b>'))

    # Function to add a group of selected var_keys to the list
    def add_group(self, b):
        self.selected_protein = self.protein_dropdown.value
        selected_keys = list(self.var_key_dropdown.value)
    
        if not (selected_keys and self.selected_protein):
            with self.var_selection_output:
                self.var_selection_output.clear_output()
                display(HTML('<b style="color:red;">Please select a protein and at least one key.</b>'))
            return
    
        # Add new keys while preserving existing ones
        combined_keys = {f"{self.selected_protein}_{key}" for key in selected_keys}
        self.selected_var_keys_list = list(set(self.selected_var_keys_list) | combined_keys)
    
        # Update display and process data for all selected proteins
        with self.var_selection_output:
            self.var_selection_output.clear_output(wait=True)
            display(HTML(f"""
                <h3> </h3>
                
                <b>{len(self.selected_var_keys_list)} unique variables:</b><br>
                <div style="padding-left: 20px;">
                    {('<br>').join(self.selected_var_keys_list)}
                </div>
                <b>Total unique variables:</b> <br>
            """))
        # Process data for all selected proteins
        with self.label_order_output:
            self.label_order_output.clear_output(wait=True)
            self.data_variables = self.extract_and_format_data()
            self.filtered_data_variables = self.create_filtered_data_variables()
            self.process_data_variables()
            self.available_data_variables = self.filtered_data_variables.copy()
    
            # Update widgets and display
            self.create_order_input_widgets()
            self.default_label_values = {key: self.label_widgets[key].value for key in self.label_widgets}
            self.default_order_values = [info['label'] for info in self.available_data_variables.values()]
            display(self.display_label_order_widgets())
    
        if hasattr(self, '_on_data_changed_callback'):
            self._on_data_changed_callback()
    
    # Function to reset the selection
    def reset_selection(self, b):
        # Preserve loaded data
        preserved_data = {
            'data_variables': self.data_variables,
            'loaded_data': getattr(self, 'loaded_data', None),
            'complete_data': getattr(self, 'complete_data', None)
        }
        
        # Reset selection-related attributes
        self.selected_var_keys_list = []
        self.protein_dropdown.value = None
        self.var_key_dropdown.options = []
        self.grouping_variable_text.value = ''
        
        # Reset visualization-related attributes
        self.filtered_data_variables = {}
        self.available_data_variables = {}
        
        # Reset widgets
        self.label_widgets = {}
        self.order_widgets = {}
        self.default_label_values = {}
        self.default_order_values = []
        
        # Clear outputs
        with self.var_selection_output:
            self.var_selection_output.clear_output()
            display(HTML('<b style="color:green;">Selection has been reset. You can make new selections from the available data.</b>'))
        
        with self.label_order_output:
            self.label_order_output.clear_output()
        
        # Clear plots if plot handler exists
        if hasattr(self, '_plot_handler') and self._plot_handler:
            with self._plot_handler.plot_output:
                self._plot_handler.plot_output.clear_output()
                plt.close('all')  # Close all matplotlib figures
            
            # Reset plot-related attributes in plot handler
            self._plot_handler.fig_port = None
            self._plot_handler.fig_land = None
        
        # Restore preserved data
        self.data_variables = preserved_data['data_variables']
        if preserved_data['loaded_data'] is not None:
            self.loaded_data = preserved_data['loaded_data']
        if preserved_data['complete_data'] is not None:
            self.complete_data = preserved_data['complete_data']
        
        # Update available grouping variables if they exist
        if hasattr(data_transformer, 'group_data') and data_transformer.group_data:
            self.available_grouping_vars = [
                group['grouping_variable'] 
                for group in data_transformer.group_data.values()
            ]
        
        # Trigger any necessary UI updates
        if hasattr(self, '_on_data_changed_callback'):
            self._on_data_changed_callback()

    # Function to display messages in the output widget
    def display_message(self, message, is_error=False):
        with self.message_output:
            self.message_output.clear_output()  # Clear previous messages
            if is_error:
                display(HTML(f"<b style='color:red;'>{message}</b>"))  # Error message in red
            else:
                display(HTML(f"<b style='color:green;'>{message}</b>"))  # Success message in green

    # Function to update order based on new order input
    def update_order(self, order_labels):
        """
        Updates the order of available_data_variables based on provided labels.
        The order of items in the dictionary will match the order of labels in order_labels.
        
        Args:
            order_labels (list): List of labels in the desired order
        """
        # Clean the input labels - strip whitespace and empty strings
        order_labels = [label.strip() for label in order_labels if label.strip()]
        
        # Get current label to key mapping
        label_to_key = {info['label']: key for key, info in self.available_data_variables.items()}
        
        # Validate input
        if len(order_labels) != len(self.available_data_variables):
            raise ValueError(f"Number of labels provided ({len(order_labels)}) does not match number of variables ({len(self.available_data_variables)})")
        
        if len(set(order_labels)) != len(order_labels):
            raise ValueError("Duplicate labels found in input")
        
        if not all(label in label_to_key for label in order_labels):
            invalid_labels = [label for label in order_labels if label not in label_to_key]
            raise ValueError(f"Invalid labels found: {invalid_labels}")
        
        # Create a new ordered dictionary with items in the specified order
        ordered_data = {}
        for new_label in order_labels:
            key = label_to_key[new_label]
            ordered_data[key] = self.available_data_variables[key]
            
            # Update the label in the data to match the new order
            ordered_data[key]['label'] = new_label
        
        # Replace the available_data_variables with the new ordered dictionary
        self.available_data_variables = ordered_data
        
        # Update the widgets to reflect new order
        self.create_order_input_widgets()
        
        # Update the text input to show the new order
        current_labels = [info['label'] for info in self.available_data_variables.values()]
        self.new_order_input.value = ', '.join(current_labels)
        
        return self.available_data_variables
    
    # Event handler for updating labels
    def on_update_label_click(self, b):
        try:
            self.update_labels()
            self.display_message("Labels updated successfully.")
        except Exception as e:
            self.display_message(f"Error updating labels: {e}", is_error=True)

    # Event handler for updating order
    def on_update_order_click(self, b):
        """Event handler for updating order"""
        try:
            # Get the order input and split it into a list
            order_input = self.new_order_input.value
            order_list = [label.strip() for label in order_input.split(',')]
            
            # Update the order
            self.update_order(order_list)
            
            # Update the label widgets with the new order
            for i, (key, info) in enumerate(self.available_data_variables.items()):
                if key in self.label_widgets:
                    self.label_widgets[key].value = info['label']
            
            # Update the display
            with self.label_order_output:
                self.label_order_output.clear_output(wait=True)
                display(self.display_label_order_widgets())
            
            # Display success message
            with self.message_output:
                self.message_output.clear_output(wait=True)
                display(HTML('<b style="color:green;">Order updated successfully!</b>'))
            
            # Trigger data changed callback if it exists
            if hasattr(self, '_on_data_changed_callback'):
                self._on_data_changed_callback()
                
        except Exception as e:
            with self.message_output:
                self.message_output.clear_output(wait=True)
                display(HTML(f'<b style="color:red;">Error updating order: {str(e)}</b>'))

    # Event handler for resetting labels and order
    def on_reset_click(self, b):
        try:
            # Reset each label widget to its default value
            for key in self.label_widgets:
                self.label_widgets[key].value = self.default_label_values[key]

            # Reset the order widget to its default value
            self.new_order_input.value = ', '.join(self.default_order_values)

            # Apply the default labels and order
            self.on_update_label_click(b)
            self.on_update_order_click(b)

            self.display_message("Labels and order reset to default.")
        except Exception as e:
            self.display_message(f"Error resetting labels and order: {e}", is_error=True)

    # Function to display label and order widgets
    def display_label_order_widgets(self):
        # Output widget for displaying messages
        self.message_output = widgets.Output()
    
        # Header for the columns
        header = HTML("<h3><u>Update Sample Labels & Order (Optional)</u></h3>")
    
        # Update labels section
        update_label = [widgets.HTML(value="<h3><u>Update Labels:</u></h3>")]
        for i, (var, info) in enumerate(self.available_data_variables.items()):
            label_widget = HBox([
                widgets.Label(
                    value=f"{i + 1})  {info['label']}  -  {info.get('protein_species', '')}  -  {info.get('protein_name', '')}",
                    layout=widgets.Layout(width='90%', height='30px',overflow='hidden')
                ),
                self.label_widgets.get(var, widgets.Text())
            ])
            update_label.append(label_widget)
        update_label_box = VBox(update_label,  layout=widgets.Layout(margin='0px', height='auto', width='90%', overflow='visible', padding='0px'))
    
        # Label above the text input box
        label_above_input = widgets.HTML(
            value="<h3><u>Re-order Samples:</u></h3>Enter labels in desired order separated by commas (e.g., label_1, label_2, label_3)")
    
        # Extract labels from available_data_variables for display
        label_list = [info['label'] for info in self.available_data_variables.values()]
    
        # Text input for new order
        # Text input for new order without scrollbar
        self.new_order_input = widgets.Textarea(
            value=', '.join(label_list),
            layout=widgets.Layout(
                width='90%',
                height='auto',  # Automatically adjust the height to fit the content
                overflow='hidden'  # Eliminate scrollbars
            )
        )
        update_order_box = VBox([label_above_input, self.new_order_input], layout=widgets.Layout(margin='0px', height='200px', overflow='visible', padding='0px'))

        # Create buttons with fixed sizes
        update_label_button = widgets.Button(
            description="Update Labels",
            button_style='success',
            layout=widgets.Layout(
                width='150px',       # Fixed width
                height='30px',       # Fixed height
                overflow='hidden'    # Eliminate scrollbars
            )
        )
        
        update_order_button = widgets.Button(
            description="Update Order",
            button_style='success',
            layout=widgets.Layout(
                width='150px',       # Fixed width
                height='30px',       # Fixed height
                overflow='hidden'    # Ensure no internal scrolling
            )
        )
        
        reset_labelorder_button = widgets.Button(
            description="Reset to Default",
            button_style='warning',
            layout=widgets.Layout(
                width='150px',       # Fixed width
                height='30px',       # Fixed height
                overflow='hidden'    # Ensure no internal scrolling
            )
        )
        
        # Combine buttons into a container (HBox) with sufficient width
        label_order_button_box = widgets.HBox(
            [update_label_button, update_order_button, reset_labelorder_button],
            layout=widgets.Layout(
                width='90%',            # Ensure enough space for all buttons
                height='auto',            # Adjust height automatically
                overflow='visible',       # No scrolling for the container
                justify_content='space-between'  # Distribute buttons horizontally
            )
        )

    
        # Attach click event handlers
        update_label_button.on_click(self.on_update_label_click)
        update_order_button.on_click(self.on_update_order_click)
        reset_labelorder_button.on_click(self.on_reset_click)
    
        # Display buttons
        vert_button_box = VBox(
            [
                update_label_box,
                update_order_box,
                label_order_button_box, 
                self.message_output
            ],
            layout=widgets.Layout(
                margin='0px',
                width='90%',        # Ensure it takes up available horizontal space
                height='40px',       # Ensure it takes up as much vertical space as needed
                flex_flow='column',  # Maintain column layout
                align_items='stretch'  # Prevent compacting by stretching items
            )
        )
    
        # Return the constructed widgets
        return vert_button_box#, update_label_box, update_order_box
  
    def update_labels(self):
        # Update labels in available_data_variables based on label_widgets
        for key in self.available_data_variables:
            self.available_data_variables[key]['label'] = self.label_widgets[key].value

    # Function to display the initial selection widgets         
    def display_widgets(self):
        # Create a grid layout with 3 rows and 2 columns
    
        # Input widgets
        input_widgets = VBox([
            widgets.HTML("<h3><u>Select Protein and Grouping Variables:</u></h3>"),
            self.protein_dropdown,
            self.grouping_variable_text,
            self.var_key_dropdown,
            self.button_box,
            #self.label_order_output
        ], layout=widgets.Layout(height = 'auto', width = '90%', margin='0px', padding='0px', overflow='hidden',))  # Minimize widget margins

    
        # Output widgets
        output_widgets = VBox([
            self.var_selection_output,
        ], layout=widgets.Layout(height = 'auto', width = '90%', margin='0px', padding='0px', overflow='hidden',))  # Minimize widget margins

        vert_button_box = self.display_label_order_widgets()

    
        # Display the grid
        #display(grid)
        return input_widgets, output_widgets, vert_button_box#, update_label_box, update_order_box

In [6]:
class DynamicVisualizationHandler:
    def __init__(self, data_transformer):
        self.data_transformer = data_transformer
        self.selector = None
        self.app = None
        self.dynamic_output = widgets.Output()
        self.status_output = widgets.Output()
        
        # Add observers to data transformer's file uploaders
        self.data_transformer.merged_uploader.observe(self._on_file_change, names='value')
        self.data_transformer.fasta_uploader.observe(self._on_file_change, names='value')
        self.data_transformer.uniprot_search.observe(self._on_file_change, names='value')
       
    """def _check_upload_status(self):
        with self.data_transformer.status_area:
            self.data_transformer.status_area.clear_output(wait=True)
            status_html = "<h3><u>Upload Status:</u></h3>"
            
            # Check merged data and group data (now combined)
            if hasattr(self.data_transformer, 'merged_df') and isinstance(self.data_transformer.merged_df, pd.DataFrame) and not self.data_transformer.merged_df.empty:
                status_html += '<p style="color:green;">✓ Merged data loaded successfully</p>'
                if hasattr(self.data_transformer, 'group_data') and self.data_transformer.group_data:
                    status_html += f'<p style="color:green;">✓ {len(self.data_transformer.group_data)} groups automatically detected</p>'
                else:
                    status_html += '<p style="color:orange;">○ No Avg_ columns detected in data</p>'
            else:
                status_html += '<p style="color:orange;">○ Waiting for merged data file...</p>'
            
            # Check FASTA data
            if hasattr(self.data_transformer, 'protein_dict') and self.data_transformer.protein_dict:
                status_html += f'<p style="color:green;">✓ FASTA files loaded ({len(self.data_transformer.protein_dict)} proteins)</p>'
            else:
                status_html += '<p style="color:orange;">○ Waiting for FASTA files...</p>'
            
            display(HTML(status_html))"""
    
    def _on_file_change(self, change):
        """Handle any file upload changes"""
        #self._check_upload_status()
        self._reinitialize_visualization()
    
    def _check_required_data(self):
        """Check if we have the minimum required data to proceed"""
        has_merged = (hasattr(self.data_transformer, 'merged_df') and 
                     isinstance(self.data_transformer.merged_df, pd.DataFrame) and 
                     not self.data_transformer.merged_df.empty)
        
        has_groups = (hasattr(self.data_transformer, 'group_data') and 
                     isinstance(self.data_transformer.group_data, dict) and 
                     len(self.data_transformer.group_data) > 0)
        
        has_proteins = (hasattr(self.data_transformer, 'protein_dict') and 
                       isinstance(self.data_transformer.protein_dict, dict) and 
                       len(self.data_transformer.protein_dict) > 0)
        
        return has_merged and has_groups and has_proteins
    
    def _reinitialize_visualization(self):
        """Reinitialize the visualization with new data"""
        with self.dynamic_output:
            clear_output(wait=True)
            
            # Always show status
            display(self.status_output)
            
            # Only show visualization if we have all required data
            if self._check_required_data():
                if self.selector is None:
                    self.selector = HeatmapDataHandler()
                
                self._update_selector_data()
                grid = self._generate_grid()
                display(grid)
                display(self.selector.label_order_output)
                
                if self.app:
                    display(self.app.get_layout())
                    self.app.show_plots()
    
    def _update_selector_data(self):
        """Update the selector with new data from data_transformer"""
        if self.selector is None:
            return
            
        # Update proteins dictionary
        self.selector.protein_dict = self.data_transformer.protein_dict
        
        # Update available proteins
        if hasattr(self.data_transformer, 'merged_df'):
            self.selector.available_proteins = set(
                self.data_transformer.merged_df['Master Protein Accessions'].str.split(';').str[0].unique()
            )
        
        # Update available grouping variables
        if hasattr(self.data_transformer, 'group_data'):
            self.selector.available_grouping_vars = [
                group['grouping_variable'] 
                for group in self.data_transformer.group_data.values()
            ]
        self.selector._on_data_changed_callback = self._on_selector_data_change
        
    def _generate_grid(self):
        """Generate the grid layout with visualization widgets"""
        if self.app is None:
            self.app = HeatmapPlotHandler(self.selector)
            # Set up bidirectional reference
            self.app.set_data_handler(self.selector)        
        else:
            # Update the app with new data from selector
            self.app.update_data(self.selector)
        
        # Get widgets from the selector
        sel_input_widgets, sel_output_widgets, sel_vert_button_box = self.selector.display_widgets()
        
        # Set fixed heights for widgets
        sel_input_widgets.layout.height = '300px'
        sel_input_widgets.layout.width = '400px'
        sel_output_widgets.layout.height = '300px'        
        sel_output_widgets.layout.width = '90%'  # Note: This was fixing the input widget again in original code
        sel_vert_button_box.layout.height = '400px'
        sel_vert_button_box.layout.width = '600px'

        # Create grid layout
        grid = GridspecLayout(
            2, 2, 
            width='800px',
            height='auto',
            grid_gap='5px',
        )
        
        # Create a container for the button box that spans both columns
        button_container = widgets.Box(
            children=[sel_vert_button_box],
            layout=widgets.Layout(
                display='flex',
                justify_content='flex-start',  # Left alignment
                width='100%'
            )
        )
        
        # Add widgets to the grid
        grid[0, 0] = sel_input_widgets  # Place in row 0, column 0
        grid[0, 1] = sel_output_widgets  # Place in row 0, column 1
        grid[1, :] = sel_vert_button_box   # Place in row 1, spanning both columns
        
        return grid

    def setup_widget_observers(self):
        """Setup observers for selector widgets"""
        if self.selector:
            # Existing observers
            self.selector.add_group_button.on_click(self._on_selector_data_change)
            self.selector.reset_button.on_click(self._on_selector_data_change)
            self.selector.update_label_button.on_click(self._on_selector_data_change)
            self.selector.update_order_button.on_click(self._on_selector_data_change)
            self.selector.reset_labelorder_button.on_click(self._on_selector_data_change)
    
    def _on_selector_data_change(self, change=None):
        """Handle changes in selector data"""
        if self.app:
            self.app.update_data(self.selector)
        self._on_widget_interaction(change)
    
    def _on_widget_interaction(self, change):
        """Handle widget interactions"""
        with self.dynamic_output:
            clear_output(wait=True)
            
            # Always show status
            display(self.status_output)
            
            # Generate and display the grid with updated widgets
            grid = self._generate_grid()
            display(grid)
            
            # Display additional widgets from the app layout
            if self.app:
                display(self.app.get_layout())
                self.app.show_plots()
                
    def display(self):
        """Display the visualization interface"""
        display(self.dynamic_output)
        #self._check_upload_status()  # Show initial status
        self._reinitialize_visualization()

In [7]:
class HeatmapPlotHandler:
    def __init__(self, selector):
        # Copy existing instance variables from selector
        instance_variables = {
            attr: getattr(selector, attr)
            for attr in dir(selector)
            if not callable(getattr(selector, attr))  # Exclude methods
            and not attr.startswith("__")            # Exclude magic methods
            and "button" not in attr                 # Exclude attributes containing "button"
        }
        for key, value in instance_variables.items():
            setattr(self, key, value)
      
        # Initialize with data from selector
        self.plot_heatmap, self.plot_zero = 'yes', 'no'
        self.selector = selector
        
        # Explicitly get protein name from selector
        self.protein_name_short = getattr(selector, 'protein_name_short', 'Unknown Protein')

        # List of valid gradient colormaps
        def get_valid_gradient_colormaps():
            return settings.valid_gradient_cmaps
        
        # List of valid discrete colormaps
        def get_valid_discretecolormaps():
            return settings.valid_discrete_cmaps

        def display_plotting_options(self):
            dropdown_layout = widgets.Layout(width='90%')
            self.plot_message = widgets.HTML("<h3><u>Ploting Options:</u></h3>")
        
            
            # Create the description label
            description_label = f'Plot Filter:  \n\n'
            

            # Add radio buttons for peptide filtering
            self.peptide_filter_radio = widgets.RadioButtons(
                options=[('All Peptides', 'all-peptides'), 
                        ('All Functional Peptides', 'bioactive-only'),
                        ('Selected Functional Peptides', 'functional-only')],
                value='all-peptides',  # default value
                description=description_label,
                disabled=False,
                style={'description_width': 'initial'},
                layout=dropdown_layout
            )


     
            self.ms_average_choice_dropdown = widgets.Dropdown(
                options=['yes', 'no', 'only'],
                description='Plot Averaged Data:',
                disabled=False,
                style={'description_width': 'initial'},
                layout=dropdown_layout,
            )
            self.bio_or_pep_dropdown = widgets.Dropdown(
                options=[('None', 'no'), ('Peptide Intervals', '1'), ('Bioactive Functions', '2')],
                description='Plot Specific Peptides:',
                disabled=False,
                style={'description_width': 'initial'},
                layout=dropdown_layout,
            )
            self.specific_select_multiple = widgets.SelectMultiple(
                options=[],
                description='Specific Options:',
                disabled=False,
                layout=widgets.Layout(display='none')  # Start hidden
            )
        
            # Attach the observer only to bio_or_pep_dropdown
            self.bio_or_pep_dropdown.observe(
                lambda change: on_selection_change(self, change),
                names='value'
            )
    
        def _on_filter_type_change(self, change):
            if change['type'] == 'change' and change['name'] == 'value':
                if hasattr(self, 'selector'):
                    self.selector.update_filter_type(change['new'])
  

        def create_plotting_widgets(self):
            # Generate filenames
            #generate_filenames(self)
    
            # Layouts
            description_layout_invisible = widgets.Layout(width='90%', overflow = 'visible')
            description_layout = widgets.Layout(width='90%', overflow = 'visible')
            dropdown_layout = widgets.Layout(width='50%', overflow = 'visible')
            dropdown_layout_large = widgets.Layout(width='90%', overflow = 'visible')
    
            # Color Widgets
            self.hm_selected_color = widgets.Dropdown(
                options=get_valid_gradient_colormaps(),
                value=default_hm_color,
                description='Heatmap:',
                layout=dropdown_layout,
                style={'description_width': 'initial'}
            )
    
            self.lp_selected_color = widgets.Dropdown(
                options=get_valid_discretecolormaps(),
                value=default_lp_color,
                description='Line Plot:',
                layout=dropdown_layout,
                style={'description_width': 'initial'}
            )
    
            self.avglp_selected_color = widgets.Dropdown(
                options=valid_discrete_cmaps,
                value=default_avglp_color,
                description='Avg Line Plot:',
                layout=dropdown_layout,
                style={'description_width': 'initial'}
            )
    
            self.color_message = widgets.HTML("<h3><u>Color Options:</u></h3>")
            self.color_widget_box = widgets.VBox([
                self.color_message,
                self.hm_selected_color,
                self.lp_selected_color,
                self.avglp_selected_color
            ])
                    
            #  Get protein name directly from class instance
            x_label = f"{self.protein_name_short} Sequence" if hasattr(self, 'protein_name_short') else "Protein Sequence"
            
            # Figure Label Widgets
            self.xaxis_label_input = widgets.Text(
                value=x_label,
                description='x-axis label:',
                layout=description_layout,
                style={'description_width': 'initial'}
            )
        
            def update_x_label():
                if hasattr(self.selector, 'protein_name_short'):
                    new_label = f"{self.selector.protein_name_short} Sequence"
                    print(f"Updating x-axis label to: {new_label}")  # Debug print
                    self.xaxis_label_input.value = new_label
        
            # Set up observer for selector changes
            if hasattr(self.selector, 'observe'):
                self.selector.observe(lambda change: update_x_label(), names=['protein_name_short'])

        
            self.yaxis_label_input = widgets.Text(
                value="Averaged Peptide Abundance",
                description='y-axis label:',
                layout=description_layout,
                style={'description_width': 'initial'}
            )
            self.yaxis_position = widgets.IntSlider(
                value=0,
                min=-10,
                max=10,
                step=1,
                layout=description_layout,
                description='y-axis title position:',
                style={'description_width': 'initial'}
            )
    
            
            self.legend_title_input_1 = widgets.Text(
                value=legend_title[0],
                description=f'Legend title ({legend_title[0]}):',
                layout=description_layout,
                style={'description_width': 'initial'}

            )
            
            self.legend_title_input_2 = widgets.Text(
                value=legend_title[1],
                description=f'Legend title ({legend_title[1]}):',
                layout=description_layout,
                style={'description_width': 'initial'}

            )
            
            self.legend_title_input_3 = widgets.Text(
                value=legend_title[2],
                description=f'Legend title ({legend_title[2]}):',
                layout=description_layout,
                style={'description_width': 'initial'}

            )
            
            self.legend_title_input_4 = widgets.Text(
                value=legend_title[3],
                description=f'Legend title ({legend_title[3]}):',
                layout=description_layout,
                style={'description_width': 'initial'}

            )
            
            self.legend_title_input_5 = widgets.Text(
                value=legend_title[4],
                description=f'Legend title ({legend_title[4]}):',
                layout=description_layout,
                style={'description_width': 'initial'}
            )
            # Conditional Widgets
            if self.ms_average_choice == 'yes' and self.bio_or_pep == '1':
                self.legend_title_input_1 = widgets.Text(
                    value=self.legend_title[0],
                    description=f'Legend title ({self.legend_title[0]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_3 = widgets.Text(
                    value=self.legend_title[2],
                    description=f'Legend title ({self.legend_title[2]}):',
                    layout=description_layout_invisible,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_3.layout.display = 'none'
                self.legend_title_input_4 = widgets.Text(
                    value=self.legend_title[3],
                    description=f'Legend title ({self.legend_title[3]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_5 = widgets.Text(
                    value=self.legend_title[4],
                    description=f'Legend title ({self.legend_title[4]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )

            if self.ms_average_choice == 'yes' and self.bio_or_pep == '2':
                self.legend_title_input_1 = widgets.Text(
                    value=self.legend_title[0],
                    description=f'Legend title ({self.legend_title[0]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_3 = widgets.Text(
                    value=self.legend_title[2],
                    description=f'Legend title ({self.legend_title[2]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_4 = widgets.Text(
                    value=self.legend_title[3],
                    description=f'Legend title ({self.legend_title[3]}):',
                    layout=description_layout_invisible,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_4.layout.display = 'none'
                self.legend_title_input_5 = widgets.Text(
                    value=self.legend_title[4],
                    description=f'Legend title ({self.legend_title[4]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
            
            if self.ms_average_choice == 'yes' and self.bio_or_pep == 'no':
                self.legend_title_input_1 = widgets.Text(
                    value=self.legend_title[0],
                    description=f'Legend title ({self.legend_title[0]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_3 = widgets.Text(
                    value=self.legend_title[2],
                    description=f'Legend title ({self.legend_title[2]}):',
                    layout=description_layout_invisible,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_3.layout.display = 'none'
                self.legend_title_input_4 = widgets.Text(
                    value=self.legend_title[3],
                    description=f'Legend title ({self.legend_title[3]}):',
                    layout=description_layout_invisible,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_4.layout.display = 'none'
                self.legend_title_input_5 = widgets.Text(
                    value=self.legend_title[4],
                    description=f'Legend title ({self.legend_title[4]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
            
            if self.ms_average_choice == 'no' and self.bio_or_pep == '1':
                self.legend_title_input_1 = widgets.Text(
                    value=self.legend_title[0],
                    description=f'Legend title ({self.legend_title[0]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_3 = widgets.Text(
                    value=self.legend_title[2],
                    description=f'Legend title ({self.legend_title[2]}):',
                    layout=description_layout_invisible,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_3.layout.display = 'none'
                self.legend_title_input_4 = widgets.Text(
                    value=self.legend_title[3],
                    description=f'Legend title ({self.legend_title[3]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_5 = widgets.Text(
                    value=self.legend_title[4],
                    description=f'Legend title ({self.legend_title[4]}):',
                    layout=description_layout_invisible,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_5.layout.display = 'none'
            
            if self.ms_average_choice == 'no' and self.bio_or_pep == '2':
                self.legend_title_input_1 = widgets.Text(
                    value=self.legend_title[0],
                    description=f'Legend title ({self.legend_title[0]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_3 = widgets.Text(
                    value=self.legend_title[2],
                    description=f'Legend title ({self.legend_title[2]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_4 = widgets.Text(
                    value=self.legend_title[3],
                    description=f'Legend title ({self.legend_title[3]}):',
                    layout=description_layout_invisible,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_4.layout.display = 'none'
                self.legend_title_input_5 = widgets.Text(
                    value=self.legend_title[4],
                    description=f'Legend title ({self.legend_title[4]}):',
                    layout=description_layout_invisible,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_5.layout.display = 'none'
                        
            if self.ms_average_choice == 'only':
                self.legend_title_input_1 = widgets.Text(
                    value=self.legend_title[0],
                    description=f'Legend title ({self.legend_title[0]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_3 = widgets.Text(
                    value=self.legend_title[2],
                    description=f'Legend title ({self.legend_title[2]}):',
                    layout=description_layout_invisible,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_3.layout.display = 'none'
                self.legend_title_input_4 = widgets.Text(
                    value=self.legend_title[3],
                    description=f'Legend title ({self.legend_title[3]}):',
                    layout=description_layout_invisible,
                    style={'description_width': 'initial'}
                )
                self.legend_title_input_4.layout.display = 'none'
                self.legend_title_input_5 = widgets.Text(
                    value=self.legend_title[4],
                    description=f'Legend title ({self.legend_title[4]}):',
                    layout=description_layout,
                    style={'description_width': 'initial'}
                )
                        
            # Plot Widgets
            self.plot_port = widgets.ToggleButton(
                value=True,
                description='Portrait Plot',
                disabled=False,
                button_style='',
                tooltip='Show updated plot',
                icon='check'
            )
    
            self.plot_land = widgets.ToggleButton(
                value=True,
                description='Landscape Plot',
                disabled=False,
                button_style='',
                tooltip='Show updated plot',
                icon='check'
            )
                 
            self.create_plot_message = widgets.HTML("<h3><u>Create Plot Checkboxs:</u></h3>")
           
            self.plot_toggle_buttons = widgets.HBox([
                self.plot_port,
                self.plot_land
            ])
            
            self.plot_toggle_widget_box = widgets.VBox([
                self.create_plot_message,
                self.plot_toggle_buttons,

            ])     
            self.figure_label_message = widgets.HTML("<h3><u>Figure Label Options:</u></h3>")
            
            self.figure_label_box = widgets.VBox([
                self.figure_label_message,
                self.xaxis_label_input,
                self.yaxis_label_input,
                self.yaxis_position,
                self.legend_title_input_1,
                self.legend_title_input_2,
                self.legend_title_input_3,
                self.legend_title_input_4,
                self.legend_title_input_5,
                #self.plot_toggle_widget_box 
            ], layout=widgets.Layout(
            width='100%',
            #height='370px', # with plot_toggle_widget
            height='300px',
            margin='0px')
            )


            """  
            self.filename_port_input = widgets.Text(
                value=self.display_filename_port,
                description='Filename (Portrait):',
                placeholder='Enter custom filename (optional)', # Custom placeholder message
                layout=dropdown_layout_large,
                style={'description_width': 'initial'}
            )
    
            self.filename_land_input = widgets.Text(
                value=self.display_filename_land,
                description='Filename (Landscape):',
                placeholder='Enter custom filename (optional)', # Custom placeholder message

                layout=dropdown_layout_large,
                style={'description_width': 'initial'}
            )
            
            self.filename_label_message = widgets.HTML("<h3><u>Save As Options</u></h3>")

            self.plot_filename_widget_box = widgets.VBox([
                self.filename_label_message,
                self.filename_port_input,
                self.filename_land_input
            ])
            """
            # Add buttons for update and save plot
            self.update_button = widgets.Button(
                description='Generate/Update',
                button_style='success',
                tooltip='Click to update the plot',
                icon='refresh'
            )
    
            self.save_button = widgets.Button(
                description='Save Plot',
                button_style='info',
                tooltip='Click to save the plot',
                icon='save',
                disabled = True
            )

            self.update_save_box = widgets.HBox([self.update_button, self.save_button])

        def on_dropdown_change(self, change):
            self.ms_average_choice = self.ms_average_choice_dropdown.value
            self.bio_or_pep = self.bio_or_pep_dropdown.value
            if self.bio_or_pep != 'no':
                self.selected_bio_or_pep = self.specific_select_multiple.value
            else:
                self.selected_bio_or_pep = []
    
            if self.bio_or_pep != 'no' and self.selected_bio_or_pep:
                self.selected_peptides, self.selected_functions = proceed_with_label_specific_options(self.selected_bio_or_pep, self.bio_or_pep)
            else:
                self.selected_peptides, self.selected_functions = [], []
    
            # Call the method to create plotting widgets
            create_plotting_widgets(self)

        
        # Function to handle updates

        def extract_non_zero_non_nan_values(df):
            unique_functions = set()
            # Iterate over each value in the DataFrame
            for value in df.stack().values:  # df.stack() stacks the DataFrame into a Series
                if value != 0 and not pd.isna(value):  # Check if value is non-zero and not NaN
                    if isinstance(value, str):
                        # If the value is a string, it could contain multiple delimited entries
                        entries = value.split('; ')
                        unique_functions.update(entries)
                    else:
                        unique_functions.add(value)
            return unique_functions
    
        
        def get_interval_start(peptide_interval):
            """Extract the first number from a peptide interval string."""
            try:
                return int(peptide_interval.split('-')[0])
            except (ValueError, AttributeError, IndexError):
                return float('inf')  # Return infinity for invalid formats to put them at the end
        
        def on_selection_change(self, change):
            """Handle changes in selection with numerically sorted peptide intervals."""
            if change['type'] == 'change' and change['name'] == 'value':
                self.bio_or_pep = self.bio_or_pep_dropdown.value
        
                # Initialize containers for unique values
                unique_functions = set()
                unique_peptides = set()
        
                # Aggregate unique functions and peptides from available data
                for var in self.available_data_variables:
                    df = self.available_data_variables[var]['function_heatmap_df']
                    df.replace('0', 0, inplace=True)  # Standardize zero representations
                    unique_functions.update(extract_non_zero_non_nan_values(df))
                            
                abs_df = self.available_data_variables[var]['heatmap_df']
                # First, get all non-special columns
                potential_peptide_columns = [col for col in abs_df.columns 
                                             if col not in ['AA', 'count', 'average']]
                
                # Convert all potential columns to numeric at once
                numeric_df = abs_df[potential_peptide_columns].apply(pd.to_numeric, errors='coerce')
                
                # Find columns with any non-zero values
                columns_with_data = numeric_df.columns[(numeric_df != 0).any()].tolist()
                
                # Update unique_peptides
                unique_peptides.update(columns_with_data)    
                
                # Update widget based on dropdown choice
                if self.bio_or_pep == '1':  # Peptide Intervals
                    # Convert to list and sort by the first number in the interval
                    unique_peptides_list = sorted(list(unique_peptides), key=get_interval_start)
                    self.specific_select_multiple.options = [(peptide, peptide) for peptide in unique_peptides_list]
                    self.specific_select_multiple.layout.display = 'block'
                    
                elif self.bio_or_pep == '2':  # Bioactive Functions
                    unique_functions_list = sorted(list(unique_functions))
                    self.specific_select_multiple.options = [(function, function) for function in unique_functions_list]
                    self.specific_select_multiple.layout.display = 'block'
                    
                else:
                    self.specific_select_multiple.options = [""]
                    self.specific_select_multiple.layout.display = 'none'

     
        # Call the method to create plotting widgets
        create_plotting_widgets(self)

        # Attach observer functions to widgets
        display_plotting_options(self)

        def on_filter_type_change(change):
            if change['type'] == 'change' and change['name'] == 'value':
                # Update the selector's filter type
                self.selector.update_filter_type(change['new'])
        
        # Now add the observer after the widget exists
        self.peptide_filter_radio.observe(on_filter_type_change, names='value')       
        
        # Attach observer functions to widgets
        self.ms_average_choice_dropdown.observe(
            lambda change: on_dropdown_change(self, change), names='value'
        )
        self.bio_or_pep_dropdown.observe(
            lambda change: on_dropdown_change(self, change), names='value'
        )
        self.specific_select_multiple.observe(
            lambda change: on_dropdown_change(self, change), names='value'
        )

        # Manually trigger the function once to use default values at the start
           #on_dropdown_change(self, None)
 

    
        # Create plot output widget
        self.plot_output = widgets.Output(layout=widgets.Layout(
        width='100%',
        height='100%',  # Automatically adjust the height to fit the content
        #overflow='hidden'  # Eliminate scrollbars)
        ))
                # Attach button click events
        self.update_button.on_click(self.on_update_plot_clicked)
        self.save_button.on_click(self.on_save_plot_clicked)
     
    def update_data(self, selector):
        """Update the plot handler with new data from selector"""
        if not selector:
            return
                
        # Update only the essential data structures
        essential_attrs = [
            'available_data_variables',
            'user_protein_id',
            'protein_name_short',
            'selected_protein'
        ]
        
        for attr in essential_attrs:
            if hasattr(selector, attr):
                setattr(self, attr, getattr(selector, attr))
                # Update x-axis label when protein name changes
                if attr == 'protein_name_short' and hasattr(self, 'xaxis_label_input'):
                    new_label = f"{getattr(selector, attr)} Sequence"
                    self.xaxis_label_input.value = new_label
                    #print(f"Updated x-axis label to: {new_label}")  # Debug print

    def update_protein_name(self, new_name):
        """Update the protein name and x-axis label"""
        self.protein_name_short = new_name
        if hasattr(self, 'xaxis_label_input'):
            new_label = f"{new_name} Sequence" if new_name else "Protein Sequence"
            self.xaxis_label_input.value = new_label
            #print(f"Updated x-axis label via update_protein_name to: {new_label}")  # Debug print
                             
    def on_update_plot_clicked(self, b):
        with self.plot_output:
            try:
                # Clear any existing plots and free memory
                self.plot_output.clear_output(wait=True)
                plt.close('all')  
                
                # Reset figure references
                self.fig_port = None
                self.fig_land = None

                self.save_button.disabled = False

                # Force garbage collection
                import gc
                gc.collect()
                # Call the update_plot function
                self.fig_port, self.fig_land = update_plot(
                    self.available_data_variables, 
                    self.ms_average_choice, 
                    self.bio_or_pep, 
                    self.selected_peptides, 
                    self.selected_functions, 
                    self.hm_selected_color.value, 
                    self.lp_selected_color.value, 
                    self.avglp_selected_color.value, 
                    self.xaxis_label_input.value, 
                    self.yaxis_label_input.value, 
                    self.yaxis_position.value, 
                    self.legend_title_input_1.value, 
                    self.legend_title_input_2.value, 
                    self.legend_title_input_3.value, 
                    self.legend_title_input_4.value, 
                    self.legend_title_input_5.value, 
                    self.plot_land.value, 
                    self.plot_port.value,
                    self.peptide_filter_radio.value
                )
                
                # Display only valid figures
                if self.fig_port is not None and len(self.fig_port.axes) > 0:
                    display(HTML(f'<br><h2>Portrait Averaged Plot:</h2>'))
                    display(self.fig_port)
    
                if self.fig_land is not None and len(self.fig_land.axes) > 0:
                    display(HTML(f'<br><h2>Landscape Averaged Plot:</h2>'))
                    display(self.fig_land)
                    
            except Exception as e:
                # Display error message
                display(HTML(f'<br><b style="color:red;">Error generating plots: {str(e)}</b>'))
                
            finally:
                # Always cleanup
                plt.close('all')
                
                # Re-attach event handler if needed
                if not self.update_button.has_trait('on_click'):
                    self.update_button.on_click(self.on_update_plot_clicked)
                    
    def create_download_link(self, fig, filename):
        """Create a download button for the figure"""
        from io import BytesIO
        import base64
        
        # Save figure to a temporary buffer.
        buf = BytesIO()
        fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
        buf.seek(0)
        
        # Create the download button
        download_button = widgets.Button(
            description=f'Download {filename}',
            button_style='info',
            icon='download'
        )
        
        # Encode the image
        b64 = base64.b64encode(buf.read()).decode()
        
        def on_button_clicked(b):
            from IPython.display import HTML
            # Create HTML with download link and click it
            download_link = f'<a href="data:image/png;base64,{b64}" download="{filename}" id="download_link_{filename}"></a>'
            display(HTML(download_link))
            display(HTML(f'''<script>
                document.getElementById("download_link_{filename}").click();
            </script>'''))
        
        # Set up the button click handler
        download_button.on_click(on_button_clicked)
        
        # Automatically trigger the click
        on_button_clicked(download_button)
        
        return download_button
    
    def on_save_plot_clicked(self, b):
        # Disable the save button temporarily to prevent multiple clicks
        b.disabled = True
        
        try:
            with self.plot_output:
                # Check if plots exist before proceeding
                if (self.fig_port is None and self.fig_land is None) or \
                   (self.fig_port is not None and len(self.fig_port.axes) == 0 and 
                    self.fig_land is not None and len(self.fig_land.axes) == 0):
                    self.plot_output.clear_output(wait=True)
                    display(HTML("<div style='display: inline-block; margin: 10px 0;'><b style='color: red'>Please generate plots using the Update/Display button before saving.</b></div>"))
                    b.disabled = False  # Re-enable the button
                    return
    
                # Store existing figures before clearing
                port_fig = self.fig_port
                land_fig = self.fig_land
                
                # Clear previous output and show loading message
                self.plot_output.clear_output(wait=True)
                display(HTML("<div style='display: inline-block; margin: 10px 0;'><b style='color:blue'>Generating high resolution image for download. Please wait...</b></div>"))
                
                # Create filenames
                additional_vars = []
                if self.plot_heatmap == 'yes':
                    additional_vars.append('heatmap')
                elif self.plot_heatmap == 'no':
                    additional_vars.append('no-heatmap')
        
                if self.bio_or_pep == '1':
                    additional_vars.append('intervals')
                elif self.bio_or_pep == '2':
                    additional_vars.append('bioactive-functions')
                elif self.bio_or_pep == 'no':
                    additional_vars.append('averages-only')
        
                additional_vars_str = '_'.join(additional_vars)
                self.protein_filename_short = re.sub(r'[^\w-]', '-', self.protein_name_short)
                self.filename_port = f'portrait_{self.user_protein_id}_{self.protein_filename_short}_average-only'
                self.filename_land = f'landscape_{self.user_protein_id}_{self.protein_filename_short}_{additional_vars_str}'
    
                # Create and trigger downloads
                download_buttons = []
                if port_fig is not None and len(port_fig.axes) > 0:
                    port_button = self.create_download_link(port_fig, f"{self.filename_port}.png")
                    download_buttons.append(port_button)
                    
                if land_fig is not None and len(land_fig.axes) > 0:
                    land_button = self.create_download_link(land_fig, f"{self.filename_land}.png")
                    download_buttons.append(land_button)
                
                # Display final content
                self.plot_output.clear_output(wait=True)
                if download_buttons:
                    display(HTML("<div style='display: inline-block; margin: 10px 0;'><b style='color:green'>Success! Your plots have been downloaded automatically.</b></div>"))
                    
                    if port_fig is not None and len(port_fig.axes) > 0:
                        display(HTML(f'<br><h2>Portrait Averaged Plot:</h2>'))
                        display(port_fig)
                        
                    if land_fig is not None and len(land_fig.axes) > 0:
                        display(HTML(f'<br><h2>Landscape Averaged Plot:</h2>'))
                        display(land_fig)
                else:
                    display(HTML("<div style='display: inline-block; margin: 10px 0;'><p style='color: red;'>No plots were generated to download.</p></div>"))
                
        except Exception as e:
            self.plot_output.clear_output(wait=True)
            error_message = f"An error occurred while saving plots: {str(e)}"
            display(HTML(f"<div style='display: inline-block; margin: 10px 0;'><b style='color: red'>{error_message}</b></div>"))
        
        finally:
            # Always re-enable the button and cleanup
            b.disabled = False
            
    def get_layout(self):
        # Create a grid layout with 4 rows and 3 columns
        grid = GridspecLayout(
            1, 1,  # Number of rows and columns
            width='800px', 
            height='auto',
            grid_gap='5px',  # Adjust spacing between grid elements
        )
            
        # Row 0, Column 0: Input widgets
        input_widgets = VBox([
            self.plot_message,
            self.peptide_filter_radio,  # Add the radio buttons here
            self.ms_average_choice_dropdown,
            self.bio_or_pep_dropdown,
            self.specific_select_multiple,
            self.color_widget_box,
            self.figure_label_box,
            #self.plot_filename_widget_box,
            widgets.HTML("<h3><u>Display and Save Plot</u></h3>"),
            self.update_save_box
        ])
        input_widgets.layout.width = '500px'

        grid[0, 0] = input_widgets  # Place in row 0, column 0
        
        return grid
    
        #return input_widgets,  self.figure_label_box, self.update_save_box

    def show_plots(self):
        display(self.plot_output)  # Span across all columns in row 3\

    def set_data_handler(self, data_handler):
        """Set the reference to the data handler"""
        self.data_handler = data_handler
        data_handler._plot_handler = self


In [None]:

# Initialize the data transformer
data_transformer = DataTransformation()

# Setup the data loading UI
data_transformer.setup_data_loading_ui()

# Initialize and display the visualization handler
viz_handler = DynamicVisualizationHandler(data_transformer)
viz_handler.display()

Output(layout=Layout(height='100%', width='100%'))