In [80]:
# Core data processing libraries
import pandas as pd
import numpy as np

from pathlib import Path
# Import required libraries
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

# Visualization libraries
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize
import matplotlib.lines as mlines
import matplotlib.patches as patches
import seaborn as sns
from itertools import cycle

# Utility imports
import csv, json, traitlets, re, io, warnings, os, copy, base64, gc, traceback
from io import BytesIO
from functools import partial
from itertools import combinations

# Jupyter-specific imports

from traitlets import HasTraits, Instance, observe
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
from ipywidgets import (
    interact, interactive, fixed, interact_manual,
    GridspecLayout, VBox, HBox, Layout, Output
)
from ipydatagrid import DataGrid

# Initialize settings
import _settings as settings
from utils.uniprot_client import UniProtClient
warnings.simplefilter(action='ignore', category=FutureWarning)

# Global variables from settings
spec_translate_list = settings.SPEC_TRANSLATE_LIST
valid_discrete_cmaps = settings.valid_discrete_cmaps
valid_gradient_cmaps = settings.all_gradient_cmaps
# Define default values for the color maps
default_hm_color = 'RdYlGn_r'
#default_hm_color = 'Purples'

default_lp_color = 'Set3'
default_avglp_color = 'Dark2'

# Define the color map for the heatmap
hm_selected_color = default_hm_color
cmap = plt.get_cmap(hm_selected_color)

# Define the color map for the individual line plots
lp_selected_color = default_lp_color

# Define the color for the averaged line plots
avglp_selected_color = default_avglp_color
avg_cmap = plt.get_cmap(avglp_selected_color)

# Define settings for different numbers of variables


chuck_size = 78
plot_heatmap, plot_zero = 'yes', 'no'
legend_title = ['Sample Type:','Peptide Counts:','Bioactivity Function:','Peptide Interval:', 'Average Absorbance:']


In [81]:
"""-----------------Export_Function---------------------------------""" 
# Function to calculate y-ticks based on the min and max values of datasets
def calculate_y_ticks(min_values, max_values):
    """
    Calculate y-ticks with proper handling of zero/negative values
    """
    # Check for empty lists or all zeros
    if not min_values or not max_values:
        return [0, 1, 10]  # Default scale if no data
        
    # Filter out zeros and negative values, keep valid positives
    valid_mins = [x for x in min_values if x > 0]
    valid_maxs = [x for x in max_values if x > 0]
    
    # If no valid values after filtering
    if not valid_mins or not valid_maxs:
        return [0, 1, 10]
    
    try:
        overall_min = np.nanmin(valid_mins)
        overall_max = np.nanmax(valid_maxs)
        
        # If min or max are zero or negative after nanmin/nanmax
        if overall_min <= 0 or overall_max <= 0:
            return [0, 1, 10]
            
        min = np.floor(overall_min)
        #min_power = 10 ** np.floor(np.log10(overall_min))
        max = np.ceil(overall_max)
        #max_power = 10 ** np.ceil(np.log10(overall_max))
        
        # Calculate midpoint in log space
        #mid_point = np.sqrt(min_power * max_power)
        #mid_point = (min_power + max_power)/2
        mid_point = np.average([min, max])
        mid_point_rounded = np.round(mid_point)
        #mid_point_rounded = 10 ** np.round(np.log10(mid_point))
        
        #return [min_power, mid_point_rounded, max_power]
        return [min, mid_point, max]
        
    except (ValueError, RuntimeWarning):
        return [0, 1, 10]  # Fallback for any calculation errors

def calculate_abundance(protein_sequence, peptide_dataframe, grouping_variable):
    protein_sequence_length = len(protein_sequence)
    data = []
    Avg_column = f'Avg_{grouping_variable}'

    for idx, row in peptide_dataframe.iterrows():
        try:
            start_idx = int(row['start']) - 1
            stop_idx = int(row['stop'])
            abundance_value = row[Avg_column]

            values = [0] * protein_sequence_length
            if abundance_value > 0:
                values[start_idx:stop_idx] = [abundance_value] * (stop_idx - start_idx)
            data.append(values)

        except (KeyError, ValueError) as e:
            print(f'{type(e).__name__}: {e} - Skipping this row')
            continue

    if not data:
        print("No valid data to process. Returning empty DataFrame.")
        return pd.DataFrame()

    # Create the initial DataFrame
    abundance_df = pd.DataFrame(data).T
    abundance_df.columns = [f'{int(row["start"])}-{int(row["stop"])}' for _, row in peptide_dataframe.iterrows()]

    
    # Initialize count list
    count_list = []
    
    # Check if abundance_df is empty or has no non-zero values
    if abundance_df.empty or abundance_df.eq(0).all().all():
        # Create all zeros for counts and averages
        count_list = [0] * len(protein_sequence)
        abundance_df['average'] = 0
    else:
        # Calculate counts and averages normally
        count_list = abundance_df.gt(0).sum(axis=1)
        abundance_df['average'] = abundance_df.replace(0, np.nan).mean(axis=1)
    
    # Assign the counts and amino acids
    abundance_df['count'] = count_list
    abundance_df['AA'] = list(protein_sequence)
    return abundance_df

def calculate_function(protein_sequence, peptide_dataframe, grouping_variable):
    protein_sequence_length = len(protein_sequence)
    data = []

    for _, row in peptide_dataframe.iterrows():
        start_idx = int(row['start'] - 1)
        stop_idx = int(row['stop'])
        if stop_idx > protein_sequence_length:
            stop_idx -= 1
            
        function_value = row['function'] if 'function' in peptide_dataframe.columns else np.nan
        
        values = [None] * protein_sequence_length
        for i in range(start_idx, stop_idx):
            values[i] = function_value

        data.append(values)

    function_df = pd.DataFrame(data).T
    function_df.columns = [f'{int(row["start"])}-{int(row["stop"])}' for _, row in peptide_dataframe.iterrows()]

    return function_df

def export_heatmap_data_to_dict(protein_id, group_key, group_info, protein_sequence, 
                               protein_species, protein_name, protein_df, is_all_null):
    """
    Exports the heatmap data to a dictionary based on filter type.
    """
    grouping_var = group_info['grouping_variable']
    relevant_columns = group_info['abundance_columns']
    
    # Check if required columns exist, if not create default DataFrame
    required_columns = ['function', 'start', 'stop']
    if not all(col in protein_df.columns for col in required_columns):
        filtered_df = pd.DataFrame({
            'start': [0],
            'stop': [1],
            'function': ['0']
        })
        filtered_protein_df = filtered_df.copy()
    else:
        filtered_protein_df = protein_df.dropna(subset=['function', 'start', 'stop'])
        filtered_df = filtered_protein_df[['start', 'stop', 'function']]

    # Calculate the heatmap data
    func_heatmap_df = calculate_function(protein_sequence, filtered_df, grouping_var)
    heatmap_df = calculate_abundance(protein_sequence, protein_df, grouping_var)
    filtered_heatmap_df = calculate_abundance(protein_sequence, filtered_protein_df, grouping_var)

    # Create the dictionary to store the results
    heatmap_data = {
        'protein_id': protein_id,
        'protein_sequence': protein_sequence,
        'protein_name': protein_name,
        'protein_species': protein_species,
        'func_heatmap_df': func_heatmap_df,
        'heatmap_df': heatmap_df,
        'filtered_heatmap_df': filtered_heatmap_df
    }
    
    return heatmap_data

def chunk_dataframe(df, chunk_size, exclude_columns=3):
    # Select all rows and all but the last 'exclude_columns' columns
    df_subset = df.iloc[:, :-exclude_columns] if exclude_columns else df
    
    # Check if df_subset is empty or all zeros
    if df_subset.empty or (df_subset == 0).all().all():
        # Return a single chunk with the default structure
        default_chunk = pd.DataFrame(
            np.zeros((chunk_size, df_subset.shape[1])), 
            columns=df_subset.columns
        )
        return [default_chunk]
    
    # Calculate the number of rows needed to make the last chunk exactly 'chunk_size'
    total_rows = df_subset.shape[0]
    remainder = total_rows % chunk_size
    
    if remainder != 0:
        # Rows needed to complete the last chunk
        rows_to_add = chunk_size - remainder
        # Create a DataFrame with zero values for the missing rows
        additional_rows = pd.DataFrame(
            np.zeros((rows_to_add, df_subset.shape[1])), 
            columns=df_subset.columns
        )
        # Append these rows to df_subset
        df_subset = pd.concat(
            [df_subset, additional_rows], 
            ignore_index=True, 
            copy=False,
            verify_integrity=True
        )
    
    # Create chunks of the DataFrame
    max_index = df_subset.index.max() + 1
    return [df_subset.iloc[i:i + chunk_size] for i in range(0, max_index, chunk_size)]
    
def process_available_data(available_data_variables, filter_type, selected_functions):
    """
    Process data for visualization when update_plot is called
    """
    if not available_data_variables:
        return 

    count_list = []
    min_values = []
    max_values = []
    seq_len_list = []
    chunk_size = 78
    
    # Process each variable's data for visualization
    for var in available_data_variables:
        if filter_type == 'all-peptides':
            df = available_data_variables[var]['heatmap_df']
        elif filter_type == 'bioactive-only':
            df = available_data_variables[var]['filtered_heatmap_df']
        elif filter_type == 'functional-only':
            if not selected_functions:
                error_html = '<span style="color: red; font-weight: bold;">Invalid Selection. Please select "Bioactive Functions" for "Plot Specific Peptides", with at least one function from the "Specific Options" dropdown menu selected.</span>'
                display(HTML(error_html))
                #return {'error': True,}    
            
            # Filter both functional and absorbance data for specific functions
            func_df = available_data_variables[var]['function_heatmap_df']
            if func_df is not None and not func_df.empty:
                # Get list of columns that contain any of the selected functions
                functional_positions = []
                for col in func_df.columns:
                    # Check if any selected function appears in any cell of this column
                    if any(any(func in str(cell) for func in selected_functions) for cell in func_df[col]):
                        functional_positions.append(col)
                              
                # Get absorbance data for these positions
                abs_df = available_data_variables[var]['filtered_heatmap_df']
                if abs_df is not None and not abs_df.empty and functional_positions:
                    # Only keep columns that exist in abs_df
                    valid_columns = [col for col in functional_positions if col in abs_df.columns]
                    if valid_columns:
                        # Get the filtered absorbance data for selected columns
                        selected_abs_df = abs_df[valid_columns]
                        
                        # Create DataFrame with sequence and selected columns
                        df = pd.DataFrame({
                            'AA': abs_df['AA'],
                        })
                        
                        # Add the selected columns
                        df = pd.concat([df, selected_abs_df], axis=1)
                        
                        # Now calculate count and average from all data
                        non_zero_mask = df.drop('AA', axis=1) > 0
                        df['count'] = non_zero_mask.sum(axis=1)
                        df['average'] = df.drop('AA', axis=1).where(non_zero_mask).mean(axis=1)
                    else:
                        df = None
                else:
                    df = None
            else:
                df = None

        # Get counts and MS data
        if df is None or df.empty or 'count' not in df.columns or 'average' not in df.columns:
            # Create default DataFrame with zero values
            protein_length = len(available_data_variables[var]['protein_sequence'])
            df = pd.DataFrame({
                'count': [0] * protein_length,
                'average': [0] * protein_length,
                'AA': list(available_data_variables[var]['protein_sequence'])
            })

        # Get counts and MS data
        peptide_counts = df['count']
        ms_data = df['average']
        count_list.append(peptide_counts)
  
        # Calculate min/max MS values
        min_ms = ms_data[ms_data > 0].min()
        max_ms = ms_data.max()
        min_values.append(min_ms)
        max_values.append(max_ms)
        
        # Add computed properties to the variable data
        available_data_variables[var].update({
            'peptide_counts': peptide_counts,
            'ms_data': ms_data,
            'ms_data_list': list(ms_data),
            'AA_list': df['AA'].tolist(),
            'max_peptide_counts': peptide_counts.max(),
            'min_peptide_counts': peptide_counts.min(),
            'max_ms_data': max_ms,
            'min_ms_data': min_ms,
            
            # Generate chunks
            'amino_acids_chunks': [
                available_data_variables[var]['protein_sequence'][i:i + chunk_size]
                for i in range(0, len(available_data_variables[var]['protein_sequence']), chunk_size)
            ],
            'peptide_counts_chunks': [
                peptide_counts[i:i + chunk_size]
                for i in range(0, len(peptide_counts), chunk_size)
            ],
            'ms_data_chunks': [
                ms_data[i:i + chunk_size]
                for i in range(0, len(ms_data), chunk_size)
            ]
        })
        
        # Process bioactive peptide data
        columns_to_include = df.columns.difference(['AA', 'count'])
        df_filtered = df[columns_to_include]
        
        available_data_variables[var].update({
            'bioactive_peptide_abs_df': df_filtered,
            #'bioactive_peptide_chunks': chunk_dataframe(df_filtered, chunk_size),
            #'bioactive_function_chunks': chunk_dataframe(available_data_variables[var]['function_heatmap_df'], chunk_size),
            'bioactive_peptide_func_df': available_data_variables[var]['function_heatmap_df']
        })
        
        seq_len_list.append(len(available_data_variables[var]['amino_acids_chunks'][0]))

    # Calculate global values
    global axis_number, total_plots, y_ticks
    max_sequence_length = max(seq_len_list)
    axis_number = len(available_data_variables) + 2
    num_sets = len(next(iter(available_data_variables.values()))['amino_acids_chunks'])
    total_plots = num_sets * axis_number
    style_map = assign_line_styles(available_data_variables)
    
    # Process counts
    if len(set([item for sublist in count_list for item in sublist])) >= 1:
        flat_list = [item for sublist in count_list for item in sublist]
        # Check if there are any non-zero values before filtering
        non_zero_values = [item for item in flat_list if item != 0]
        if non_zero_values:  # If there are non-zero values
            list_of_counts = set(non_zero_values)
            max_count = max(non_zero_values)
            num_unique_count = len(set(non_zero_values))
        else:  # If all values are zero
            list_of_counts = {0}
            max_count = 0
            num_unique_count = 1
    else:
        max_count = 0
        num_unique_count = 0
        list_of_counts = set()
    # Calculate number of colors
    num_colors = 6 if num_unique_count >= 6 else num_unique_count

    # Calculate global values
    if min_values and max_values:
        y_ticks = calculate_y_ticks(min_values, max_values)
        y_ticks_str = ', '.join(f'{tick:.2e}' for tick in y_ticks)
        y_ticks_html = f'<b>Max/Min of MS data (y-ticks):</b> {y_ticks_str}'
    else:
        y_ticks = [0.1, 1, 10]  # Default log scale values
        y_ticks_html = '<span style="color:red;">Insufficient data to calculate MS data y-ticks.</span>'
        y_ticks_html += f'<b>Protein Sequence Length:</b> {max_sequence_length}'

    return {
        'list_of_counts': list_of_counts,
        'min_values': min_values,
        'max_values': max_values,
        'seq_len_list': seq_len_list,
        'max_sequence_length': max_sequence_length,
        'y_ticks': y_ticks,
        'y_ticks_html': y_ticks_html,
        'max_count': max_count,
        'num_unique_count': num_unique_count,
        'num_colors': num_colors,
        'total_plots': total_plots,
        'style_map': style_map,
        'error': False

    }

# Update plot function
def update_plot(available_data_variables, ms_average_choice, bio_or_pep, selected_peptides, selected_functions,
                hm_selected_color, lp_selected_color, avglp_selected_color,
                xaxis_label, yaxis_label, yaxis_position, legend_title_input_1, legend_title_input_2,
                legend_title_input_3,  plot_land, plot_port, filter_type):
    if not available_data_variables:  # Check if data is available
        display(HTML(f'<br><span style="color:red;">No data is available for plotting.</span>'))
        return (None, None)  # Return tuple of None values
    result = process_available_data(available_data_variables, filter_type, selected_functions)

    # Unpack the dictionary into individual global variables
    
    import _settings as settings
    if result:
        # Unpack the dictionary into individual global variables
        global list_of_counts, min_values, max_values, seq_len_list, max_sequence_length
        global y_ticks_html, max_count, num_unique_count, num_colors, total_plots, style_map, y_ticks
        
        lineplot_height, scale_factor = port_hm_settings.get(len(available_data_variables), (20, 0.1))
        list_of_counts = result['list_of_counts']
        min_values = result['min_values']
        max_values = result['max_values']
        seq_len_list = result['seq_len_list']
        max_sequence_length = result['max_sequence_length']
        y_ticks_html = result['y_ticks_html']
        max_count = result['max_count']
        num_unique_count = result['num_unique_count']
        num_colors = result['num_colors']
        total_plots = result['total_plots']
        style_map = result['style_map']

        yaxis_position_land = yaxis_position_port = 0.0 - 0.01 * yaxis_position
        selected_legend_title = [legend_title_input_1, legend_title_input_2, legend_title_input_3, legend_title[4]]

        # Your plotting code here, using the widget values as inputs

        cmap = plt.get_cmap(hm_selected_color)
        avg_cmap = plt.get_cmap(avglp_selected_color)

        if ms_average_choice == 'no' and bio_or_pep == 'no':
            display(HTML(f'<br><b>Reslect "Ploting Options", no data is available for plotting.</b><br><br>'))
            
            return None, None

        else:
            """ #                                       Function Call to Generate Plot
            #---------------------------------------------------------------------------------------------------------------------------------------------
            """
            if plot_port and (bio_or_pep == 'no'):
                # Temporarily suppress specific warning
                with warnings.catch_warnings():
                    warnings.simplefilter('ignore', UserWarning)
                    if ms_average_choice in ['yes', 'only']:
                        fig_port = visualize_sequence_heatmap_portrait(
                            available_data_variables,
                            0.001,  
                            lineplot_height,  
                            1,  
                            xaxis_label,  
                            yaxis_label,  
                            selected_legend_title,  
                            yaxis_position_port,  
                            cmap,
                            avg_cmap,
                            lp_selected_color,
                            avglp_selected_color,
                            selected_functions,
                            ms_average_choice,
                            bio_or_pep,
                            78)  # removed chunk_size= to make it a positional argument
                    else:
                        display(HTML(f'<br><b>No was selected earlier regarding the plotting of averaged absorbances, preventing the plotting of the averaged plot.</b><br><br>'))
                    if max_count <= 1:
                        display(HTML(f'<br><b style="color:red;">You have to few peptides for proper heatmapping.</b>'))

            """_____________________________________________________EXECUTING CODE TO PLOT W/ UPDATES_________________________________________________________________"""

            # Plotting code
            if plot_land:
                amino_acid_height = 0.25 + 0.1 * (
                    len(available_data_variables) // 4 if len(available_data_variables) >= 8 else 0)
                indices_height = amino_acid_height + 0.08
                with warnings.catch_warnings():
                    warnings.simplefilter('ignore', UserWarning)
                    fig_land = visualize_sequence_heatmap_lanscape(
                        available_data_variables,
                        amino_acid_height,
                        7.5,
                        indices_height,
                        xaxis_label,
                        yaxis_label,
                        selected_legend_title,
                        yaxis_position_land,
                        cmap,
                        avg_cmap,
                        lp_selected_color,
                        avglp_selected_color,
                        selected_peptides,
                        selected_functions,
                        ms_average_choice,
                        bio_or_pep,
                        filter_type)

                if max_count <= 1:
                    display(HTML(f'<br><b style="color:red;">You have to few peptides for proper heatmapping.</b>'))
        if plot_port:
            if bio_or_pep != 'no':
                    display(HTML(f'<br><b>The portrait plot only supports ploting averaged abundance of peptides which is not formated to be combined with specified peptide intervals or functions.</b><br><br>'))
                    fig_port_return = None 
            else:
                if len(fig_port.axes) > 0:
                    fig_port_return = fig_port
        else:
            fig_port_return = None 
            
        if plot_land:
            if len(fig_land.axes) > 0:
                fig_land_return = fig_land
        else:
            fig_land_return = None         

        if not plot_port:
            if not plot_land:
                display(HTML(f"<br><br><p><strong>One or both of the create plot checkbox must be selected to generate a plot.</strong></p>"))

        return fig_port_return, fig_land_return # Return the list of figures
    else:
        display(HTML(f'<br><b>Reslect "Variable Options", no data is available for plotting.</b><br><br>'))

"""_________________________________________Data Visualization Functions_________________________________"""
# Function to plot rows of amino acids with backgrounds colored
def plot_row_color(ax, amino_acids, colors):
    ax.axis('off')
    ax.set_xlim(0, max_sequence_length)
    ax.set_xlabel('')
    for j, (aa, color) in enumerate(zip(amino_acids, colors)):
        ax.text(j + 0.5, 0.5, aa, color='black', ha='center', va='center', fontsize=14,
                backgroundcolor=mcolors.rgb2hex(color))

# Assigns line type for landscape plot if plotting individual peptides
def assign_line_styles(data_variables):
    # Define a set of line styles you find visually distinct
    line_styles = cycle(['-', '--', ':', '-.'])
    #line_styles = cycle(['-'])

    # Extract unique labels from your data variables
    unique_labels = set(data['label'] for data in data_variables.values())

    # Map each unique label to a line style
    style_map = {label: next(line_styles) for label in unique_labels}

    # Map each unique label to a line style
    #style_map = {'Gastric IVT': '--',
    #             'Intestinal IVT': ':',
    #             'J1H': '-', 'J2H': '-', 'J3H': '-', 'J4H': '-',}
    return style_map

# Function to plot rows of amino acids with NO backgrounds colored
def plot_row(ax, amino_acids):
    ax.axis('off')
    ax.set_xlim(0, max_sequence_length)
    ax.set_xlabel('')
    for j, (aa) in enumerate(amino_acids):
        ax.text(j + 0.5, 0.5, aa, color='black', ha='center', va='center', fontsize=8)  # backgroundcolor='white')

# Function to plot continuse averaged lines
def plot_average_ms_data(ax, ms_data, label, var_index, y_ticks, i, chunk_size, avg_cmap, line_style):
    """
    Plot average MS data as a line plot on a twin axis, incorporating a line style, and return the line object and its properties.
    """
    start_limit = i * chunk_size
    end_limit = (i + 1) * chunk_size - 1
    ax.set_xlim(start_limit, end_limit)

    if isinstance(ms_data, (pd.DataFrame, pd.Series)):
        x_values = ms_data.index.tolist()
        y_values = ms_data.values
    else:
        x_values = list(range(len(ms_data)))  # Indices from 0 to len(ms_data)-1
        y_values = ms_data

    num_colors = avg_cmap.N

    # Get the color from the colormap and save it into avglp_selected_color
    color = avg_cmap(var_index % num_colors)
    line, = ax.plot(x_values, y_values, label=label, color=color, linestyle=line_style)
    ax.set_yscale('log')
    ax.set_yticks(y_ticks)
    ax.set_ylim(min(y_ticks), max(y_ticks))
    ax.tick_params(axis='y', labelsize=16)
    ax.yaxis.tick_left()
    ax.set_xticks([])
    ax.set_xticklabels([])

    return line, label, var_index  # Return the line and its label for further processing if needed

# Function to extract non-zero, non-NaN values
def extract_non_zero_non_nan_values(df):
    unique_functions = set()
    # Iterate over each value in the DataFrame
    for value in df.stack().values:  # df.stack() stacks the DataFrame into a Series
        if value != 0 and not pd.isna(value):  # Check if value is non-zero and not NaN
            if isinstance(value, str):
                # If the value is a string, it could contain multiple delimited entries
                entries = value.split('; ')
                unique_functions.update(entries)
            else:
                unique_functions.add(value)
    return unique_functions

    # This function is used in plotting to filter the data to only plot selected peptides or bioactive peptides, independent of averaged data


def filter_data_by_selection(bp_abs, bp_func, selected_peptides, selected_functions, bio_or_pep, filter_type):
    """
    Filters the data based on user selection with improved column handling
    """
    # Early return if any DataFrame is None or empty
    if bp_abs is None or bp_abs.empty:
        print("DataFrame is None or empty")
        return pd.DataFrame(), pd.DataFrame()
        
    # Special handling for peptide selection (bio_or_pep = '1')
    if bio_or_pep == '1' and selected_peptides:
        try:
            # First identify which metadata columns we want to keep
            meta_columns = ['AA', 'count']
            keep_columns = [col for col in meta_columns if col in bp_abs.columns]
            
            # Then find our selected peptide columns that exist in the DataFrame
            peptide_columns = []
            for pep in selected_peptides:
                if pep in bp_abs.columns:
                    peptide_columns.append(pep)
            
           
            # Combine metadata and peptide columns
            all_columns = keep_columns + peptide_columns
            
            if peptide_columns:  # Only proceed if we found matching peptide columns
                # Create new DataFrames with only the columns we want
                filtered_bp_abs = bp_abs[all_columns].copy()
                filtered_bp_func = pd.DataFrame()

                return filtered_bp_abs, filtered_bp_func
            else:
                print("No matching peptide columns found")
                return pd.DataFrame(), pd.DataFrame()
                
        except Exception as e:
            print(f"Error in peptide filtering: {str(e)}")
            traceback.print_exc()
            return pd.DataFrame(), pd.DataFrame()
            
    # Handle other cases (bio_or_pep = '2' or 'no')
    elif bio_or_pep == '2' and selected_functions:
        # Filter by selected functions
        if bp_func is None or bp_func.empty:
            return pd.DataFrame(), pd.DataFrame()
            
        try:
            def has_selected_function(value):
                if pd.isna(value):
                    return False
                value_str = str(value)
                return any(func in value_str for func in selected_functions)
            
            # Create mask based on function presence
            mask = bp_func.apply(lambda col: col.apply(has_selected_function))
            cols_with_functions = mask.any()
            relevant_columns = cols_with_functions[cols_with_functions].index
            
            # Filter both DataFrames using the same columns
            if len(relevant_columns) > 0:
                filtered_bp_abs = bp_abs[relevant_columns]
                filtered_bp_func = bp_func[relevant_columns]
                return filtered_bp_abs, filtered_bp_func
            else:
                return pd.DataFrame(), pd.DataFrame()
                
        except Exception as e:
            print(f"Error in function filtering: {str(e)}")
            return pd.DataFrame(), pd.DataFrame()
            
    # No filtering case
    else:
        return bp_abs, bp_func

def process_chunk_data(ax2, chunk_abs, chunk_func, chunk_size, i, y_ticks, handles, labels, 
                      sample_list, var_name_list, line_style, var_name, var_ms_data, 
                      selected_peptides, selected_functions, lp_selected_color, ms_average_choice, bio_or_pep):
    

    """
    Process and plot chunk data with proper handling of peptides and functions.
    """
    print_list = []
    
    start_limit = i * chunk_size
    end_limit = (i + 1) * chunk_size - 1
    ax2.set_xlim(start_limit, end_limit)
    
    # Set up colormap
    common_columns = []
    colormap = plt.get_cmap(lp_selected_color)
        
    # Create a dictionary to map functions to colors
    if bio_or_pep == '1':
        num_colors = max(len(selected_peptides), 1)
        items_to_color = selected_peptides
    elif bio_or_pep == '2':
        num_colors = max(len(selected_functions), 1)
        items_to_color = selected_functions
    else:
        num_colors = 1
        items_to_color = []
    
    # Generate evenly spaced colors
    colors = colormap(np.linspace(0, 1, num_colors))
    
    # Create a dictionary mapping each function to a specific color
    function_colors = {item: colors[i % len(colors)] for i, item in enumerate(items_to_color)}
    function_colors['Multiple'] = 'black'  # Keep 'Multiple' as black
    
    # Track which functions we've already plotted
    plotted_functions = set()
    
    # Process each column in the abundance data
    for col in chunk_abs.columns:
        y_values = chunk_abs[col].dropna()
        y_values = y_values[y_values > 0]
        
        if not y_values.empty:
            x_values = y_values.index
            label_value = 'No Label'
            
            if bio_or_pep == '1':
                # For peptide intervals, just use the column name as label
                label_value = col
            elif bio_or_pep == '2' and col in chunk_func.columns:
                # For bioactive functions, determine the appropriate function label
                func_values = chunk_func[col].dropna()
                
                if not func_values.empty:
                    # Get unique functions for this peptide
                    peptide_functions = set()
                    for val in func_values:
                        if pd.notna(val):
                            if isinstance(val, str) and ';' in val:
                                peptide_functions.update(val.split(';'))
                            else:
                                peptide_functions.add(str(val))
                    
                    # Intersect with selected functions to get only those we care about
                    relevant_functions = peptide_functions.intersection(selected_functions)
                    
                    if len(relevant_functions) > 1:
                        # If multiple relevant functions, add to print list and use 'Multiple'
                        print_list.append(sorted(relevant_functions))
                        label_value = 'Multiple'
                    elif len(relevant_functions) == 1:
                        # If one relevant function, use it as the label
                        label_value = next(iter(relevant_functions))
            
            # Only plot if we have a valid label and it's in our selected items (or 'Multiple')
            if (label_value in items_to_color or label_value == 'Multiple') and label_value != 'No Label':
                color = function_colors.get(label_value, 'grey')
                
                # If we've already plotted this function and it's not 'Multiple', use a dotted line
                # to distinguish it from the same function in different peptides
                current_line_style = line_style
                if label_value in plotted_functions and label_value != 'Multiple':
                    current_line_style = 'dotted'
                
                try:
                    # Plot the line
                    lines = ax2.plot(x_values, y_values, 
                                   label=label_value,
                                   linestyle=current_line_style, 
                                   color=color)
                    line = lines[0]
                    
                    # Only add to legend if we haven't seen this function before
                    if label_value not in plotted_functions:
                        handles.append(line)
                        labels.append(label_value)
                        sample_list.append(current_line_style)
                        var_name_list.append(var_name)
                        plotted_functions.add(label_value)
                    
                except Exception as e:
                    print(f"Error plotting line for {label_value}: {str(e)}")
                    continue
    
    # Set up axis properties
    ax2.set_yscale('log')
    ax2.set_yticks(y_ticks)
    ax2.set_ylim(min(y_ticks), max(y_ticks))
    ax2.tick_params(axis='y', labelsize=16)
    ax2.yaxis.tick_left()
    
    return print_list

def get_grouped_colors(counts, max_count, num_groups, plot_zero, cmap):
    # Initialize colors list with None to maintain length
    colors = [None] * len(counts)

    # Set the start point based on the user's input
    start_point = 0 if plot_zero == 'yes' else 1

    # Group counts into fewer categories if necessary, excluding zeros
    group_bounds = np.linspace(start_point, max_count, num_groups + 1)
    group_labels = np.digitize(counts, bins=group_bounds, right=True)  # Find group labels for counts
    ##print("group_labels", group_labels)
    norm = Normalize(vmin=start_point, vmax=num_groups)  # Normalize based on the number of groups

    # Map each count to a color
    #for i, count in enumerate(counts):
    for i, (count, label) in enumerate(zip(counts, group_labels)):
        if plot_zero == 'no' and count == 0:
            colors[i] = 'white'
        else:
            # Find which group the count belongs to
            group_idx = np.digitize(count, group_bounds) - 1
            
            if max_count > 20:
                # For large ranges, use continuous color mapping
                colors[i] = cmap(count / max_count)
            else:
                # For smaller ranges, use discrete color groups
                #colors[i] = cmap(norm(group_bounds[min(group_idx, len(group_bounds)-1)]))
                colors[i] = cmap(norm(label))

    #print("label",label)
    return colors

# Function to create legend for heatmap
def create_heatmap_legend_handles(cmap, num_colors, max_count, plot_zero):
    """
    Create color-coded legend handles for the heatmap with proper handling of zero values
    and error cases
    """
    try:
        # Handle case where all values are zero
        if max_count == 0:
           norm = Normalize(vmin=0, vmax=1)
           color = cmap(norm(0))
           return [patches.Patch(color=color, label='0')], ['0']
                 
        # Set the start point based on plot_zero
        start_point = 0 if plot_zero == 'yes' else 1
        
        # Create group boundaries to match get_grouped_colors
        count_ranges = np.linspace(start_point, max_count, num_colors + 1)
        
        # Determine interval type based on count vs intervals
        if max_count > num_colors:
           plt_interval = max_count
        else:
           plt_interval = num_colors
           
        legend_handles = []
        heatmap_labels = []
        norm = Normalize(vmin=start_point, vmax=max_count)

        # For small ranges, start from index 1 to skip the duplicate 1
        start_idx = 1 if plt_interval <= 6 else 0
        for i in range(start_idx, len(count_ranges)):
            color = cmap(norm(count_ranges[i]))
            if plt_interval <= 6 and plot_zero == 'no':
                label = f'{int(count_ranges[i])}'
            elif plt_interval <= 6 and plot_zero == 'yes':
                label = f'{int(count_ranges[i])}'
            elif i + 1 >= len(count_ranges):
                label = f'{int(count_ranges[i])} - {max(count_ranges)}'
                break
            else:
                # Create label showing range
                start_val = int(count_ranges[i])
                end_val = int(count_ranges[i + 1])
                label = f'{start_val} - {end_val}'
                
            #if i = len(count_ranges):
            #    color = cmap(norm(count_ranges[i]))
 
            #print("i:",i,"label:",label)
            #legend_handles.append(patches.Patch(color=color, label=label))
            legend_handles.append(patches.Patch(color=color, label=label))
            heatmap_labels.append(label)
        return legend_handles, heatmap_labels
       
    except Exception as e:
       # Fallback for error cases
       color = cmap(0)
       handle = patches.Patch(color=color, label='0')
       return [handle], ['0']

def create_custom_legends(fig, labels, handles, var_name_list, legend_titles, heatmap_legend_handles,
                          heatmap_legend_labels, ms_average_choice, bio_or_pep, plot_type):
    
    handles_dict = {}
    sample_handles_dict = {}

    # Modify label if needed and populate dictionaries
    for label, handle, sample_name in zip(labels, handles, var_name_list):
        handles_dict[label] = handle  # Store or update the handle with modified label

        # Store handles for sample types (assuming sample_name correctly aligns with the handles)
        if sample_name not in sample_handles_dict:
            sample_handles_dict[sample_name] = handle

    # Create new handles for the legend with modified properties
    new_handles_func = [copy.copy(handle) for handle in handles_dict.values()]
    new_labels_func = [label for label in handles_dict.keys()]

    # Initial empty lists for combined handles and labels
    combined_handles = []
    combined_labels = []
    # Filter for "Averaged" labels
    averaged_handles = []
    averaged_labels = []
    other_handles = []
    other_labels = []

    for handle, label in zip(new_handles_func, new_labels_func):
        if "Averaged" in label:
            clean_label = label.replace("Averaged ", "")  # Remove 'Averaged ' from the label
            averaged_handles.append(handle)
            averaged_labels.append(clean_label)  # Append the cleaned label

        else:
            handle.set_linestyle('-')  # Set line style to solid
            other_handles.append(handle)
            other_labels.append(label)

    new_handles_samples = []
    if not averaged_handles:
        for handle in sample_handles_dict.values():
            new_handle = copy.copy(handle)
            new_handle.set_color('black')  # Set color to black for sample type handles
            new_handles_samples.append(new_handle)

    # Dummy handles for subtitles
    line_type = mlines.Line2D([], [], color='none', label='Line Type')
    average_color = mlines.Line2D([], [], color='none', label='Average Absorbance')
    line_color = mlines.Line2D([], [], color='none', label='Line Color')
    pep_count_placeholder = mlines.Line2D([], [], color='none', label='Line Type')

    line_type_title = legend_titles[0]
    avgline_color_title = legend_titles[2] if bio_or_pep == 'no' else legend_titles[3]
    color_title = legend_titles[2]

    if ms_average_choice == 'yes':
        if bio_or_pep != 'no':
            combined_handles = [line_color] + other_handles + [average_color] + averaged_handles
            combined_labels = [color_title] + other_labels + [avgline_color_title] + averaged_labels

        else: #if bio_or_pep == 'no':
            combined_handles = [average_color] + averaged_handles
            combined_labels = [avgline_color_title] + averaged_labels

    if ms_average_choice == 'only':
        combined_handles = [average_color] + averaged_handles
        combined_labels = [avgline_color_title] + averaged_labels
        
    if ms_average_choice == 'no' and bio_or_pep != 'no':
        combined_handles = [line_color] + other_handles + [line_type] + new_handles_samples
        combined_labels = [color_title] + other_labels + [line_type_title] + [key for key in sample_handles_dict.keys()]

    legend_peptide_count = None
    if plot_type == "land":
        # Create the peptide count legend separately
        if plot_heatmap == 'yes':
            legend_peptide_count = fig.legend(handles=heatmap_legend_handles, loc='center',
                                              fontsize=14,
                                              title=legend_titles[1],
                                              title_fontsize=14,
                                              bbox_to_anchor=(0.5, -0.1),
                                              ncol=len(heatmap_legend_handles))

        # Create the combined legend (for other handles/labels)
        combined_legend = fig.legend(handles=combined_handles, labels=combined_labels,
                                     loc='upper left', bbox_to_anchor=(0.99, 0.975),
                                     fontsize=14, handlelength=2)
        plt.tight_layout()
        fig.subplots_adjust(right=0.9)  # Make room for side legend

    # Create a single combined legend
    elif plot_type == "port":

        # Combine heatmap handles with other handles for a single legend
        combined_handles = [line_color] + other_handles + [pep_count_placeholder] + heatmap_legend_handles
        combined_labels = [avgline_color_title] + other_labels + [legend_titles[1]] + heatmap_legend_labels

        # Create a single combined legend with just the combined handles (no need for additional labels)
        combined_legend = fig.legend(handles=combined_handles,
                                     labels=combined_labels,
                                     loc='upper left',
                                     bbox_to_anchor=(0.9025, 0.875),
                                     fontsize=14)

    return combined_legend, legend_peptide_count

"""_________________________________________Landscape Plot________________________________________"""
### def visualize_sequence_heatmap_individual_lines(available_data_variables, amino_acid_height, lineplot_height, indices_height, filename, xaxis_label, yaxis_label, legend_title, yaxis_position):
def visualize_sequence_heatmap_lanscape(available_data_variables,
                                                         amino_acid_height,
                                                         lineplot_height,
                                                         indices_height,
                                                         xaxis_label,
                                                         yaxis_label,
                                                         legend_title,
                                                         yaxis_position,
                                                         cmap,
                                                         avg_cmap,
                                                         lp_selected_color,
                                                         avglp_selected_color,
                                                         selected_peptides,
                                                         selected_functions,
                                                         ms_average_choice,
                                                         bio_or_pep, filter_type):
    # Use a list comprehension to find the maximum length of 'AA_list' across all variables
    max_sequence_length = max([len(available_data_variables[var]['AA_list']) for var in available_data_variables])

    # Check if there are multiple distinct protein IDs in available_data_variables
    multiple_proteins = len(set([available_data_variables[var]['protein_id'] for var in available_data_variables])) > 1

    chunk_size = max_sequence_length
    # Create legend handles for the heatmap
    heatmap_legend_handles, heatmap_legend_labels = create_heatmap_legend_handles(cmap, num_colors, max_count, plot_zero)  # You can change the number 5 to have more or fewer color intervals

    # Function to plot rows of amino acids with backgrounds colored
    # Function to plot rows of amino acids with backgrounds colored
    def plot_row_color_landscape(ax, amino_acids, colors):
        ax.axis('off')
        ax.set_xlim(0, max_sequence_length)
        ax.set_xlabel('')
        # Height and width for each cell
        cell_width = 1  # Each amino acid is spaced evenly by 1 unit on the x-axis
        cell_height = 1  # Set height of the row

        for j, (aa, color) in enumerate(zip(amino_acids, colors)):
            # Create a rectangle (cell) with the background color
            rect = patches.Rectangle((j, 0), cell_width, cell_height, color=mcolors.rgb2hex(color))
            ax.add_patch(rect)  # Add the colored rectangle to the plot

    # Function to plot rows of amino acids with backgrounds colored
    def plot_row_landscape(ax, amino_acids):
        ax3.axis('off')
        ax3.set_xlim(0, max_sequence_length)
        ax3.set_xlabel('')
        aa_font_size = 10
        if max_sequence_length > 200:
            aa_font_size -= 0.5
        if max_sequence_length > 250:
            aa_font_size -= 1
        if max_sequence_length > 300:
            return
        for j, (aa) in enumerate(amino_acids):
                ax3.text(j + 0.5, 0.5, aa, color='black', ha='center', va='center',
                         fontsize=aa_font_size)

    axis_number = 3  # Total number of plots per set of data
    if plot_heatmap == 'yes':
        # Define height ratios for each subplot in a set
        height_ratios = (
                    [lineplot_height] + [indices_height] + [amino_acid_height] * len(available_data_variables) + [amino_acid_height])
        axis_number = len(available_data_variables) + 3

    elif plot_heatmap == 'no':
        height_ratios = ([lineplot_height] +[indices_height] + [amino_acid_height] )
        axis_number = 3

    fig, axes = plt.subplots(axis_number, 1, figsize=(25, (lineplot_height + indices_height + amino_acid_height)),
                             gridspec_kw={'height_ratios': height_ratios, 'hspace': 0})


    # Initialize for legend handling
    handles, labels, sample_list, var_name_list = [], [], [], []
    total_count = 0

    # Loop through each set of data and create plots
    for var_index, var in enumerate(available_data_variables):
        ax1 = axes[0]
        ax1.axis('off')
        ax1 = ax1.twinx()  # Create a twin y-axis
        ax1.yaxis.set_minor_locator(plt.NullLocator())

        var_amino_acids = available_data_variables[var]['AA_list']
        var_counts = available_data_variables[var]['peptide_counts']
        var_ms_data = available_data_variables[var]['ms_data_list']
        var_name = available_data_variables[var]['label']
        var_colors = get_grouped_colors(var_counts, max_count, num_colors, plot_zero, cmap)
        bp_abs = available_data_variables[var]['bioactive_peptide_abs_df']
        bp_func = available_data_variables[var]['bioactive_peptide_func_df']
        # Default line style
        line_style = '-'  # Default style if other conditions don't apply

        #if ms_average_choice == 'yes' and bio_or_pep == 'no':
        #   line_style = '-'  # This can stay '-' or change as per your requirement
        #else:
        line_style = style_map[var_name]  # Get assigned line style from the style map

        if bio_or_pep != 'no' and ms_average_choice != 'only':
            filtered_bp_abs, filtered_bp_fun = filter_data_by_selection(bp_abs, bp_func, selected_peptides,
                                                                        selected_functions, bio_or_pep, filter_type)
            print_list = process_chunk_data(ax1, filtered_bp_abs, filtered_bp_fun, chunk_size, 0, y_ticks, handles,
                                            labels, sample_list, var_name_list, line_style, var_name, var_ms_data,
                                            selected_peptides, selected_functions, lp_selected_color, ms_average_choice,
                                            bio_or_pep)

        if ms_average_choice in ['yes', 'only']:
            # Ensure that line_style is defined before this point or provide a default value
            line, label, _ = plot_average_ms_data(ax1, var_ms_data, f'Averaged {var_name}', var_index, y_ticks, 0,
                                                  chunk_size, avg_cmap, line_style)
            handles.append(line)
            labels.append(f'{label}')
            sample_list.append(line_style)
            var_name_list.append(var_name)

        # Plot indices below the MS line plot
        ax2 = axes[1]
        ax2.axis('off')
        ax2.set_xlim(0, max_sequence_length)
        indices = [0]
        if var_index == 0:
            # Add indices at increments of 20, starting from 20 up to the length of the array, but not including the last index if it's less than 20 away
            indices.extend(range(20, max_sequence_length - 5, 20))
            # Always add the last index of the array
            indices.append(max_sequence_length - 1)
            for idx in indices:
                ax2.text(idx + 0.5, 0.5, str(total_count + idx + 1), ha='center', va='center', fontsize=16)
            total_count += max_sequence_length

            # Amino acid plots
            ax3 = axes[2]

            # Check if there's only one distinct protein
            if not multiple_proteins:
                plot_row_landscape(ax3, var_amino_acids)  # Call plot_row_landscape if only 1 protein
            else:
                ax3.axis('off')  # Plot a blank line by turning off the axis

        if plot_heatmap == 'yes':
            # Amino acid plots
            ax = axes[var_index + 3]
            plot_row_color_landscape(ax, var_amino_acids, var_colors)
            ax.text(0, 0.5, var_name, ha='right', va='center', fontsize=14)
    # Create the legend after the plotting loop, using the handles and labels without duplicates
    # This will create a dictionary only with entries where the label is not '0'abs
    create_custom_legends(fig, labels, handles, var_name_list, legend_title, heatmap_legend_handles,
                          heatmap_legend_labels, ms_average_choice, bio_or_pep, plot_type="land")

    if bio_or_pep == '2' and ms_average_choice != 'only':
        # Flatten the list of lists into a single list of strings
        print_list = [item for sublist in print_list for item in sublist]

        # Remove duplicates from the list by converting it to a set, then convert it back to a list
        print_list = list(set(print_list))

        # Remove the empty string if it exists
        if '' in print_list:
            print_list.remove('')

        # Enumerate the list and format the output
        print_list = [f"     {i + 1}. {label}" for i, label in enumerate(print_list)]
        print_list.insert(0, "The following labels have been relabeled as 'Multiple':")

        # Join all the elements of the list into a single string with newlines
        footnote = '\n'.join(print_list)
        #fig.text(0.15, -0.2, footnote, ha='left', va='center', fontsize=12)



    fig.text(yaxis_position, 0.90, yaxis_label, va='top', rotation='vertical', fontsize=16)
    fig.text(0.5, -0.025, xaxis_label, ha='center', va='center', fontsize=16)
    plt.tight_layout()
    #plt.subplots_adjust(left=0.05)  # Create space on the left for the y-label

    if ms_average_choice == 'only' and bio_or_pep == '1':
        display(HTML(f'<b style="color:red;">The selection of Plot Averaged Data option "only" and Plot Specific Peptides option "Peptide Intervals" is invalid. Only the average absrobance will be plotted.</b>'))

    if bio_or_pep == '2' and ms_average_choice != 'only':
        if len(print_list) > 1:
            print(footnote)
    return fig 

In [82]:
def import_data_variables(export_dir='protein_data_export'):
    """
    Import previously exported data variables including DataFrames from CSV files.
    
    Args:
        export_dir: Directory where the data was exported
        
    Returns:
        Dictionary with the reconstructed data structure
    """
    # Check if export directory exists
    if not os.path.exists(export_dir):
        raise FileNotFoundError(f"Export directory '{export_dir}' not found")
    
    # Load metadata file
    metadata_path = Path(export_dir) / "metadata.json"
    if not os.path.exists(metadata_path):
        raise FileNotFoundError(f"Metadata file not found at {metadata_path}")
    
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    
    # Function to recursively reconstruct the data structure
    def reconstruct_dict(meta_dict):
        result = {}
        
        for key, info in meta_dict.items():
            # Handle different types based on the metadata
            if isinstance(info, dict):
                if info.get('type') == 'DataFrame':
                    # Load DataFrame from CSV
                    csv_path = info.get('csv_path')
                    if csv_path and os.path.exists(csv_path):
                        result[key] = pd.read_csv(csv_path, index_col=0)
                    else:
                        print(f"Warning: CSV file for {key} not found at {csv_path}")
                        result[key] = pd.DataFrame()  # Empty DataFrame as fallback
                
                elif info.get('type') == 'dict' and 'content' in info:
                    # Recursively process nested dictionaries
                    result[key] = reconstruct_dict(info['content'])
                
                elif info.get('type') in ['str', 'int', 'float', 'bool', 'list'] and 'value' in info:
                    # Simple types with values
                    result[key] = info['value']
                
                elif info.get('type') == 'numpy_scalar' and 'value' in info:
                    # Numpy scalar types
                    result[key] = info['value']
                
                elif info.get('type') == 'ndarray':
                    # For arrays, we'd need to have saved them separately
                    # Here we just note that it was an array
                    print(f"Warning: Numpy array {key} cannot be fully reconstructed from metadata alone")
                    result[key] = None
                
                elif info.get('type') == 'bool_':
                    # Handle numpy bool_ type
                    if 'value' in info:
                        result[key] = bool(info['value'])
                    else:
                        # Default to False if no value was provided
                        result[key] = False
                    
                else:
                    # Other types we can't automatically reconstruct
                    print(f"Note: Attempting to reconstruct {key} of type {info.get('type')}")
                    # Try to get the value if it exists
                    if 'value' in info:
                        result[key] = info['value']
                    else:
                        result[key] = None
            else:
                # Handle unexpected metadata structure
                print(f"Warning: Unexpected metadata format for {key}")
                result[key] = None
        
        return result
    
    # Reconstruct the top-level dictionary
    available_data_variables = {}
    
    for top_key, top_meta in metadata.items():
        if isinstance(top_meta, dict):
            available_data_variables[top_key] = reconstruct_dict(top_meta)
        else:
            print(f"Warning: Unexpected top-level metadata for {top_key}")
            available_data_variables[top_key] = None
    
    print(f"Successfully imported data from {export_dir}")
    print(f"Reconstructed {len(available_data_variables)} top-level variables")
    
    return available_data_variables

def print_imports():
    # Import the data
    global available_data_variables, filter_type, ms_average_choice, bio_or_pep, selected_peptides, selected_functions, hm_selected_color, lp_selected_color, avglp_selected_color, xaxis_label_input, yaxis_label_input, yaxis_position, legend_title_input_1, legend_title_input_2, legend_title_input_3, plot_land, plot_port, selector_filter_type
    available_data_variables = import_data_variables('protein_data_export')

    # Now you can use available_data_variables just like before
    # For example, to check the keys:
    print("Top-level keys:", list(available_data_variables.keys()))
    for item in (available_data_variables[list(available_data_variables.keys())[0]].items()):
        if "heatmap_df" not in item[0]:
            print(item)
        else:
            print(item[0], item[1].shape)
    # To check if DataFrames and bool values were loaded correctly:
    #for key in available_data_variables:
    #    if 'heatmap_df' in available_data_variables[key]:
    #        df = available_data_variables[key]['heatmap_df']
    #        print(f"{key} heatmap_df shape: {df.shape}")
        


    # Assign the variables from the debug output
    ms_average_choice = True  # 'yes' converts to boolean True
    bio_or_pep = False  # 'no' converts to boolean False
    selected_peptides = []  # Empty list
    selected_functions = []  # Empty list
    hm_selected_color = "RdYlGn_r"
    lp_selected_color = "Set3"
    avglp_selected_color = "Dark2"
    xaxis_label_input = " Beta-casein Sequence"
    yaxis_label_input = "Averaged Peptide Abundance"
    yaxis_position = 0
    legend_title_input_1 = "Sample Type:"
    legend_title_input_2 = "Peptide Counts:"
    legend_title_input_3 = "Average Absorbance:"
    plot_land = True
    plot_port = True
    filter_type = "all-peptides"
    # Print the variables to confirm
    print(f"\n----------User Input Variables ------------")
    print(f"ms_average_choice: {ms_average_choice}")
    print(f"bio_or_pep: {bio_or_pep}")
    print(f"selected_peptides: {selected_peptides}")
    print(f"selected_functions: {selected_functions}")
    print(f"hm_selected_color: {hm_selected_color}")
    print(f"lp_selected_color: {lp_selected_color}")
    print(f"avglp_selected_color: {avglp_selected_color}")
    print(f"xaxis_label_input: {xaxis_label_input}")
    print(f"yaxis_label_input: {yaxis_label_input}")
    print(f"yaxis_position: {yaxis_position}")
    print(f"legend_title_input_1: {legend_title_input_1}")
    print(f"legend_title_input_2: {legend_title_input_2}")
    print(f"legend_title_input_3: {legend_title_input_3}")
    print(f"plot_land: {plot_land}")
    print(f"plot_port: {plot_port}")
    print(f"filter_type: {filter_type}")

print_imports()

Successfully imported data from protein_data_export
Reconstructed 4 top-level variables
Top-level keys: ['P02666_Low', 'P02666_Extreme', 'P02666_Threshold', 'P02666_Moderate']
('protein_id', 'P02666')
('protein_sequence', 'MKVLILACLVALALARELEELNVPGEIVESLSSSEESITRINKKIEKFQSEEQQQTEDELQDKIHPFAQTQSLVYPFPGPIPNSLPQNIPPLTQTPVVVPPFLQPEVMGVSKVKEAMAPKHKEMPFPKYPVEPFTESQSLTLTDVENLHLPLPLLQSWMHQPHQPLPPTVMFPPQSVLSLSQSKVLPVPQKAVPYPQRDMPIQAFLLYQEPVLGPVRGPFPIIV')
('protein_name', ' Beta-casein')
('protein_species', 'Bovine')
heatmap_df (224, 1139)
function_heatmap_df (224, 451)
('label', 'Low')
('is_func_df_all_none', False)
filtered_heatmap_df (224, 454)

----------User Input Variables ------------
ms_average_choice: True
bio_or_pep: False
selected_peptides: []
selected_functions: []
hm_selected_color: RdYlGn_r
lp_selected_color: Set3
avglp_selected_color: Dark2
xaxis_label_input:  Beta-casein Sequence
yaxis_label_input: Averaged Peptide Abundance
yaxis_position: 0
legend_title_input_1: Sample Type:
le

In [83]:
result = process_available_data(available_data_variables, filter_type, selected_functions)
if result:
    # Unpack the dictionary into individual global variables
    global list_of_counts, min_values, max_values, seq_len_list, max_sequence_length
    global y_ticks_html, max_count, num_unique_count, num_colors, total_plots, style_map, y_ticks
    list_of_counts = result['list_of_counts']
    min_values = result['min_values']
    max_values = result['max_values']
    seq_len_list = result['seq_len_list']
    max_sequence_length = result['max_sequence_length']
    y_ticks_html = result['y_ticks_html']
    max_count = result['max_count']
    num_unique_count = result['num_unique_count']
    num_colors = result['num_colors']
    total_plots = result['total_plots']
    style_map = result['style_map']    
    print("items in results")
    for item in result:
        print(item,result[item])
    

items in results
list_of_counts {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 91, 92, 93, 94, 95, 97, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 139, 140, 141, 144, 146, 148, 150, 151, 152}
min_values [86664.66729099999, 25476.19531, 568514.53125, 68407.0125]
max_values [181268251.3099857, 93337996.59268063, 148101035.71410823, 139981520.34881058]
seq_len_list [78, 78, 78, 78]
max_sequence_length 78
y_ticks [25476.0, 90646864.0, 181268252.0]
y_ticks_html <b>Max/Min of MS data (y-ticks):</b> 2.55e+04, 9.06e+07, 1.81e+08
max_count 152
num_unique_count 136
num_colors 6
total_plo

In [84]:

from matplotlib.colors import to_rgba
import matplotlib.pyplot as plt
# Create a matplotlib colormap object from the Plotly colorscale name
cmap = plt.get_cmap(hm_selected_color)  

# Sort variables by key for consistent order
var_keys = sorted(list(available_data_variables.keys()))

# Get max sequence length for x-axis consistency
max_seq_len = max([len(available_data_variables[var]['protein_sequence']) for var in available_data_variables])

# find number of variables and make subplots
num_vars = len(available_data_variables)
fig = make_subplots(
    rows=num_vars, 
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.03,  # Slightly increase spacing for labels
    subplot_titles=[""] * num_vars  # Empty titles, we'll add them later
)

# Process each variable and add it to its own subplot
for i, var_key in enumerate(var_keys):
    var_data = available_data_variables[var_key]
    row = i + 1  # 1-indexed for subplots
    
    # Get the label for this variable
    if '_' in var_key:
        _, group_key = var_key.split('_', 1)
        label = group_key
    else:
        label = var_key
        
    # Get heatmap data
    if 'heatmap_df' in var_data and not var_data['heatmap_df'].empty:
        heatmap_df = var_data['heatmap_df']
        
        # Get counts and data
        counts = heatmap_df['count'].tolist()
        ms_data = heatmap_df['average'].fillna(0).tolist()
        aa_list = heatmap_df['AA'].tolist()
        
        # Get positions for all amino acids (continuous x-axis)
        positions = list(range(1, len(counts) + 1))
        
        # Create colors based on counts
        # Use the original get_grouped_colors function
        def plotly_colors():
            bar_colors_rgba = get_grouped_colors(counts, max_count, 
                                            num_groups=6,  # You can adjust this number
                                            plot_zero='no',  # or 'yes' if you want to include zeros
                                            cmap=cmap)

            # Convert matplotlib RGBA colors to Plotly format
            bar_colors = []
            for color in bar_colors_rgba:
                if color == 'white':
                    bar_colors.append('rgba(255,255,255,1.0)')
                else:
                    r, g, b, a = to_rgba(color)
                    bar_colors.append(f'rgba({int(r*255)},{int(g*255)},{int(b*255)},{a})')   
            return bar_colors
        bar_colors = plotly_colors()
        
  
        # For zero values, use very small heights
        min_height = 1e-1
        heights = []
        for ms in ms_data:
            if ms > 0:
                heights.append(max(ms, min_height))
            else:
                heights.append(min_height)  # Small non-zero value for log scale

        fig.add_trace(
            go.Bar(
                x=positions,
                y=heights,
                marker_color=bar_colors,
                marker_line_width=0,
                width=1.0,  # Full width to eliminate gaps
                hoverinfo='text',
                hovertext=[
                    f"Position: {pos}<br>Amino Acid: {aa}<br>Count: {count}<br>Abundance: {height:.2e}"
                    for pos, aa, count, height in zip(positions, aa_list, counts, heights)
                ],
                showlegend=False,
            ),
            row=row, col=1
        )
        def plotly_yaxis():    
            # Get min, max, median values for annotations
            y_min = [min(min_values)]
            y_max = [max(max_values)]

            # Get the three tick values from the calculate_y_ticks function
            tick1, tick2, tick3 = calculate_y_ticks(y_min, y_max)
            
            # Calculate log range for the y-axis display
            #log_min = np.log10(tick1)
            #log_max = np.log10(tick3)
            
            # Update y-axis with log scaling
            fig.update_yaxes(
                title=dict(
                    text=label,
                    font=dict(size=12),
                    standoff=5  # Closer to the axis
                ),
                #type='log',
                showgrid=True,
                gridwidth=1,
                gridcolor='rgba(0,0,0,0.3)',
                range= [tick1, tick3],#[log_min, log_max],  # Log-scaled range
                tickvals=[tick1, tick2, tick3],  # Show exactly 3 ticks
                ticktext=["", "", ""],  # Hide tick text but keep the ticks
                showticklabels=False,  # Don't show tick labels
                ticks="outside",  # Show tick marks outside
                ticklen=6,  # Make ticks longer
                tickwidth=2,  # Make ticks thicker
                row=row, col=1
            )
            return tick1, tick2, tick3
        tick1, tick2, tick3 = plotly_yaxis()

# Create a discrete legend at the right side of the plot
def plotly_legend(tick1, tick2, tick3):
    # PART 1: Create legend for peptide counts
    heatmap_legend_handles, heatmap_legend_labels = create_heatmap_legend_handles(cmap, num_colors, max_count, plot_zero)
    #print(f"Legend handles: {len(heatmap_legend_handles)}")
    #print(f"Legend labels: {heatmap_legend_labels}")
    # Add a separator with the first legend title
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode='markers',
            marker=dict(
                size=0,
                opacity=0,
            ),
            name=f"<b>{legend_title_input_2}</b>", # Bold formatting for title
            showlegend=True,
            legendgroup='peptide_count_title'
        )
    )

    # Add legend items for peptide counts
    for i, (label, handle) in enumerate(zip(heatmap_legend_labels, heatmap_legend_handles)):
        # Get color from matplotlib handle and convert to Plotly format
        color = handle.get_facecolor()
        r, g, b, a = color
        plotly_color = f'rgba({int(r*255)},{int(g*255)},{int(b*255)},{a})'
        
        # Add an invisible trace with the right legend properties
        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode='markers',
                marker=dict(
                    size=10,
                    color=plotly_color,
                    line=dict(width=1, color='black')
                ),
                name=label,
                legendgroup=legend_title_input_2,
                showlegend=True
            )
        )
    
    # PART 2: Create legend for absorbance values
    # Add a separator with the second legend title
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode='markers',
            marker=dict(
                size=0,
                opacity=0,
            ),
            name=f"<b>{legend_title_input_3}</b>", # Bold formatting for title
            showlegend=True,
            legendgroup='absorbance_title'
        )
    )

    # Add an invisible trace to create a title-like effect for the second legend
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode='markers',
            marker=dict(
                size=0,
                opacity=0,
            ),
            name='', # Empty name to prevent duplicate text
            showlegend=False,
            legendgroup='absorbance_title'
        )
    )
    
    # Set icon for absorbance legend items
    absorbance_symbol = 'line-ew'  # horizontal line symbol
    
    # Add the absorbance values to the legend
    for i, (tick, label) in enumerate(zip([tick1, tick2, tick3], ['Y-Min', 'Y-Mid', 'Y-Max'])):
        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode='markers',
                marker=dict(
                    size=10,
                    symbol=absorbance_symbol,
                    color='black',
                    line=dict(width=2)
                ),
                name=f"{label}: {tick:.1g}",
                #name=f"{label}: {tick:.1e}",
                legendgroup=legend_title_input_3,
                showlegend=True
            )
        )
    
    # Update the legend layout
    fig.update_layout(
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=1.02,  # Position further to the right
            bordercolor="black",
            orientation="v",
            itemsizing='constant',  # Make all legend items the same size
            groupclick="toggleitem",
            bgcolor="rgba(255,255,255,0.8)",  # Semi-transparent background
            font=dict(size=11)
        )
    )
plotly_legend(tick1, tick2, tick3)

# Add position markers every 20 amino acids
# Update x-axes to be continuous with no gaps
def plotly_xaxes():
    for i in range(1, num_vars + 1):
        # Add position markers every 20 amino acids    
        marker_positions = list(range(20, max_seq_len + 1, 20))
        marker_texts = [str(pos) for pos in marker_positions]
        

        
        # Only show x-axis title on bottom subplot
        if i == num_vars:
            fig.update_xaxes(
                title=xaxis_label_input,
                showgrid=True,
                gridwidth=1,
                gridcolor='rgba(200,200,200,0.3)',
                range=[0, max_seq_len + 1],  # Set consistent range for all plots
                zeroline=False,
                dtick=20,  # Set position marks every 20 units
                row=i, col=1
            )
            # Add text labels for every 20th position
            for pos in marker_positions:
                fig.add_annotation(
                    x=pos,
                    y=-0.15,  # Below the x-axis
                    xref="x",
                    yref="paper",
                    text=str(pos),
                    showarrow=False,
                    font=dict(size=10),
                    row=i, col=1
                )

        else:
            for pos in marker_positions:
                # Add a vertical line at each marker position
                fig.add_shape(
                    type="line",
                    x0=pos, y0=0, 
                    x1=pos, y1=1,
                    xref="x", yref="paper",
                    line=dict(
                        color="rgba(100,100,100,0.3)",
                        width=1,
                        dash="dot"
                    ),
                    row=i, col=1
                )
            fig.update_xaxes(
                showticklabels=False,  # Hide tick labels except on bottom plot
                showgrid=True,
                gridwidth=1,
                gridcolor='rgba(200,200,200,0.3)',
                range=[0, max_seq_len + 1],  # Set consistent range for all plots
                zeroline=False,
                row=i, col=1
            )
plotly_xaxes()

# Update overall layout
fig.update_layout(
    height=130 * num_vars,  # Adjust height based on number of variables
    width=1000,
    bargap=0,      # No gap between bars in the same trace
    bargroupgap=0, # No gap between bar groups
    plot_bgcolor='lightgray',
    showlegend=True,
    margin=dict(l=50, r=50, t=50, b=50)  # Increased bottom margin for position markers
)

# Display the figure
fig.show()